diff options
author | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 |
---|---|---|
committer | Dimitry Andric <dim@FreeBSD.org> | 2017-12-18 20:10:56 +0000 |
commit | 044eb2f6afba375a914ac9d8024f8f5142bb912e (patch) | |
tree | 1475247dc9f9fe5be155ebd4c9069c75aadf8c20 /lib/Transforms | |
parent | eb70dddbd77e120e5d490bd8fbe7ff3f8fa81c6b (diff) |
Notes
Diffstat (limited to 'lib/Transforms')
172 files changed, 25164 insertions, 9915 deletions
diff --git a/lib/Transforms/Coroutines/CoroFrame.cpp b/lib/Transforms/Coroutines/CoroFrame.cpp index 85e9003ec3c5..6334256bf03a 100644 --- a/lib/Transforms/Coroutines/CoroFrame.cpp +++ b/lib/Transforms/Coroutines/CoroFrame.cpp @@ -261,21 +261,19 @@ SuspendCrossingInfo::SuspendCrossingInfo(Function &F, coro::Shape &Shape) // We build up the list of spills for every case where a use is separated // from the definition by a suspend point. -struct Spill : std::pair<Value *, Instruction *> { - using base = std::pair<Value *, Instruction *>; - - Spill(Value *Def, User *U) : base(Def, cast<Instruction>(U)) {} - - Value *def() const { return first; } - Instruction *user() const { return second; } - BasicBlock *userBlock() const { return second->getParent(); } +namespace { +class Spill { + Value *Def; + Instruction *User; - std::pair<Value *, BasicBlock *> getKey() const { - return {def(), userBlock()}; - } +public: + Spill(Value *Def, llvm::User *U) : Def(Def), User(cast<Instruction>(U)) {} - bool operator<(Spill const &rhs) const { return getKey() < rhs.getKey(); } + Value *def() const { return Def; } + Instruction *user() const { return User; } + BasicBlock *userBlock() const { return User->getParent(); } }; +} // namespace // Note that there may be more than one record with the same value of Def in // the SpillInfo vector. @@ -673,8 +671,8 @@ static bool materializable(Instruction &V) { // Check for structural coroutine intrinsics that should not be spilled into // the coroutine frame. static bool isCoroutineStructureIntrinsic(Instruction &I) { - return isa<CoroIdInst>(&I) || isa<CoroBeginInst>(&I) || - isa<CoroSaveInst>(&I) || isa<CoroSuspendInst>(&I); + return isa<CoroIdInst>(&I) || isa<CoroSaveInst>(&I) || + isa<CoroSuspendInst>(&I); } // For every use of the value that is across suspend point, recreate that value @@ -839,7 +837,7 @@ void coro::buildCoroutineFrame(Function &F, Shape &Shape) { for (Instruction &I : instructions(F)) { // Values returned from coroutine structure intrinsics should not be part // of the Coroutine Frame. - if (isCoroutineStructureIntrinsic(I)) + if (isCoroutineStructureIntrinsic(I) || &I == Shape.CoroBegin) continue; // The Coroutine Promise always included into coroutine frame, no need to // check for suspend crossing. diff --git a/lib/Transforms/Coroutines/CoroSplit.cpp b/lib/Transforms/Coroutines/CoroSplit.cpp index 173dc05f0584..8712ca4823c6 100644 --- a/lib/Transforms/Coroutines/CoroSplit.cpp +++ b/lib/Transforms/Coroutines/CoroSplit.cpp @@ -19,17 +19,53 @@ // coroutine. //===----------------------------------------------------------------------===// +#include "CoroInstr.h" #include "CoroInternal.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" -#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <initializer_list> +#include <iterator> using namespace llvm; @@ -342,7 +378,6 @@ static void replaceFrameSize(coro::Shape &Shape) { // Assumes that all the functions have the same signature. static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, std::initializer_list<Function *> Fns) { - SmallVector<Constant *, 4> Args(Fns.begin(), Fns.end()); assert(!Args.empty()); Function *Part = *Fns.begin(); @@ -363,7 +398,6 @@ static void setCoroInfo(Function &F, CoroBeginInst *CoroBegin, // Store addresses of Resume/Destroy/Cleanup functions in the coroutine frame. static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, Function *DestroyFn, Function *CleanupFn) { - IRBuilder<> Builder(Shape.FramePtr->getNextNode()); auto *ResumeAddr = Builder.CreateConstInBoundsGEP2_32( Shape.FrameTy, Shape.FramePtr, 0, coro::Shape::ResumeField, @@ -387,7 +421,7 @@ static void updateCoroFrame(coro::Shape &Shape, Function *ResumeFn, static void postSplitCleanup(Function &F) { removeUnreachableBlocks(F); - llvm::legacy::FunctionPassManager FPM(F.getParent()); + legacy::FunctionPassManager FPM(F.getParent()); FPM.add(createVerifierPass()); FPM.add(createSCCPPass()); @@ -400,6 +434,91 @@ static void postSplitCleanup(Function &F) { FPM.doFinalization(); } +// Assuming we arrived at the block NewBlock from Prev instruction, store +// PHI's incoming values in the ResolvedValues map. +static void +scanPHIsAndUpdateValueMap(Instruction *Prev, BasicBlock *NewBlock, + DenseMap<Value *, Value *> &ResolvedValues) { + auto *PrevBB = Prev->getParent(); + auto *I = &*NewBlock->begin(); + while (auto PN = dyn_cast<PHINode>(I)) { + auto V = PN->getIncomingValueForBlock(PrevBB); + // See if we already resolved it. + auto VI = ResolvedValues.find(V); + if (VI != ResolvedValues.end()) + V = VI->second; + // Remember the value. + ResolvedValues[PN] = V; + I = I->getNextNode(); + } +} + +// Replace a sequence of branches leading to a ret, with a clone of a ret +// instruction. Suspend instruction represented by a switch, track the PHI +// values and select the correct case successor when possible. +static bool simplifyTerminatorLeadingToRet(Instruction *InitialInst) { + DenseMap<Value *, Value *> ResolvedValues; + + Instruction *I = InitialInst; + while (isa<TerminatorInst>(I)) { + if (isa<ReturnInst>(I)) { + if (I != InitialInst) + ReplaceInstWithInst(InitialInst, I->clone()); + return true; + } + if (auto *BR = dyn_cast<BranchInst>(I)) { + if (BR->isUnconditional()) { + BasicBlock *BB = BR->getSuccessor(0); + scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); + I = BB->getFirstNonPHIOrDbgOrLifetime(); + continue; + } + } else if (auto *SI = dyn_cast<SwitchInst>(I)) { + Value *V = SI->getCondition(); + auto it = ResolvedValues.find(V); + if (it != ResolvedValues.end()) + V = it->second; + if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) { + BasicBlock *BB = SI->findCaseValue(Cond)->getCaseSuccessor(); + scanPHIsAndUpdateValueMap(I, BB, ResolvedValues); + I = BB->getFirstNonPHIOrDbgOrLifetime(); + continue; + } + } + return false; + } + return false; +} + +// Add musttail to any resume instructions that is immediately followed by a +// suspend (i.e. ret). We do this even in -O0 to support guaranteed tail call +// for symmetrical coroutine control transfer (C++ Coroutines TS extension). +// This transformation is done only in the resume part of the coroutine that has +// identical signature and calling convention as the coro.resume call. +static void addMustTailToCoroResumes(Function &F) { + bool changed = false; + + // Collect potential resume instructions. + SmallVector<CallInst *, 4> Resumes; + for (auto &I : instructions(F)) + if (auto *Call = dyn_cast<CallInst>(&I)) + if (auto *CalledValue = Call->getCalledValue()) + // CoroEarly pass replaced coro resumes with indirect calls to an + // address return by CoroSubFnInst intrinsic. See if it is one of those. + if (isa<CoroSubFnInst>(CalledValue->stripPointerCasts())) + Resumes.push_back(Call); + + // Set musttail on those that are followed by a ret instruction. + for (CallInst *Call : Resumes) + if (simplifyTerminatorLeadingToRet(Call->getNextNode())) { + Call->setTailCallKind(CallInst::TCK_MustTail); + changed = true; + } + + if (changed) + removeUnreachableBlocks(F); +} + // Coroutine has no suspend points. Remove heap allocation for the coroutine // frame if possible. static void handleNoSuspendCoroutine(CoroBeginInst *CoroBegin, Type *FrameTy) { @@ -488,7 +607,7 @@ static void simplifySuspendPoints(coro::Shape &Shape) { size_t I = 0, N = S.size(); if (N == 0) return; - for (;;) { + while (true) { if (simplifySuspendPoint(S[I], Shape.CoroBegin)) { if (--N == I) break; @@ -608,6 +727,8 @@ static void splitCoroutine(Function &F, CallGraph &CG, CallGraphSCC &SCC) { postSplitCleanup(*DestroyClone); postSplitCleanup(*CleanupClone); + addMustTailToCoroResumes(*ResumeClone); + // Store addresses resume/destroy/cleanup functions in the coroutine frame. updateCoroFrame(Shape, ResumeClone, DestroyClone, CleanupClone); @@ -681,6 +802,7 @@ namespace { struct CoroSplit : public CallGraphSCCPass { static char ID; // Pass identification, replacement for typeid + CoroSplit() : CallGraphSCCPass(ID) { initializeCoroSplitPass(*PassRegistry::getPassRegistry()); } @@ -729,11 +851,14 @@ struct CoroSplit : public CallGraphSCCPass { void getAnalysisUsage(AnalysisUsage &AU) const override { CallGraphSCCPass::getAnalysisUsage(AU); } + StringRef getPassName() const override { return "Coroutine Splitting"; } }; -} + +} // end anonymous namespace char CoroSplit::ID = 0; + INITIALIZE_PASS( CoroSplit, "coro-split", "Split coroutine into a set of functions driving its state machine", false, diff --git a/lib/Transforms/Coroutines/Coroutines.cpp b/lib/Transforms/Coroutines/Coroutines.cpp index 44e1f9b404ed..10411c1bd65d 100644 --- a/lib/Transforms/Coroutines/Coroutines.cpp +++ b/lib/Transforms/Coroutines/Coroutines.cpp @@ -1,4 +1,4 @@ -//===-- Coroutines.cpp ----------------------------------------------------===// +//===- Coroutines.cpp -----------------------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -6,18 +6,38 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// // This file implements the common infrastructure for Coroutine Passes. +// //===----------------------------------------------------------------------===// +#include "CoroInstr.h" #include "CoroInternal.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" -#include "llvm/InitializePasses.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/Coroutines.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/PassManagerBuilder.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstddef> +#include <utility> using namespace llvm; @@ -117,7 +137,6 @@ static bool isCoroutineIntrinsicName(StringRef Name) { // that names are intrinsic names. bool coro::declaresIntrinsics(Module &M, std::initializer_list<StringRef> List) { - for (StringRef Name : List) { assert(isCoroutineIntrinsicName(Name) && "not a coroutine intrinsic"); if (M.getNamedValue(Name)) diff --git a/lib/Transforms/IPO/AlwaysInliner.cpp b/lib/Transforms/IPO/AlwaysInliner.cpp index b7d96007c24a..a4bbc99b1f90 100644 --- a/lib/Transforms/IPO/AlwaysInliner.cpp +++ b/lib/Transforms/IPO/AlwaysInliner.cpp @@ -15,15 +15,12 @@ #include "llvm/Transforms/IPO/AlwaysInliner.h" #include "llvm/ADT/SetVector.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CallingConv.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Transforms/IPO.h" diff --git a/lib/Transforms/IPO/ArgumentPromotion.cpp b/lib/Transforms/IPO/ArgumentPromotion.cpp index 72bae203ee94..b25cbcad3b9d 100644 --- a/lib/Transforms/IPO/ArgumentPromotion.cpp +++ b/lib/Transforms/IPO/ArgumentPromotion.cpp @@ -1,4 +1,4 @@ -//===-- ArgumentPromotion.cpp - Promote by-reference arguments ------------===// +//===- ArgumentPromotion.cpp - Promote by-reference arguments -------------===// // // The LLVM Compiler Infrastructure // @@ -31,30 +31,59 @@ #include "llvm/Transforms/IPO/ArgumentPromotion.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/LazyCallGraph.h" #include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <functional> +#include <iterator> +#include <map> #include <set> +#include <string> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "argpromotion" @@ -65,7 +94,7 @@ STATISTIC(NumByValArgsPromoted, "Number of byval arguments promoted"); STATISTIC(NumArgumentsDead, "Number of dead pointer args eliminated"); /// A vector used to hold the indices of a single GEP instruction -typedef std::vector<uint64_t> IndicesVector; +using IndicesVector = std::vector<uint64_t>; /// DoPromotion - This method actually performs the promotion of the specified /// arguments, and returns the new function. At this point, we know that it's @@ -75,13 +104,12 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, SmallPtrSetImpl<Argument *> &ByValArgsToTransform, Optional<function_ref<void(CallSite OldCS, CallSite NewCS)>> ReplaceCallSite) { - // Start by computing a new prototype for the function, which is the same as // the old function, but has modified arguments. FunctionType *FTy = F->getFunctionType(); std::vector<Type *> Params; - typedef std::set<std::pair<Type *, IndicesVector>> ScalarizeTable; + using ScalarizeTable = std::set<std::pair<Type *, IndicesVector>>; // ScalarizedElements - If we are promoting a pointer that has elements // accessed out of it, keep track of which elements are accessed so that we @@ -89,7 +117,6 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // // Arguments that are directly loaded will have a zero element value here, to // handle cases where there are both a direct load and GEP accesses. - // std::map<Argument *, ScalarizeTable> ScalarizedElements; // OriginalLoads - Keep track of a representative load instruction from the @@ -335,7 +362,6 @@ doPromotion(Function *F, SmallPtrSetImpl<Argument *> &ArgsToPromote, // Loop over the argument list, transferring uses of the old arguments over to // the new arguments, also transferring over the names as well. - // for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(), I2 = NF->arg_begin(); I != E; ++I) { @@ -537,7 +563,7 @@ static void markIndicesSafe(const IndicesVector &ToMark, /// arguments passed in. static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, AAResults &AAR, unsigned MaxElements) { - typedef std::set<IndicesVector> GEPIndicesSet; + using GEPIndicesSet = std::set<IndicesVector>; // Quick exit for unused arguments if (Arg->use_empty()) @@ -693,7 +719,7 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, BasicBlock *BB = Load->getParent(); MemoryLocation Loc = MemoryLocation::get(Load); - if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, MRI_Mod)) + if (AAR.canInstructionRangeModRef(BB->front(), *Load, Loc, ModRefInfo::Mod)) return false; // Pointer is invalidated! // Now check every path from the entry block to the load for transparency. @@ -714,7 +740,6 @@ static bool isSafeToPromoteArgument(Argument *Arg, bool isByValOrInAlloca, /// \brief Checks if a type could have padding bytes. static bool isDenselyPacked(Type *type, const DataLayout &DL) { - // There is no size information, so be conservative. if (!type->isSized()) return false; @@ -749,7 +774,6 @@ static bool isDenselyPacked(Type *type, const DataLayout &DL) { /// \brief Checks if the padding bytes of an argument could be accessed. static bool canPaddingBeAccessed(Argument *arg) { - assert(arg->hasByValAttr()); // Track all the pointers to the argument to make sure they are not captured. @@ -788,7 +812,6 @@ static bool canPaddingBeAccessed(Argument *arg) { /// are any promotable arguments and if it is safe to promote the function (for /// example, all callers are direct). If safe to promote some arguments, it /// calls the DoPromotion method. -/// static Function * promoteArguments(Function *F, function_ref<AAResults &(Function &F)> AARGetter, unsigned MaxElements, @@ -964,9 +987,17 @@ PreservedAnalyses ArgumentPromotionPass::run(LazyCallGraph::SCC &C, } namespace { + /// ArgPromotion - The 'by reference' to 'by value' argument promotion pass. -/// struct ArgPromotion : public CallGraphSCCPass { + // Pass identification, replacement for typeid + static char ID; + + explicit ArgPromotion(unsigned MaxElements = 3) + : CallGraphSCCPass(ID), MaxElements(MaxElements) { + initializeArgPromotionPass(*PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); @@ -975,21 +1006,20 @@ struct ArgPromotion : public CallGraphSCCPass { } bool runOnSCC(CallGraphSCC &SCC) override; - static char ID; // Pass identification, replacement for typeid - explicit ArgPromotion(unsigned MaxElements = 3) - : CallGraphSCCPass(ID), MaxElements(MaxElements) { - initializeArgPromotionPass(*PassRegistry::getPassRegistry()); - } private: using llvm::Pass::doInitialization; + bool doInitialization(CallGraph &CG) override; + /// The maximum number of elements to expand, or 0 for unlimited. unsigned MaxElements; }; -} + +} // end anonymous namespace char ArgPromotion::ID = 0; + INITIALIZE_PASS_BEGIN(ArgPromotion, "argpromotion", "Promote 'by reference' arguments to scalars", false, false) diff --git a/lib/Transforms/IPO/CMakeLists.txt b/lib/Transforms/IPO/CMakeLists.txt index 67f18a307b9b..397561746f86 100644 --- a/lib/Transforms/IPO/CMakeLists.txt +++ b/lib/Transforms/IPO/CMakeLists.txt @@ -2,6 +2,7 @@ add_llvm_library(LLVMipo AlwaysInliner.cpp ArgumentPromotion.cpp BarrierNoopPass.cpp + CalledValuePropagation.cpp ConstantMerge.cpp CrossDSOCFI.cpp DeadArgumentElimination.cpp diff --git a/lib/Transforms/IPO/CalledValuePropagation.cpp b/lib/Transforms/IPO/CalledValuePropagation.cpp new file mode 100644 index 000000000000..c5f6336aa2be --- /dev/null +++ b/lib/Transforms/IPO/CalledValuePropagation.cpp @@ -0,0 +1,423 @@ +//===- CalledValuePropagation.cpp - Propagate called values -----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a transformation that attaches !callees metadata to +// indirect call sites. For a given call site, the metadata, if present, +// indicates the set of functions the call site could possibly target at +// run-time. This metadata is added to indirect call sites when the set of +// possible targets can be determined by analysis and is known to be small. The +// analysis driving the transformation is similar to constant propagation and +// makes uses of the generic sparse propagation solver. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/CalledValuePropagation.h" +#include "llvm/Analysis/SparsePropagation.h" +#include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/Transforms/IPO.h" +using namespace llvm; + +#define DEBUG_TYPE "called-value-propagation" + +/// The maximum number of functions to track per lattice value. Once the number +/// of functions a call site can possibly target exceeds this threshold, it's +/// lattice value becomes overdefined. The number of possible lattice values is +/// bounded by Ch(F, M), where F is the number of functions in the module and M +/// is MaxFunctionsPerValue. As such, this value should be kept very small. We +/// likely can't do anything useful for call sites with a large number of +/// possible targets, anyway. +static cl::opt<unsigned> MaxFunctionsPerValue( + "cvp-max-functions-per-value", cl::Hidden, cl::init(4), + cl::desc("The maximum number of functions to track per lattice value")); + +namespace { +/// To enable interprocedural analysis, we assign LLVM values to the following +/// groups. The register group represents SSA registers, the return group +/// represents the return values of functions, and the memory group represents +/// in-memory values. An LLVM Value can technically be in more than one group. +/// It's necessary to distinguish these groups so we can, for example, track a +/// global variable separately from the value stored at its location. +enum class IPOGrouping { Register, Return, Memory }; + +/// Our LatticeKeys are PointerIntPairs composed of LLVM values and groupings. +using CVPLatticeKey = PointerIntPair<Value *, 2, IPOGrouping>; + +/// The lattice value type used by our custom lattice function. It holds the +/// lattice state, and a set of functions. +class CVPLatticeVal { +public: + /// The states of the lattice values. Only the FunctionSet state is + /// interesting. It indicates the set of functions to which an LLVM value may + /// refer. + enum CVPLatticeStateTy { Undefined, FunctionSet, Overdefined, Untracked }; + + /// Comparator for sorting the functions set. We want to keep the order + /// deterministic for testing, etc. + struct Compare { + bool operator()(const Function *LHS, const Function *RHS) const { + return LHS->getName() < RHS->getName(); + } + }; + + CVPLatticeVal() : LatticeState(Undefined) {} + CVPLatticeVal(CVPLatticeStateTy LatticeState) : LatticeState(LatticeState) {} + CVPLatticeVal(std::set<Function *, Compare> &&Functions) + : LatticeState(FunctionSet), Functions(Functions) {} + + /// Get a reference to the functions held by this lattice value. The number + /// of functions will be zero for states other than FunctionSet. + const std::set<Function *, Compare> &getFunctions() const { + return Functions; + } + + /// Returns true if the lattice value is in the FunctionSet state. + bool isFunctionSet() const { return LatticeState == FunctionSet; } + + bool operator==(const CVPLatticeVal &RHS) const { + return LatticeState == RHS.LatticeState && Functions == RHS.Functions; + } + + bool operator!=(const CVPLatticeVal &RHS) const { + return LatticeState != RHS.LatticeState || Functions != RHS.Functions; + } + +private: + /// Holds the state this lattice value is in. + CVPLatticeStateTy LatticeState; + + /// Holds functions indicating the possible targets of call sites. This set + /// is empty for lattice values in the undefined, overdefined, and untracked + /// states. The maximum size of the set is controlled by + /// MaxFunctionsPerValue. Since most LLVM values are expected to be in + /// uninteresting states (i.e., overdefined), CVPLatticeVal objects should be + /// small and efficiently copyable. + std::set<Function *, Compare> Functions; +}; + +/// The custom lattice function used by the generic sparse propagation solver. +/// It handles merging lattice values and computing new lattice values for +/// constants, arguments, values returned from trackable functions, and values +/// located in trackable global variables. It also computes the lattice values +/// that change as a result of executing instructions. +class CVPLatticeFunc + : public AbstractLatticeFunction<CVPLatticeKey, CVPLatticeVal> { +public: + CVPLatticeFunc() + : AbstractLatticeFunction(CVPLatticeVal(CVPLatticeVal::Undefined), + CVPLatticeVal(CVPLatticeVal::Overdefined), + CVPLatticeVal(CVPLatticeVal::Untracked)) {} + + /// Compute and return a CVPLatticeVal for the given CVPLatticeKey. + CVPLatticeVal ComputeLatticeVal(CVPLatticeKey Key) override { + switch (Key.getInt()) { + case IPOGrouping::Register: + if (isa<Instruction>(Key.getPointer())) { + return getUndefVal(); + } else if (auto *A = dyn_cast<Argument>(Key.getPointer())) { + if (canTrackArgumentsInterprocedurally(A->getParent())) + return getUndefVal(); + } else if (auto *C = dyn_cast<Constant>(Key.getPointer())) { + return computeConstant(C); + } + return getOverdefinedVal(); + case IPOGrouping::Memory: + case IPOGrouping::Return: + if (auto *GV = dyn_cast<GlobalVariable>(Key.getPointer())) { + if (canTrackGlobalVariableInterprocedurally(GV)) + return computeConstant(GV->getInitializer()); + } else if (auto *F = cast<Function>(Key.getPointer())) + if (canTrackReturnsInterprocedurally(F)) + return getUndefVal(); + } + return getOverdefinedVal(); + } + + /// Merge the two given lattice values. The interesting cases are merging two + /// FunctionSet values and a FunctionSet value with an Undefined value. For + /// these cases, we simply union the function sets. If the size of the union + /// is greater than the maximum functions we track, the merged value is + /// overdefined. + CVPLatticeVal MergeValues(CVPLatticeVal X, CVPLatticeVal Y) override { + if (X == getOverdefinedVal() || Y == getOverdefinedVal()) + return getOverdefinedVal(); + if (X == getUndefVal() && Y == getUndefVal()) + return getUndefVal(); + std::set<Function *, CVPLatticeVal::Compare> Union; + std::set_union(X.getFunctions().begin(), X.getFunctions().end(), + Y.getFunctions().begin(), Y.getFunctions().end(), + std::inserter(Union, Union.begin()), + CVPLatticeVal::Compare{}); + if (Union.size() > MaxFunctionsPerValue) + return getOverdefinedVal(); + return CVPLatticeVal(std::move(Union)); + } + + /// Compute the lattice values that change as a result of executing the given + /// instruction. The changed values are stored in \p ChangedValues. We handle + /// just a few kinds of instructions since we're only propagating values that + /// can be called. + void ComputeInstructionState( + Instruction &I, DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, + SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) override { + switch (I.getOpcode()) { + case Instruction::Call: + return visitCallSite(cast<CallInst>(&I), ChangedValues, SS); + case Instruction::Invoke: + return visitCallSite(cast<InvokeInst>(&I), ChangedValues, SS); + case Instruction::Load: + return visitLoad(*cast<LoadInst>(&I), ChangedValues, SS); + case Instruction::Ret: + return visitReturn(*cast<ReturnInst>(&I), ChangedValues, SS); + case Instruction::Select: + return visitSelect(*cast<SelectInst>(&I), ChangedValues, SS); + case Instruction::Store: + return visitStore(*cast<StoreInst>(&I), ChangedValues, SS); + default: + return visitInst(I, ChangedValues, SS); + } + } + + /// Print the given CVPLatticeVal to the specified stream. + void PrintLatticeVal(CVPLatticeVal LV, raw_ostream &OS) override { + if (LV == getUndefVal()) + OS << "Undefined "; + else if (LV == getOverdefinedVal()) + OS << "Overdefined"; + else if (LV == getUntrackedVal()) + OS << "Untracked "; + else + OS << "FunctionSet"; + } + + /// Print the given CVPLatticeKey to the specified stream. + void PrintLatticeKey(CVPLatticeKey Key, raw_ostream &OS) override { + if (Key.getInt() == IPOGrouping::Register) + OS << "<reg> "; + else if (Key.getInt() == IPOGrouping::Memory) + OS << "<mem> "; + else if (Key.getInt() == IPOGrouping::Return) + OS << "<ret> "; + if (isa<Function>(Key.getPointer())) + OS << Key.getPointer()->getName(); + else + OS << *Key.getPointer(); + } + + /// We collect a set of indirect calls when visiting call sites. This method + /// returns a reference to that set. + SmallPtrSetImpl<Instruction *> &getIndirectCalls() { return IndirectCalls; } + +private: + /// Holds the indirect calls we encounter during the analysis. We will attach + /// metadata to these calls after the analysis indicating the functions the + /// calls can possibly target. + SmallPtrSet<Instruction *, 32> IndirectCalls; + + /// Compute a new lattice value for the given constant. The constant, after + /// stripping any pointer casts, should be a Function. We ignore null + /// pointers as an optimization, since calling these values is undefined + /// behavior. + CVPLatticeVal computeConstant(Constant *C) { + if (isa<ConstantPointerNull>(C)) + return CVPLatticeVal(CVPLatticeVal::FunctionSet); + if (auto *F = dyn_cast<Function>(C->stripPointerCasts())) + return CVPLatticeVal({F}); + return getOverdefinedVal(); + } + + /// Handle return instructions. The function's return state is the merge of + /// the returned value state and the function's return state. + void visitReturn(ReturnInst &I, + DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, + SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) { + Function *F = I.getParent()->getParent(); + if (F->getReturnType()->isVoidTy()) + return; + auto RegI = CVPLatticeKey(I.getReturnValue(), IPOGrouping::Register); + auto RetF = CVPLatticeKey(F, IPOGrouping::Return); + ChangedValues[RetF] = + MergeValues(SS.getValueState(RegI), SS.getValueState(RetF)); + } + + /// Handle call sites. The state of a called function's formal arguments is + /// the merge of the argument state with the call sites corresponding actual + /// argument state. The call site state is the merge of the call site state + /// with the returned value state of the called function. + void visitCallSite(CallSite CS, + DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, + SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) { + Function *F = CS.getCalledFunction(); + Instruction *I = CS.getInstruction(); + auto RegI = CVPLatticeKey(I, IPOGrouping::Register); + + // If this is an indirect call, save it so we can quickly revisit it when + // attaching metadata. + if (!F) + IndirectCalls.insert(I); + + // If we can't track the function's return values, there's nothing to do. + if (!F || !canTrackReturnsInterprocedurally(F)) { + ChangedValues[RegI] = getOverdefinedVal(); + return; + } + + // Inform the solver that the called function is executable, and perform + // the merges for the arguments and return value. + SS.MarkBlockExecutable(&F->front()); + auto RetF = CVPLatticeKey(F, IPOGrouping::Return); + for (Argument &A : F->args()) { + auto RegFormal = CVPLatticeKey(&A, IPOGrouping::Register); + auto RegActual = + CVPLatticeKey(CS.getArgument(A.getArgNo()), IPOGrouping::Register); + ChangedValues[RegFormal] = + MergeValues(SS.getValueState(RegFormal), SS.getValueState(RegActual)); + } + ChangedValues[RegI] = + MergeValues(SS.getValueState(RegI), SS.getValueState(RetF)); + } + + /// Handle select instructions. The select instruction state is the merge the + /// true and false value states. + void visitSelect(SelectInst &I, + DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, + SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) { + auto RegI = CVPLatticeKey(&I, IPOGrouping::Register); + auto RegT = CVPLatticeKey(I.getTrueValue(), IPOGrouping::Register); + auto RegF = CVPLatticeKey(I.getFalseValue(), IPOGrouping::Register); + ChangedValues[RegI] = + MergeValues(SS.getValueState(RegT), SS.getValueState(RegF)); + } + + /// Handle load instructions. If the pointer operand of the load is a global + /// variable, we attempt to track the value. The loaded value state is the + /// merge of the loaded value state with the global variable state. + void visitLoad(LoadInst &I, + DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, + SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) { + auto RegI = CVPLatticeKey(&I, IPOGrouping::Register); + if (auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand())) { + auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory); + ChangedValues[RegI] = + MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV)); + } else { + ChangedValues[RegI] = getOverdefinedVal(); + } + } + + /// Handle store instructions. If the pointer operand of the store is a + /// global variable, we attempt to track the value. The global variable state + /// is the merge of the stored value state with the global variable state. + void visitStore(StoreInst &I, + DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, + SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) { + auto *GV = dyn_cast<GlobalVariable>(I.getPointerOperand()); + if (!GV) + return; + auto RegI = CVPLatticeKey(I.getValueOperand(), IPOGrouping::Register); + auto MemGV = CVPLatticeKey(GV, IPOGrouping::Memory); + ChangedValues[MemGV] = + MergeValues(SS.getValueState(RegI), SS.getValueState(MemGV)); + } + + /// Handle all other instructions. All other instructions are marked + /// overdefined. + void visitInst(Instruction &I, + DenseMap<CVPLatticeKey, CVPLatticeVal> &ChangedValues, + SparseSolver<CVPLatticeKey, CVPLatticeVal> &SS) { + auto RegI = CVPLatticeKey(&I, IPOGrouping::Register); + ChangedValues[RegI] = getOverdefinedVal(); + } +}; +} // namespace + +namespace llvm { +/// A specialization of LatticeKeyInfo for CVPLatticeKeys. The generic solver +/// must translate between LatticeKeys and LLVM Values when adding Values to +/// its work list and inspecting the state of control-flow related values. +template <> struct LatticeKeyInfo<CVPLatticeKey> { + static inline Value *getValueFromLatticeKey(CVPLatticeKey Key) { + return Key.getPointer(); + } + static inline CVPLatticeKey getLatticeKeyFromValue(Value *V) { + return CVPLatticeKey(V, IPOGrouping::Register); + } +}; +} // namespace llvm + +static bool runCVP(Module &M) { + // Our custom lattice function and generic sparse propagation solver. + CVPLatticeFunc Lattice; + SparseSolver<CVPLatticeKey, CVPLatticeVal> Solver(&Lattice); + + // For each function in the module, if we can't track its arguments, let the + // generic solver assume it is executable. + for (Function &F : M) + if (!F.isDeclaration() && !canTrackArgumentsInterprocedurally(&F)) + Solver.MarkBlockExecutable(&F.front()); + + // Solver our custom lattice. In doing so, we will also build a set of + // indirect call sites. + Solver.Solve(); + + // Attach metadata to the indirect call sites that were collected indicating + // the set of functions they can possibly target. + bool Changed = false; + MDBuilder MDB(M.getContext()); + for (Instruction *C : Lattice.getIndirectCalls()) { + CallSite CS(C); + auto RegI = CVPLatticeKey(CS.getCalledValue(), IPOGrouping::Register); + CVPLatticeVal LV = Solver.getExistingValueState(RegI); + if (!LV.isFunctionSet() || LV.getFunctions().empty()) + continue; + MDNode *Callees = MDB.createCallees(SmallVector<Function *, 4>( + LV.getFunctions().begin(), LV.getFunctions().end())); + C->setMetadata(LLVMContext::MD_callees, Callees); + Changed = true; + } + + return Changed; +} + +PreservedAnalyses CalledValuePropagationPass::run(Module &M, + ModuleAnalysisManager &) { + runCVP(M); + return PreservedAnalyses::all(); +} + +namespace { +class CalledValuePropagationLegacyPass : public ModulePass { +public: + static char ID; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } + + CalledValuePropagationLegacyPass() : ModulePass(ID) { + initializeCalledValuePropagationLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + return runCVP(M); + } +}; +} // namespace + +char CalledValuePropagationLegacyPass::ID = 0; +INITIALIZE_PASS(CalledValuePropagationLegacyPass, "called-value-propagation", + "Called Value Propagation", false, false) + +ModulePass *llvm::createCalledValuePropagationPass() { + return new CalledValuePropagationLegacyPass(); +} diff --git a/lib/Transforms/IPO/ConstantMerge.cpp b/lib/Transforms/IPO/ConstantMerge.cpp index 62b5a9c9ba26..e0b1037053f0 100644 --- a/lib/Transforms/IPO/ConstantMerge.cpp +++ b/lib/Transforms/IPO/ConstantMerge.cpp @@ -19,16 +19,23 @@ #include "llvm/Transforms/IPO/ConstantMerge.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Transforms/IPO.h" +#include <algorithm> +#include <cassert> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "constmerge" @@ -102,8 +109,7 @@ static bool mergeConstants(Module &M) { // constants together may allow us to merge other constants together if the // second level constants have initializers which point to the globals that // were just merged. - while (1) { - + while (true) { // First: Find the canonical constants others will be merged with. for (Module::global_iterator GVI = M.global_begin(), E = M.global_end(); GVI != E; ) { @@ -225,23 +231,27 @@ PreservedAnalyses ConstantMergePass::run(Module &M, ModuleAnalysisManager &) { } namespace { + struct ConstantMergeLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid + ConstantMergeLegacyPass() : ModulePass(ID) { initializeConstantMergeLegacyPassPass(*PassRegistry::getPassRegistry()); } // For this pass, process all of the globals in the module, eliminating // duplicate constants. - bool runOnModule(Module &M) { + bool runOnModule(Module &M) override { if (skipModule(M)) return false; return mergeConstants(M); } }; -} + +} // end anonymous namespace char ConstantMergeLegacyPass::ID = 0; + INITIALIZE_PASS(ConstantMergeLegacyPass, "constmerge", "Merge Duplicate Global Constants", false, false) diff --git a/lib/Transforms/IPO/CrossDSOCFI.cpp b/lib/Transforms/IPO/CrossDSOCFI.cpp index d94aa5da8560..886029ea58d5 100644 --- a/lib/Transforms/IPO/CrossDSOCFI.cpp +++ b/lib/Transforms/IPO/CrossDSOCFI.cpp @@ -13,9 +13,9 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/CrossDSOCFI.h" -#include "llvm/ADT/DenseSet.h" -#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Triple.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" @@ -31,7 +31,6 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" using namespace llvm; @@ -81,7 +80,7 @@ ConstantInt *CrossDSOCFI::extractNumericTypeId(MDNode *MD) { void CrossDSOCFI::buildCFICheck(Module &M) { // FIXME: verify that __cfi_check ends up near the end of the code section, // but before the jump slots created in LowerTypeTests. - llvm::DenseSet<uint64_t> TypeIds; + SetVector<uint64_t> TypeIds; SmallVector<MDNode *, 2> Types; for (GlobalObject &GO : M.global_objects()) { Types.clear(); @@ -115,6 +114,11 @@ void CrossDSOCFI::buildCFICheck(Module &M) { // linker knows about the symbol; this pass replaces the function body. F->deleteBody(); F->setAlignment(4096); + + Triple T(M.getTargetTriple()); + if (T.isARM() || T.isThumb()) + F->addFnAttr("target-features", "+thumb-mode"); + auto args = F->arg_begin(); Value &CallSiteTypeId = *(args++); CallSiteTypeId.setName("CallSiteTypeId"); diff --git a/lib/Transforms/IPO/DeadArgumentElimination.cpp b/lib/Transforms/IPO/DeadArgumentElimination.cpp index 8e26849ea9e3..5446541550e5 100644 --- a/lib/Transforms/IPO/DeadArgumentElimination.cpp +++ b/lib/Transforms/IPO/DeadArgumentElimination.cpp @@ -1,4 +1,4 @@ -//===-- DeadArgumentElimination.cpp - Eliminate dead arguments ------------===// +//===- DeadArgumentElimination.cpp - Eliminate dead arguments -------------===// // // The LLVM Compiler Infrastructure // @@ -20,24 +20,36 @@ #include "llvm/Transforms/IPO/DeadArgumentElimination.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" -#include "llvm/IR/CallingConv.h" #include "llvm/IR/Constant.h" -#include "llvm/IR/DIBuilder.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include <set> -#include <tuple> +#include <cassert> +#include <cstdint> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "deadargelim" @@ -46,9 +58,10 @@ STATISTIC(NumArgumentsEliminated, "Number of unread args removed"); STATISTIC(NumRetValsEliminated , "Number of unused return values removed"); STATISTIC(NumArgumentsReplacedWithUndef, "Number of unread args replaced with undef"); + namespace { + /// DAE - The dead argument elimination pass. - /// class DAE : public ModulePass { protected: // DAH uses this to specify a different ID. @@ -56,6 +69,7 @@ namespace { public: static char ID; // Pass identification, replacement for typeid + DAE() : ModulePass(ID) { initializeDAEPass(*PassRegistry::getPassRegistry()); } @@ -71,33 +85,38 @@ namespace { virtual bool ShouldHackArguments() const { return false; } }; -} +} // end anonymous namespace char DAE::ID = 0; + INITIALIZE_PASS(DAE, "deadargelim", "Dead Argument Elimination", false, false) namespace { + /// DAH - DeadArgumentHacking pass - Same as dead argument elimination, but /// deletes arguments to functions which are external. This is only for use /// by bugpoint. struct DAH : public DAE { static char ID; + DAH() : DAE(ID) {} bool ShouldHackArguments() const override { return true; } }; -} + +} // end anonymous namespace char DAH::ID = 0; + INITIALIZE_PASS(DAH, "deadarghaX0r", "Dead Argument Hacking (BUGPOINT USE ONLY; DO NOT USE)", false, false) /// createDeadArgEliminationPass - This pass removes arguments from functions /// which are not used by the body of the function. -/// ModulePass *llvm::createDeadArgEliminationPass() { return new DAE(); } + ModulePass *llvm::createDeadArgHackingPass() { return new DAH(); } /// DeleteDeadVarargs - If this is an function that takes a ... list, and if @@ -140,7 +159,7 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { // the old function, but doesn't have isVarArg set. FunctionType *FTy = Fn.getFunctionType(); - std::vector<Type*> Params(FTy->param_begin(), FTy->param_end()); + std::vector<Type *> Params(FTy->param_begin(), FTy->param_end()); FunctionType *NFTy = FunctionType::get(FTy->getReturnType(), Params, false); unsigned NumArgs = Params.size(); @@ -155,7 +174,7 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { // Loop over all of the callers of the function, transforming the call sites // to pass in a smaller number of arguments into the new function. // - std::vector<Value*> Args; + std::vector<Value *> Args; for (Value::user_iterator I = Fn.user_begin(), E = Fn.user_end(); I != E; ) { CallSite CS(*I++); if (!CS) @@ -214,7 +233,6 @@ bool DeadArgumentEliminationPass::DeleteDeadVarargs(Function &Fn) { // Loop over the argument list, transferring uses of the old arguments over to // the new arguments, also transferring over the names as well. While we're at // it, remove the dead arguments from the DeadArguments list. - // for (Function::arg_iterator I = Fn.arg_begin(), E = Fn.arg_end(), I2 = NF->arg_begin(); I != E; ++I, ++I2) { // Move the name and users over to the new version. @@ -343,7 +361,6 @@ DeadArgumentEliminationPass::MarkIfNotLive(RetOrArg Use, return MaybeLive; } - /// SurveyUse - This looks at a single use of an argument or return value /// and determines if it should be alive or not. Adds this use to MaybeLiveUses /// if it causes the used value to become MaybeLive. @@ -460,7 +477,6 @@ DeadArgumentEliminationPass::SurveyUses(const Value *V, // // We consider arguments of non-internal functions to be intrinsically alive as // well as arguments to functions which have their "address taken". -// void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { // Functions with inalloca parameters are expecting args in a particular // register and memory layout. @@ -478,11 +494,14 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { } unsigned RetCount = NumRetVals(&F); + // Assume all return values are dead - typedef SmallVector<Liveness, 5> RetVals; + using RetVals = SmallVector<Liveness, 5>; + RetVals RetValLiveness(RetCount, MaybeLive); - typedef SmallVector<UseVector, 5> RetUses; + using RetUses = SmallVector<UseVector, 5>; + // These vectors map each return value to the uses that make it MaybeLive, so // we can add those to the Uses map if the return value really turns out to be // MaybeLive. Initialized to a list of RetCount empty lists. @@ -601,15 +620,15 @@ void DeadArgumentEliminationPass::SurveyFunction(const Function &F) { void DeadArgumentEliminationPass::MarkValue(const RetOrArg &RA, Liveness L, const UseVector &MaybeLiveUses) { switch (L) { - case Live: MarkLive(RA); break; + case Live: + MarkLive(RA); + break; case MaybeLive: - { // Note any uses of this value, so this return value can be // marked live whenever one of the uses becomes live. for (const auto &MaybeLiveUse : MaybeLiveUses) Uses.insert(std::make_pair(MaybeLiveUse, RA)); break; - } } } @@ -762,7 +781,7 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // One return type? Just a simple value then, but only if we didn't use to // return a struct with that simple value before. NRetTy = RetTypes.front(); - else if (RetTypes.size() == 0) + else if (RetTypes.empty()) // No return types? Make it void, but only if we didn't use to return {}. NRetTy = Type::getVoidTy(F->getContext()); } @@ -808,7 +827,6 @@ bool DeadArgumentEliminationPass::RemoveDeadStuffFromFunction(Function *F) { // Loop over all of the callers of the function, transforming the call sites // to pass in a smaller number of arguments into the new function. - // std::vector<Value*> Args; while (!F->use_empty()) { CallSite CS(F->user_back()); diff --git a/lib/Transforms/IPO/ElimAvailExtern.cpp b/lib/Transforms/IPO/ElimAvailExtern.cpp index ecff88c88dcb..d5fef59286dd 100644 --- a/lib/Transforms/IPO/ElimAvailExtern.cpp +++ b/lib/Transforms/IPO/ElimAvailExtern.cpp @@ -1,5 +1,4 @@ -//===-- ElimAvailExtern.cpp - DCE unreachable internal functions -//----------------===// +//===- ElimAvailExtern.cpp - DCE unreachable internal functions -----------===// // // The LLVM Compiler Infrastructure // @@ -15,11 +14,15 @@ #include "llvm/Transforms/IPO/ElimAvailExtern.h" #include "llvm/ADT/Statistic.h" -#include "llvm/IR/Constants.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/GlobalStatus.h" + using namespace llvm; #define DEBUG_TYPE "elim-avail-extern" @@ -69,8 +72,10 @@ EliminateAvailableExternallyPass::run(Module &M, ModuleAnalysisManager &) { } namespace { + struct EliminateAvailableExternallyLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid + EliminateAvailableExternallyLegacyPass() : ModulePass(ID) { initializeEliminateAvailableExternallyLegacyPassPass( *PassRegistry::getPassRegistry()); @@ -78,16 +83,17 @@ struct EliminateAvailableExternallyLegacyPass : public ModulePass { // run - Do the EliminateAvailableExternally pass on the specified module, // optionally updating the specified callgraph to reflect the changes. - // - bool runOnModule(Module &M) { + bool runOnModule(Module &M) override { if (skipModule(M)) return false; return eliminateAvailableExternally(M); } }; -} + +} // end anonymous namespace char EliminateAvailableExternallyLegacyPass::ID = 0; + INITIALIZE_PASS(EliminateAvailableExternallyLegacyPass, "elim-avail-extern", "Eliminate Available Externally Globals", false, false) diff --git a/lib/Transforms/IPO/ExtractGV.cpp b/lib/Transforms/IPO/ExtractGV.cpp index d1147f7d844b..042cacb70ad0 100644 --- a/lib/Transforms/IPO/ExtractGV.cpp +++ b/lib/Transforms/IPO/ExtractGV.cpp @@ -12,8 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/SetVector.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" diff --git a/lib/Transforms/IPO/ForceFunctionAttrs.cpp b/lib/Transforms/IPO/ForceFunctionAttrs.cpp index 968712138208..325a5d77aadb 100644 --- a/lib/Transforms/IPO/ForceFunctionAttrs.cpp +++ b/lib/Transforms/IPO/ForceFunctionAttrs.cpp @@ -52,11 +52,13 @@ static Attribute::AttrKind parseAttrKind(StringRef Kind) { .Case("returns_twice", Attribute::ReturnsTwice) .Case("safestack", Attribute::SafeStack) .Case("sanitize_address", Attribute::SanitizeAddress) + .Case("sanitize_hwaddress", Attribute::SanitizeHWAddress) .Case("sanitize_memory", Attribute::SanitizeMemory) .Case("sanitize_thread", Attribute::SanitizeThread) .Case("ssp", Attribute::StackProtect) .Case("sspreq", Attribute::StackProtectReq) .Case("sspstrong", Attribute::StackProtectStrong) + .Case("strictfp", Attribute::StrictFP) .Case("uwtable", Attribute::UWTable) .Default(Attribute::None); } diff --git a/lib/Transforms/IPO/FunctionAttrs.cpp b/lib/Transforms/IPO/FunctionAttrs.cpp index 813a4b6e2831..5352e32479bb 100644 --- a/lib/Transforms/IPO/FunctionAttrs.cpp +++ b/lib/Transforms/IPO/FunctionAttrs.cpp @@ -6,34 +6,61 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// -/// +// /// \file /// This file implements interprocedural passes which walk the /// call-graph deducing and/or propagating function attributes. -/// +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/FunctionAttrs.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringSwitch.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/CallGraphSCCPass.h" #include "llvm/Analysis/CaptureTracking.h" -#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" +#include <cassert> +#include <iterator> +#include <map> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "functionattrs" @@ -57,8 +84,10 @@ static cl::opt<bool> EnableNonnullArgPropagation( "caller functions.")); namespace { -typedef SmallSetVector<Function *, 8> SCCNodeSet; -} + +using SCCNodeSet = SmallSetVector<Function *, 8>; + +} // end anonymous namespace /// Returns the memory access attribute for function F using AAR for AA results, /// where SCCNodes is the current SCC. @@ -101,17 +130,18 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, SCCNodes.count(CS.getCalledFunction())) continue; FunctionModRefBehavior MRB = AAR.getModRefBehavior(CS); + ModRefInfo MRI = createModRefInfo(MRB); // If the call doesn't access memory, we're done. - if (!(MRB & MRI_ModRef)) + if (isNoModRef(MRI)) continue; if (!AliasAnalysis::onlyAccessesArgPointees(MRB)) { // The call could access any memory. If that includes writes, give up. - if (MRB & MRI_Mod) + if (isModSet(MRI)) return MAK_MayWrite; // If it reads, note it. - if (MRB & MRI_Ref) + if (isRefSet(MRI)) ReadsMemory = true; continue; } @@ -133,10 +163,10 @@ static MemoryAccessKind checkFunctionMemoryAccess(Function &F, bool ThisBody, if (AAR.pointsToConstantMemory(Loc, /*OrLocal=*/true)) continue; - if (MRB & MRI_Mod) + if (isModSet(MRI)) // Writes non-local memory. Give up. return MAK_MayWrite; - if (MRB & MRI_Ref) + if (isRefSet(MRI)) // Ok, it reads non-local memory. ReadsMemory = true; } @@ -237,6 +267,7 @@ static bool addReadAttrs(const SCCNodeSet &SCCNodes, AARGetterT &&AARGetter) { } namespace { + /// For a given pointer Argument, this retains a list of Arguments of functions /// in the same SCC that the pointer data flows into. We use this to build an /// SCC of the arguments. @@ -248,7 +279,7 @@ struct ArgumentGraphNode { class ArgumentGraph { // We store pointers to ArgumentGraphNode objects, so it's important that // that they not move around upon insert. - typedef std::map<Argument *, ArgumentGraphNode> ArgumentMapTy; + using ArgumentMapTy = std::map<Argument *, ArgumentGraphNode>; ArgumentMapTy ArgumentMap; @@ -263,7 +294,7 @@ class ArgumentGraph { public: ArgumentGraph() { SyntheticRoot.Definition = nullptr; } - typedef SmallVectorImpl<ArgumentGraphNode *>::iterator iterator; + using iterator = SmallVectorImpl<ArgumentGraphNode *>::iterator; iterator begin() { return SyntheticRoot.Uses.begin(); } iterator end() { return SyntheticRoot.Uses.end(); } @@ -281,8 +312,7 @@ public: /// consider that a capture, instead adding it to the "Uses" list and /// continuing with the analysis. struct ArgumentUsesTracker : public CaptureTracker { - ArgumentUsesTracker(const SCCNodeSet &SCCNodes) - : Captured(false), SCCNodes(SCCNodes) {} + ArgumentUsesTracker(const SCCNodeSet &SCCNodes) : SCCNodes(SCCNodes) {} void tooManyUses() override { Captured = true; } @@ -331,37 +361,45 @@ struct ArgumentUsesTracker : public CaptureTracker { return false; } - bool Captured; // True only if certainly captured (used outside our SCC). - SmallVector<Argument *, 4> Uses; // Uses within our SCC. + // True only if certainly captured (used outside our SCC). + bool Captured = false; + + // Uses within our SCC. + SmallVector<Argument *, 4> Uses; const SCCNodeSet &SCCNodes; }; -} + +} // end anonymous namespace namespace llvm { + template <> struct GraphTraits<ArgumentGraphNode *> { - typedef ArgumentGraphNode *NodeRef; - typedef SmallVectorImpl<ArgumentGraphNode *>::iterator ChildIteratorType; + using NodeRef = ArgumentGraphNode *; + using ChildIteratorType = SmallVectorImpl<ArgumentGraphNode *>::iterator; static NodeRef getEntryNode(NodeRef A) { return A; } static ChildIteratorType child_begin(NodeRef N) { return N->Uses.begin(); } static ChildIteratorType child_end(NodeRef N) { return N->Uses.end(); } }; + template <> struct GraphTraits<ArgumentGraph *> : public GraphTraits<ArgumentGraphNode *> { static NodeRef getEntryNode(ArgumentGraph *AG) { return AG->getEntryNode(); } + static ChildIteratorType nodes_begin(ArgumentGraph *AG) { return AG->begin(); } + static ChildIteratorType nodes_end(ArgumentGraph *AG) { return AG->end(); } }; -} + +} // end namespace llvm /// Returns Attribute::None, Attribute::ReadOnly or Attribute::ReadNone. static Attribute::AttrKind determinePointerReadAttrs(Argument *A, const SmallPtrSet<Argument *, 8> &SCCNodes) { - SmallVector<Use *, 32> Worklist; SmallSet<Use *, 32> Visited; @@ -502,8 +540,8 @@ static bool addArgumentReturnedAttrs(const SCCNodeSet &SCCNodes) { continue; // There is nothing to do if an argument is already marked as 'returned'. - if (any_of(F->args(), - [](const Argument &Arg) { return Arg.hasReturnedAttr(); })) + if (llvm::any_of(F->args(), + [](const Argument &Arg) { return Arg.hasReturnedAttr(); })) continue; auto FindRetArg = [&]() -> Value * { @@ -884,11 +922,13 @@ static bool isReturnNonNull(Function *F, const SCCNodeSet &SCCNodes, if (auto *Ret = dyn_cast<ReturnInst>(BB.getTerminator())) FlowsToReturn.insert(Ret->getReturnValue()); + auto &DL = F->getParent()->getDataLayout(); + for (unsigned i = 0; i != FlowsToReturn.size(); ++i) { Value *RetVal = FlowsToReturn[i]; // If this value is locally known to be non-null, we're good - if (isKnownNonNull(RetVal)) + if (isKnownNonZero(RetVal, DL)) continue; // Otherwise, we need to look upwards since we can't make any local @@ -1135,8 +1175,11 @@ PreservedAnalyses PostOrderFunctionAttrsPass::run(LazyCallGraph::SCC &C, } namespace { + struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass { - static char ID; // Pass identification, replacement for typeid + // Pass identification, replacement for typeid + static char ID; + PostOrderFunctionAttrsLegacyPass() : CallGraphSCCPass(ID) { initializePostOrderFunctionAttrsLegacyPassPass( *PassRegistry::getPassRegistry()); @@ -1151,7 +1194,8 @@ struct PostOrderFunctionAttrsLegacyPass : public CallGraphSCCPass { CallGraphSCCPass::getAnalysisUsage(AU); } }; -} + +} // end anonymous namespace char PostOrderFunctionAttrsLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(PostOrderFunctionAttrsLegacyPass, "functionattrs", @@ -1214,8 +1258,11 @@ bool PostOrderFunctionAttrsLegacyPass::runOnSCC(CallGraphSCC &SCC) { } namespace { + struct ReversePostOrderFunctionAttrsLegacyPass : public ModulePass { - static char ID; // Pass identification, replacement for typeid + // Pass identification, replacement for typeid + static char ID; + ReversePostOrderFunctionAttrsLegacyPass() : ModulePass(ID) { initializeReversePostOrderFunctionAttrsLegacyPassPass( *PassRegistry::getPassRegistry()); @@ -1229,9 +1276,11 @@ struct ReversePostOrderFunctionAttrsLegacyPass : public ModulePass { AU.addPreserved<CallGraphWrapperPass>(); } }; -} + +} // end anonymous namespace char ReversePostOrderFunctionAttrsLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(ReversePostOrderFunctionAttrsLegacyPass, "rpo-functionattrs", "Deduce function attributes in RPO", false, false) INITIALIZE_PASS_DEPENDENCY(CallGraphWrapperPass) @@ -1291,7 +1340,7 @@ static bool deduceFunctionAttributeInRPO(Module &M, CallGraph &CG) { } bool Changed = false; - for (auto *F : reverse(Worklist)) + for (auto *F : llvm::reverse(Worklist)) Changed |= addNoRecurseAttrsTopDown(*F); return Changed; diff --git a/lib/Transforms/IPO/FunctionImport.cpp b/lib/Transforms/IPO/FunctionImport.cpp index 233a36d2bc54..b6d6201cd23b 100644 --- a/lib/Transforms/IPO/FunctionImport.cpp +++ b/lib/Transforms/IPO/FunctionImport.cpp @@ -12,30 +12,54 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/FunctionImport.h" - +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringSet.h" -#include "llvm/ADT/Triple.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Bitcode/BitcodeReader.h" #include "llvm/IR/AutoUpgrade.h" -#include "llvm/IR/DiagnosticPrinter.h" -#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalObject.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Verifier.h" +#include "llvm/IR/ModuleSummaryIndex.h" #include "llvm/IRReader/IRReader.h" -#include "llvm/Linker/Linker.h" -#include "llvm/Object/IRObjectFile.h" +#include "llvm/Linker/IRMover.h" +#include "llvm/Object/ModuleSymbolTable.h" +#include "llvm/Object/SymbolicFile.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/FileSystem.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO/Internalize.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/FunctionImportUtils.h" - -#define DEBUG_TYPE "function-import" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <cassert> +#include <memory> +#include <set> +#include <string> +#include <system_error> +#include <tuple> +#include <utility> using namespace llvm; +#define DEBUG_TYPE "function-import" + STATISTIC(NumImportedFunctions, "Number of functions imported"); STATISTIC(NumImportedModules, "Number of modules imported from"); STATISTIC(NumDeadSymbols, "Number of dead stripped symbols in index"); @@ -61,7 +85,7 @@ static cl::opt<float> ImportHotInstrFactor( "before processing newly imported functions")); static cl::opt<float> ImportHotMultiplier( - "import-hot-multiplier", cl::init(3.0), cl::Hidden, cl::value_desc("x"), + "import-hot-multiplier", cl::init(10.0), cl::Hidden, cl::value_desc("x"), cl::desc("Multiply the `import-instr-limit` threshold for hot callsites")); static cl::opt<float> ImportCriticalMultiplier( @@ -91,6 +115,18 @@ static cl::opt<bool> EnableImportMetadata( ), cl::Hidden, cl::desc("Enable import metadata like 'thinlto_src_module'")); +/// Summary file to use for function importing when using -function-import from +/// the command line. +static cl::opt<std::string> + SummaryFile("summary-file", + cl::desc("The summary file to use for function importing.")); + +/// Used when testing importing from distributed indexes via opt +// -function-import. +static cl::opt<bool> + ImportAllIndex("import-all-index", + cl::desc("Import all external functions in index.")); + // Load lazily a module from \p FileName in \p Context. static std::unique_ptr<Module> loadFile(const std::string &FileName, LLVMContext &Context) { @@ -109,8 +145,6 @@ static std::unique_ptr<Module> loadFile(const std::string &FileName, return Result; } -namespace { - /// Given a list of possible callee implementation for a call site, select one /// that fits the \p Threshold. /// @@ -129,21 +163,26 @@ selectCallee(const ModuleSummaryIndex &Index, CalleeSummaryList, [&](const std::unique_ptr<GlobalValueSummary> &SummaryPtr) { auto *GVSummary = SummaryPtr.get(); + // For SamplePGO, in computeImportForFunction the OriginalId + // may have been used to locate the callee summary list (See + // comment there). + // The mapping from OriginalId to GUID may return a GUID + // that corresponds to a static variable. Filter it out here. + // This can happen when + // 1) There is a call to a library function which is not defined + // in the index. + // 2) There is a static variable with the OriginalGUID identical + // to the GUID of the library function in 1); + // When this happens, the logic for SamplePGO kicks in and + // the static variable in 2) will be found, which needs to be + // filtered out. + if (GVSummary->getSummaryKind() == GlobalValueSummary::GlobalVarKind) + return false; if (GlobalValue::isInterposableLinkage(GVSummary->linkage())) // There is no point in importing these, we can't inline them return false; - if (auto *AS = dyn_cast<AliasSummary>(GVSummary)) { - GVSummary = &AS->getAliasee(); - // Alias can't point to "available_externally". However when we import - // linkOnceODR the linkage does not change. So we import the alias - // and aliasee only in this case. - // FIXME: we should import alias as available_externally *function*, - // the destination module does need to know it is an alias. - if (!GlobalValue::isLinkOnceODRLinkage(GVSummary->linkage())) - return false; - } - auto *Summary = cast<FunctionSummary>(GVSummary); + auto *Summary = cast<FunctionSummary>(GVSummary->getBaseObject()); // If this is a local function, make sure we import the copy // in the caller's module. The only time a local function can @@ -174,9 +213,28 @@ selectCallee(const ModuleSummaryIndex &Index, return cast<GlobalValueSummary>(It->get()); } +namespace { + using EdgeInfo = std::tuple<const FunctionSummary *, unsigned /* Threshold */, GlobalValue::GUID>; +} // anonymous namespace + +static ValueInfo +updateValueInfoForIndirectCalls(const ModuleSummaryIndex &Index, ValueInfo VI) { + if (!VI.getSummaryList().empty()) + return VI; + // For SamplePGO, the indirect call targets for local functions will + // have its original name annotated in profile. We try to find the + // corresponding PGOFuncName as the GUID. + // FIXME: Consider updating the edges in the graph after building + // it, rather than needing to perform this mapping on each walk. + auto GUID = Index.getGUIDFromOriginalID(VI.getGUID()); + if (GUID == 0) + return nullptr; + return Index.getValueInfo(GUID); +} + /// Compute the list of functions to import for a given caller. Mark these /// imported functions and the symbols they reference in their source module as /// exported from their source module. @@ -191,17 +249,9 @@ static void computeImportForFunction( DEBUG(dbgs() << " edge -> " << VI.getGUID() << " Threshold:" << Threshold << "\n"); - if (VI.getSummaryList().empty()) { - // For SamplePGO, the indirect call targets for local functions will - // have its original name annotated in profile. We try to find the - // corresponding PGOFuncName as the GUID. - auto GUID = Index.getGUIDFromOriginalID(VI.getGUID()); - if (GUID == 0) - continue; - VI = Index.getValueInfo(GUID); - if (!VI) - continue; - } + VI = updateValueInfoForIndirectCalls(Index, VI); + if (!VI) + continue; if (DefinedGVSummaries.count(VI.getGUID())) { DEBUG(dbgs() << "ignored! Target already in destination module.\n"); @@ -227,16 +277,9 @@ static void computeImportForFunction( DEBUG(dbgs() << "ignored! No qualifying callee with summary found.\n"); continue; } - // "Resolve" the summary, traversing alias, - const FunctionSummary *ResolvedCalleeSummary; - if (isa<AliasSummary>(CalleeSummary)) { - ResolvedCalleeSummary = cast<FunctionSummary>( - &cast<AliasSummary>(CalleeSummary)->getAliasee()); - assert( - GlobalValue::isLinkOnceODRLinkage(ResolvedCalleeSummary->linkage()) && - "Unexpected alias to a non-linkonceODR in import list"); - } else - ResolvedCalleeSummary = cast<FunctionSummary>(CalleeSummary); + + // "Resolve" the summary + const auto *ResolvedCalleeSummary = cast<FunctionSummary>(CalleeSummary->getBaseObject()); assert(ResolvedCalleeSummary->instCount() <= NewThreshold && "selectCallee() didn't honor the threshold"); @@ -312,14 +355,12 @@ static void ComputeImportForModule( DEBUG(dbgs() << "Ignores Dead GUID: " << GVSummary.first << "\n"); continue; } - auto *Summary = GVSummary.second; - if (auto *AS = dyn_cast<AliasSummary>(Summary)) - Summary = &AS->getAliasee(); - auto *FuncSummary = dyn_cast<FunctionSummary>(Summary); + auto *FuncSummary = + dyn_cast<FunctionSummary>(GVSummary.second->getBaseObject()); if (!FuncSummary) // Skip import for global variables continue; - DEBUG(dbgs() << "Initalize import for " << GVSummary.first << "\n"); + DEBUG(dbgs() << "Initialize import for " << GVSummary.first << "\n"); computeImportForFunction(*FuncSummary, Index, ImportInstrLimit, DefinedGVSummaries, Worklist, ImportList, ExportLists); @@ -344,8 +385,6 @@ static void ComputeImportForModule( } } -} // anonymous namespace - /// Compute all the import and export for every module using the Index. void llvm::ComputeCrossModuleImport( const ModuleSummaryIndex &Index, @@ -395,11 +434,23 @@ void llvm::ComputeCrossModuleImport( #endif } +#ifndef NDEBUG +static void dumpImportListForModule(StringRef ModulePath, + FunctionImporter::ImportMapTy &ImportList) { + DEBUG(dbgs() << "* Module " << ModulePath << " imports from " + << ImportList.size() << " modules.\n"); + for (auto &Src : ImportList) { + auto SrcModName = Src.first(); + DEBUG(dbgs() << " - " << Src.second.size() << " functions imported from " + << SrcModName << "\n"); + } +} +#endif + /// Compute all the imports for the given module in the Index. void llvm::ComputeCrossModuleImportForModule( StringRef ModulePath, const ModuleSummaryIndex &Index, FunctionImporter::ImportMapTy &ImportList) { - // Collect the list of functions this module defines. // GUID -> Summary GVSummaryMapTy FunctionSummaryMap; @@ -410,13 +461,34 @@ void llvm::ComputeCrossModuleImportForModule( ComputeImportForModule(FunctionSummaryMap, Index, ImportList); #ifndef NDEBUG - DEBUG(dbgs() << "* Module " << ModulePath << " imports from " - << ImportList.size() << " modules.\n"); - for (auto &Src : ImportList) { - auto SrcModName = Src.first(); - DEBUG(dbgs() << " - " << Src.second.size() << " functions imported from " - << SrcModName << "\n"); + dumpImportListForModule(ModulePath, ImportList); +#endif +} + +// Mark all external summaries in Index for import into the given module. +// Used for distributed builds using a distributed index. +void llvm::ComputeCrossModuleImportForModuleFromIndex( + StringRef ModulePath, const ModuleSummaryIndex &Index, + FunctionImporter::ImportMapTy &ImportList) { + for (auto &GlobalList : Index) { + // Ignore entries for undefined references. + if (GlobalList.second.SummaryList.empty()) + continue; + + auto GUID = GlobalList.first; + assert(GlobalList.second.SummaryList.size() == 1 && + "Expected individual combined index to have one summary per GUID"); + auto &Summary = GlobalList.second.SummaryList[0]; + // Skip the summaries for the importing module. These are included to + // e.g. record required linkage changes. + if (Summary->modulePath() == ModulePath) + continue; + // Doesn't matter what value we plug in to the map, just needs an entry + // to provoke importing by thinBackend. + ImportList[Summary->modulePath()][GUID] = 1; } +#ifndef NDEBUG + dumpImportListForModule(ModulePath, ImportList); #endif } @@ -453,6 +525,17 @@ void llvm::computeDeadSymbols( // Make value live and add it to the worklist if it was not live before. // FIXME: we should only make the prevailing copy live here auto visit = [&](ValueInfo VI) { + // FIXME: If we knew which edges were created for indirect call profiles, + // we could skip them here. Any that are live should be reached via + // other edges, e.g. reference edges. Otherwise, using a profile collected + // on a slightly different binary might provoke preserving, importing + // and ultimately promoting calls to functions not linked into this + // binary, which increases the binary size unnecessarily. Note that + // if this code changes, the importer needs to change so that edges + // to functions marked dead are skipped. + VI = updateValueInfoForIndirectCalls(Index, VI); + if (!VI) + return; for (auto &S : VI.getSummaryList()) if (S->isLive()) return; @@ -465,17 +548,12 @@ void llvm::computeDeadSymbols( while (!Worklist.empty()) { auto VI = Worklist.pop_back_val(); for (auto &Summary : VI.getSummaryList()) { - for (auto Ref : Summary->refs()) + GlobalValueSummary *Base = Summary->getBaseObject(); + for (auto Ref : Base->refs()) visit(Ref); - if (auto *FS = dyn_cast<FunctionSummary>(Summary.get())) + if (auto *FS = dyn_cast<FunctionSummary>(Base)) for (auto Call : FS->calls()) visit(Call.first); - if (auto *AS = dyn_cast<AliasSummary>(Summary.get())) { - auto AliaseeGUID = AS->getAliasee().getOriginalName(); - ValueInfo AliaseeVI = Index.getValueInfo(AliaseeGUID); - if (AliaseeVI) - visit(AliaseeVI); - } } } Index.setWithGlobalValueDeadStripping(); @@ -600,23 +678,9 @@ void llvm::thinLTOResolveWeakForLinkerModule( /// Run internalization on \p TheModule based on symmary analysis. void llvm::thinLTOInternalizeModule(Module &TheModule, const GVSummaryMapTy &DefinedGlobals) { - // Parse inline ASM and collect the list of symbols that are not defined in - // the current module. - StringSet<> AsmUndefinedRefs; - ModuleSymbolTable::CollectAsmSymbols( - TheModule, - [&AsmUndefinedRefs](StringRef Name, object::BasicSymbolRef::Flags Flags) { - if (Flags & object::BasicSymbolRef::SF_Undefined) - AsmUndefinedRefs.insert(Name); - }); - // Declare a callback for the internalize pass that will ask for every // candidate GlobalValue if it can be internalized or not. auto MustPreserveGV = [&](const GlobalValue &GV) -> bool { - // Can't be internalized if referenced in inline asm. - if (AsmUndefinedRefs.count(GV.getName())) - return true; - // Lookup the linkage recorded in the summaries during global analysis. auto GS = DefinedGlobals.find(GV.getGUID()); if (GS == DefinedGlobals.end()) { @@ -647,12 +711,25 @@ void llvm::thinLTOInternalizeModule(Module &TheModule, // FIXME: See if we can just internalize directly here via linkage changes // based on the index, rather than invoking internalizeModule. - llvm::internalizeModule(TheModule, MustPreserveGV); + internalizeModule(TheModule, MustPreserveGV); +} + +/// Make alias a clone of its aliasee. +static Function *replaceAliasWithAliasee(Module *SrcModule, GlobalAlias *GA) { + Function *Fn = cast<Function>(GA->getBaseObject()); + + ValueToValueMapTy VMap; + Function *NewFn = CloneFunction(Fn, VMap); + // Clone should use the original alias's linkage and name, and we ensure + // all uses of alias instead use the new clone (casted if necessary). + NewFn->setLinkage(GA->getLinkage()); + GA->replaceAllUsesWith(ConstantExpr::getBitCast(NewFn, GA->getType())); + NewFn->takeName(GA); + return NewFn; } // Automatically import functions in Module \p DestModule based on the summaries // index. -// Expected<bool> FunctionImporter::importFunctions( Module &DestModule, const FunctionImporter::ImportMapTy &ImportList) { DEBUG(dbgs() << "Starting import for Module " @@ -699,10 +776,9 @@ Expected<bool> FunctionImporter::importFunctions( // Add 'thinlto_src_module' metadata for statistics and debugging. F.setMetadata( "thinlto_src_module", - llvm::MDNode::get( - DestModule.getContext(), - {llvm::MDString::get(DestModule.getContext(), - SrcModule->getSourceFileName())})); + MDNode::get(DestModule.getContext(), + {MDString::get(DestModule.getContext(), + SrcModule->getSourceFileName())})); } GlobalsToImport.insert(&F); } @@ -722,13 +798,6 @@ Expected<bool> FunctionImporter::importFunctions( } } for (GlobalAlias &GA : SrcModule->aliases()) { - // FIXME: This should eventually be controlled entirely by the summary. - if (FunctionImportGlobalProcessing::doImportAsDefinition( - &GA, &GlobalsToImport)) { - GlobalsToImport.insert(&GA); - continue; - } - if (!GA.hasName()) continue; auto GUID = GA.getGUID(); @@ -737,25 +806,25 @@ Expected<bool> FunctionImporter::importFunctions( << " " << GA.getName() << " from " << SrcModule->getSourceFileName() << "\n"); if (Import) { - // Alias can't point to "available_externally". However when we import - // linkOnceODR the linkage does not change. So we import the alias - // and aliasee only in this case. This has been handled by - // computeImportForFunction() - GlobalObject *GO = GA.getBaseObject(); - assert(GO->hasLinkOnceODRLinkage() && - "Unexpected alias to a non-linkonceODR in import list"); -#ifndef NDEBUG - if (!GlobalsToImport.count(GO)) - DEBUG(dbgs() << " alias triggers importing aliasee " << GO->getGUID() - << " " << GO->getName() << " from " - << SrcModule->getSourceFileName() << "\n"); -#endif - if (Error Err = GO->materialize()) - return std::move(Err); - GlobalsToImport.insert(GO); if (Error Err = GA.materialize()) return std::move(Err); - GlobalsToImport.insert(&GA); + // Import alias as a copy of its aliasee. + GlobalObject *Base = GA.getBaseObject(); + if (Error Err = Base->materialize()) + return std::move(Err); + auto *Fn = replaceAliasWithAliasee(SrcModule.get(), &GA); + DEBUG(dbgs() << "Is importing aliasee fn " << Base->getGUID() + << " " << Base->getName() << " from " + << SrcModule->getSourceFileName() << "\n"); + if (EnableImportMetadata) { + // Add 'thinlto_src_module' metadata for statistics and debugging. + Fn->setMetadata( + "thinlto_src_module", + MDNode::get(DestModule.getContext(), + {MDString::get(DestModule.getContext(), + SrcModule->getSourceFileName())})); + } + GlobalsToImport.insert(Fn); } } @@ -789,12 +858,6 @@ Expected<bool> FunctionImporter::importFunctions( return ImportedCount; } -/// Summary file to use for function importing when using -function-import from -/// the command line. -static cl::opt<std::string> - SummaryFile("summary-file", - cl::desc("The summary file to use for function importing.")); - static bool doImportingForModule(Module &M) { if (SummaryFile.empty()) report_fatal_error("error: -function-import requires -summary-file\n"); @@ -809,8 +872,15 @@ static bool doImportingForModule(Module &M) { // First step is collecting the import list. FunctionImporter::ImportMapTy ImportList; - ComputeCrossModuleImportForModule(M.getModuleIdentifier(), *Index, - ImportList); + // If requested, simply import all functions in the index. This is used + // when testing distributed backend handling via the opt tool, when + // we have distributed indexes containing exactly the summaries to import. + if (ImportAllIndex) + ComputeCrossModuleImportForModuleFromIndex(M.getModuleIdentifier(), *Index, + ImportList); + else + ComputeCrossModuleImportForModule(M.getModuleIdentifier(), *Index, + ImportList); // Conservatively mark all internal values as promoted. This interface is // only used when doing importing via the function importing pass. The pass @@ -848,17 +918,18 @@ static bool doImportingForModule(Module &M) { } namespace { + /// Pass that performs cross-module function import provided a summary file. class FunctionImportLegacyPass : public ModulePass { public: /// Pass identification, replacement for typeid static char ID; + explicit FunctionImportLegacyPass() : ModulePass(ID) {} + /// Specify pass name for debug output StringRef getPassName() const override { return "Function Importing"; } - explicit FunctionImportLegacyPass() : ModulePass(ID) {} - bool runOnModule(Module &M) override { if (skipModule(M)) return false; @@ -866,7 +937,8 @@ public: return doImportingForModule(M); } }; -} // anonymous namespace + +} // end anonymous namespace PreservedAnalyses FunctionImportPass::run(Module &M, ModuleAnalysisManager &AM) { @@ -881,7 +953,9 @@ INITIALIZE_PASS(FunctionImportLegacyPass, "function-import", "Summary Based Function Import", false, false) namespace llvm { + Pass *createFunctionImportPass() { return new FunctionImportLegacyPass(); } -} + +} // end namespace llvm diff --git a/lib/Transforms/IPO/GlobalDCE.cpp b/lib/Transforms/IPO/GlobalDCE.cpp index c91e8b454927..ada9eb80e680 100644 --- a/lib/Transforms/IPO/GlobalDCE.cpp +++ b/lib/Transforms/IPO/GlobalDCE.cpp @@ -18,7 +18,6 @@ #include "llvm/Transforms/IPO/GlobalDCE.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/Statistic.h" -#include "llvm/IR/Constants.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/Pass.h" @@ -115,7 +114,7 @@ void GlobalDCEPass::UpdateGVDependencies(GlobalValue &GV) { ComputeDependencies(User, Deps); Deps.erase(&GV); // Remove self-reference. for (GlobalValue *GVU : Deps) { - GVDependencies.insert(std::make_pair(GVU, &GV)); + GVDependencies[GVU].insert(&GV); } } @@ -199,8 +198,8 @@ PreservedAnalyses GlobalDCEPass::run(Module &M, ModuleAnalysisManager &MAM) { AliveGlobals.end()}; while (!NewLiveGVs.empty()) { GlobalValue *LGV = NewLiveGVs.pop_back_val(); - for (auto &&GVD : make_range(GVDependencies.equal_range(LGV))) - MarkLive(*GVD.second, &NewLiveGVs); + for (auto *GVD : GVDependencies[LGV]) + MarkLive(*GVD, &NewLiveGVs); } // Now that all globals which are needed are in the AliveGlobals set, we loop diff --git a/lib/Transforms/IPO/GlobalOpt.cpp b/lib/Transforms/IPO/GlobalOpt.cpp index 93eab680ca6b..4bb2984e3b47 100644 --- a/lib/Transforms/IPO/GlobalOpt.cpp +++ b/lib/Transforms/IPO/GlobalOpt.cpp @@ -20,22 +20,41 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MathExtras.h" @@ -45,7 +64,11 @@ #include "llvm/Transforms/Utils/Evaluator.h" #include "llvm/Transforms/Utils/GlobalStatus.h" #include "llvm/Transforms/Utils/Local.h" -#include <algorithm> +#include <cassert> +#include <cstdint> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "globalopt" @@ -139,7 +162,7 @@ static bool IsSafeComputationToRemove(Value *V, const TargetLibraryInfo *TLI) { } V = I->getOperand(0); - } while (1); + } while (true); } /// This GV is a pointer root. Loop over all users of the global and clean up @@ -220,7 +243,7 @@ static bool CleanupPointerRootUsers(GlobalVariable *GV, break; I->eraseFromParent(); I = J; - } while (1); + } while (true); I->eraseFromParent(); } } @@ -348,7 +371,6 @@ static bool isSafeSROAElementUse(Value *V) { return true; } - /// U is a direct user of the specified global value. Look at it and its uses /// and decide whether it is safe to SROA this global. static bool IsUserOfGlobalSafeForSRA(User *U, GlobalValue *GV) { @@ -402,11 +424,8 @@ static bool IsUserOfGlobalSafeForSRA(User *U, GlobalValue *GV) { } } - for (User *UU : U->users()) - if (!isSafeSROAElementUse(UU)) - return false; - - return true; + return llvm::all_of(U->users(), + [](User *UU) { return isSafeSROAElementUse(UU); }); } /// Look at all uses of the global and decide whether it is safe for us to @@ -419,6 +438,27 @@ static bool GlobalUsersSafeToSRA(GlobalValue *GV) { return true; } +/// Copy over the debug info for a variable to its SRA replacements. +static void transferSRADebugInfo(GlobalVariable *GV, GlobalVariable *NGV, + uint64_t FragmentOffsetInBits, + uint64_t FragmentSizeInBits, + unsigned NumElements) { + SmallVector<DIGlobalVariableExpression *, 1> GVs; + GV->getDebugInfo(GVs); + for (auto *GVE : GVs) { + DIVariable *Var = GVE->getVariable(); + DIExpression *Expr = GVE->getExpression(); + if (NumElements > 1) { + if (auto E = DIExpression::createFragmentExpression( + Expr, FragmentOffsetInBits, FragmentSizeInBits)) + Expr = *E; + else + return; + } + auto *NGVE = DIGlobalVariableExpression::get(GVE->getContext(), Var, Expr); + NGV->addDebugInfo(NGVE); + } +} /// Perform scalar replacement of aggregates on the specified global variable. /// This opens the door for other optimizations by exposing the behavior of the @@ -434,7 +474,7 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { Constant *Init = GV->getInitializer(); Type *Ty = Init->getType(); - std::vector<GlobalVariable*> NewGlobals; + std::vector<GlobalVariable *> NewGlobals; Module::GlobalListType &Globals = GV->getParent()->getGlobalList(); // Get the alignment of the global, either explicit or target-specific. @@ -443,9 +483,11 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { StartAlignment = DL.getABITypeAlignment(GV->getType()); if (StructType *STy = dyn_cast<StructType>(Ty)) { - NewGlobals.reserve(STy->getNumElements()); + uint64_t FragmentOffset = 0; + unsigned NumElements = STy->getNumElements(); + NewGlobals.reserve(NumElements); const StructLayout &Layout = *DL.getStructLayout(STy); - for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + for (unsigned i = 0, e = NumElements; i != e; ++i) { Constant *In = Init->getAggregateElement(i); assert(In && "Couldn't get element of initializer?"); GlobalVariable *NGV = new GlobalVariable(STy->getElementType(i), false, @@ -465,15 +507,22 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { unsigned NewAlign = (unsigned)MinAlign(StartAlignment, FieldOffset); if (NewAlign > DL.getABITypeAlignment(STy->getElementType(i))) NGV->setAlignment(NewAlign); + + // Copy over the debug info for the variable. + FragmentOffset = alignTo(FragmentOffset, NewAlign); + uint64_t Size = DL.getTypeSizeInBits(NGV->getValueType()); + transferSRADebugInfo(GV, NGV, FragmentOffset, Size, NumElements); + FragmentOffset += Size; } } else if (SequentialType *STy = dyn_cast<SequentialType>(Ty)) { unsigned NumElements = STy->getNumElements(); if (NumElements > 16 && GV->hasNUsesOrMore(16)) return nullptr; // It's not worth it. NewGlobals.reserve(NumElements); - - uint64_t EltSize = DL.getTypeAllocSize(STy->getElementType()); - unsigned EltAlign = DL.getABITypeAlignment(STy->getElementType()); + auto ElTy = STy->getElementType(); + uint64_t EltSize = DL.getTypeAllocSize(ElTy); + unsigned EltAlign = DL.getABITypeAlignment(ElTy); + uint64_t FragmentSizeInBits = DL.getTypeSizeInBits(ElTy); for (unsigned i = 0, e = NumElements; i != e; ++i) { Constant *In = Init->getAggregateElement(i); assert(In && "Couldn't get element of initializer?"); @@ -494,6 +543,8 @@ static GlobalVariable *SRAGlobal(GlobalVariable *GV, const DataLayout &DL) { unsigned NewAlign = (unsigned)MinAlign(StartAlignment, EltSize*i); if (NewAlign > EltAlign) NGV->setAlignment(NewAlign); + transferSRADebugInfo(GV, NGV, FragmentSizeInBits * i, FragmentSizeInBits, + NumElements); } } @@ -689,7 +740,6 @@ static bool OptimizeAwayTrappingUsesOfValue(Value *V, Constant *NewV) { return Changed; } - /// The specified global has only one non-null value stored into it. If there /// are uses of the loaded value that would trap if the loaded value is /// dynamically null, then we know that they cannot be reachable with a null @@ -1045,7 +1095,6 @@ static bool LoadUsesSimpleEnoughForHeapSRA(const Value *V, return true; } - /// If all users of values loaded from GV are simple enough to perform HeapSRA, /// return true. static bool AllGlobalLoadUsesSimpleEnoughForHeapSRA(const GlobalVariable *GV, @@ -1095,9 +1144,9 @@ static bool AllGlobalLoadUsesSimpleEnoughForHeapSRA(const GlobalVariable *GV, } static Value *GetHeapSROAValue(Value *V, unsigned FieldNo, - DenseMap<Value*, std::vector<Value*> > &InsertedScalarizedValues, - std::vector<std::pair<PHINode*, unsigned> > &PHIsToRewrite) { - std::vector<Value*> &FieldVals = InsertedScalarizedValues[V]; + DenseMap<Value *, std::vector<Value *>> &InsertedScalarizedValues, + std::vector<std::pair<PHINode *, unsigned>> &PHIsToRewrite) { + std::vector<Value *> &FieldVals = InsertedScalarizedValues[V]; if (FieldNo >= FieldVals.size()) FieldVals.resize(FieldNo+1); @@ -1139,8 +1188,8 @@ static Value *GetHeapSROAValue(Value *V, unsigned FieldNo, /// Given a load instruction and a value derived from the load, rewrite the /// derived value to use the HeapSRoA'd load. static void RewriteHeapSROALoadUser(Instruction *LoadUser, - DenseMap<Value*, std::vector<Value*> > &InsertedScalarizedValues, - std::vector<std::pair<PHINode*, unsigned> > &PHIsToRewrite) { + DenseMap<Value *, std::vector<Value *>> &InsertedScalarizedValues, + std::vector<std::pair<PHINode *, unsigned>> &PHIsToRewrite) { // If this is a comparison against null, handle it. if (ICmpInst *SCI = dyn_cast<ICmpInst>(LoadUser)) { assert(isa<ConstantPointerNull>(SCI->getOperand(1))); @@ -1187,7 +1236,7 @@ static void RewriteHeapSROALoadUser(Instruction *LoadUser, // processed. PHINode *PN = cast<PHINode>(LoadUser); if (!InsertedScalarizedValues.insert(std::make_pair(PN, - std::vector<Value*>())).second) + std::vector<Value *>())).second) return; // If this is the first time we've seen this PHI, recursively process all @@ -1202,8 +1251,8 @@ static void RewriteHeapSROALoadUser(Instruction *LoadUser, /// global. Eliminate all uses of Ptr, making them use FieldGlobals instead. /// All uses of loaded values satisfy AllGlobalLoadUsesSimpleEnoughForHeapSRA. static void RewriteUsesOfLoadForHeapSRoA(LoadInst *Load, - DenseMap<Value*, std::vector<Value*> > &InsertedScalarizedValues, - std::vector<std::pair<PHINode*, unsigned> > &PHIsToRewrite) { + DenseMap<Value *, std::vector<Value *>> &InsertedScalarizedValues, + std::vector<std::pair<PHINode *, unsigned> > &PHIsToRewrite) { for (auto UI = Load->user_begin(), E = Load->user_end(); UI != E;) { Instruction *User = cast<Instruction>(*UI++); RewriteHeapSROALoadUser(User, InsertedScalarizedValues, PHIsToRewrite); @@ -1232,8 +1281,8 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, // Okay, at this point, there are no users of the malloc. Insert N // new mallocs at the same place as CI, and N globals. - std::vector<Value*> FieldGlobals; - std::vector<Value*> FieldMallocs; + std::vector<Value *> FieldGlobals; + std::vector<Value *> FieldMallocs; SmallVector<OperandBundleDef, 1> OpBundles; CI->getOperandBundlesAsDefs(OpBundles); @@ -1330,10 +1379,10 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, /// As we process loads, if we can't immediately update all uses of the load, /// keep track of what scalarized loads are inserted for a given load. - DenseMap<Value*, std::vector<Value*> > InsertedScalarizedValues; + DenseMap<Value *, std::vector<Value *>> InsertedScalarizedValues; InsertedScalarizedValues[GV] = FieldGlobals; - std::vector<std::pair<PHINode*, unsigned> > PHIsToRewrite; + std::vector<std::pair<PHINode *, unsigned>> PHIsToRewrite; // Okay, the malloc site is completely handled. All of the uses of GV are now // loads, and all uses of those loads are simple. Rewrite them to use loads @@ -1379,7 +1428,7 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, } // Drop all inter-phi links and any loads that made it this far. - for (DenseMap<Value*, std::vector<Value*> >::iterator + for (DenseMap<Value *, std::vector<Value *>>::iterator I = InsertedScalarizedValues.begin(), E = InsertedScalarizedValues.end(); I != E; ++I) { if (PHINode *PN = dyn_cast<PHINode>(I->first)) @@ -1389,7 +1438,7 @@ static GlobalVariable *PerformHeapAllocSRoA(GlobalVariable *GV, CallInst *CI, } // Delete all the phis and loads now that inter-references are dead. - for (DenseMap<Value*, std::vector<Value*> >::iterator + for (DenseMap<Value *, std::vector<Value *>>::iterator I = InsertedScalarizedValues.begin(), E = InsertedScalarizedValues.end(); I != E; ++I) { if (PHINode *PN = dyn_cast<PHINode>(I->first)) @@ -1576,12 +1625,57 @@ static bool TryToShrinkGlobalToBoolean(GlobalVariable *GV, Constant *OtherVal) { assert(InitVal->getType() != Type::getInt1Ty(GV->getContext()) && "No reason to shrink to bool!"); + SmallVector<DIGlobalVariableExpression *, 1> GVs; + GV->getDebugInfo(GVs); + // If initialized to zero and storing one into the global, we can use a cast // instead of a select to synthesize the desired value. bool IsOneZero = false; - if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)) + bool EmitOneOrZero = true; + if (ConstantInt *CI = dyn_cast<ConstantInt>(OtherVal)){ IsOneZero = InitVal->isNullValue() && CI->isOne(); + if (ConstantInt *CIInit = dyn_cast<ConstantInt>(GV->getInitializer())){ + uint64_t ValInit = CIInit->getZExtValue(); + uint64_t ValOther = CI->getZExtValue(); + uint64_t ValMinus = ValOther - ValInit; + + for(auto *GVe : GVs){ + DIGlobalVariable *DGV = GVe->getVariable(); + DIExpression *E = GVe->getExpression(); + + // It is expected that the address of global optimized variable is on + // top of the stack. After optimization, value of that variable will + // be ether 0 for initial value or 1 for other value. The following + // expression should return constant integer value depending on the + // value at global object address: + // val * (ValOther - ValInit) + ValInit: + // DW_OP_deref DW_OP_constu <ValMinus> + // DW_OP_mul DW_OP_constu <ValInit> DW_OP_plus DW_OP_stack_value + E = DIExpression::get(NewGV->getContext(), + {dwarf::DW_OP_deref, + dwarf::DW_OP_constu, + ValMinus, + dwarf::DW_OP_mul, + dwarf::DW_OP_constu, + ValInit, + dwarf::DW_OP_plus, + dwarf::DW_OP_stack_value}); + DIGlobalVariableExpression *DGVE = + DIGlobalVariableExpression::get(NewGV->getContext(), DGV, E); + NewGV->addDebugInfo(DGVE); + } + EmitOneOrZero = false; + } + } + + if (EmitOneOrZero) { + // FIXME: This will only emit address for debugger on which will + // be written only 0 or 1. + for(auto *GV : GVs) + NewGV->addDebugInfo(GV); + } + while (!GV->use_empty()) { Instruction *UI = cast<Instruction>(GV->user_back()); if (StoreInst *SI = dyn_cast<StoreInst>(UI)) { @@ -2202,7 +2296,7 @@ static void setUsedInitializer(GlobalVariable &V, // Type of pointer to the array of pointers. PointerType *Int8PtrTy = Type::getInt8PtrTy(V.getContext(), 0); - SmallVector<llvm::Constant *, 8> UsedArray; + SmallVector<Constant *, 8> UsedArray; for (GlobalValue *GV : Init) { Constant *Cast = ConstantExpr::getPointerBitCastOrAddrSpaceCast(GV, Int8PtrTy); @@ -2215,14 +2309,15 @@ static void setUsedInitializer(GlobalVariable &V, Module *M = V.getParent(); V.removeFromParent(); GlobalVariable *NV = - new GlobalVariable(*M, ATy, false, llvm::GlobalValue::AppendingLinkage, - llvm::ConstantArray::get(ATy, UsedArray), ""); + new GlobalVariable(*M, ATy, false, GlobalValue::AppendingLinkage, + ConstantArray::get(ATy, UsedArray), ""); NV->takeName(&V); NV->setSection("llvm.metadata"); delete &V; } namespace { + /// An easy to access representation of llvm.used and llvm.compiler.used. class LLVMUsed { SmallPtrSet<GlobalValue *, 8> Used; @@ -2235,25 +2330,34 @@ public: UsedV = collectUsedGlobalVariables(M, Used, false); CompilerUsedV = collectUsedGlobalVariables(M, CompilerUsed, true); } - typedef SmallPtrSet<GlobalValue *, 8>::iterator iterator; - typedef iterator_range<iterator> used_iterator_range; + + using iterator = SmallPtrSet<GlobalValue *, 8>::iterator; + using used_iterator_range = iterator_range<iterator>; + iterator usedBegin() { return Used.begin(); } iterator usedEnd() { return Used.end(); } + used_iterator_range used() { return used_iterator_range(usedBegin(), usedEnd()); } + iterator compilerUsedBegin() { return CompilerUsed.begin(); } iterator compilerUsedEnd() { return CompilerUsed.end(); } + used_iterator_range compilerUsed() { return used_iterator_range(compilerUsedBegin(), compilerUsedEnd()); } + bool usedCount(GlobalValue *GV) const { return Used.count(GV); } + bool compilerUsedCount(GlobalValue *GV) const { return CompilerUsed.count(GV); } + bool usedErase(GlobalValue *GV) { return Used.erase(GV); } bool compilerUsedErase(GlobalValue *GV) { return CompilerUsed.erase(GV); } bool usedInsert(GlobalValue *GV) { return Used.insert(GV).second; } + bool compilerUsedInsert(GlobalValue *GV) { return CompilerUsed.insert(GV).second; } @@ -2265,7 +2369,8 @@ public: setUsedInitializer(*CompilerUsedV, CompilerUsed); } }; -} + +} // end anonymous namespace static bool hasUseOtherThanLLVMUsed(GlobalAlias &GA, const LLVMUsed &U) { if (GA.use_empty()) // No use at all. @@ -2580,8 +2685,10 @@ PreservedAnalyses GlobalOptPass::run(Module &M, ModuleAnalysisManager &AM) { } namespace { + struct GlobalOptLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid + GlobalOptLegacyPass() : ModulePass(ID) { initializeGlobalOptLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -2603,9 +2710,11 @@ struct GlobalOptLegacyPass : public ModulePass { AU.addRequired<DominatorTreeWrapperPass>(); } }; -} + +} // end anonymous namespace char GlobalOptLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(GlobalOptLegacyPass, "globalopt", "Global Variable Optimizer", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) diff --git a/lib/Transforms/IPO/GlobalSplit.cpp b/lib/Transforms/IPO/GlobalSplit.cpp index e47d881d1127..792f4b3052a3 100644 --- a/lib/Transforms/IPO/GlobalSplit.cpp +++ b/lib/Transforms/IPO/GlobalSplit.cpp @@ -15,22 +15,30 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/GlobalSplit.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Transforms/IPO.h" - -#include <set> +#include <cstdint> +#include <vector> using namespace llvm; -namespace { - -bool splitGlobal(GlobalVariable &GV) { +static bool splitGlobal(GlobalVariable &GV) { // If the address of the global is taken outside of the module, we cannot // apply this transformation. if (!GV.hasLocalLinkage()) @@ -130,7 +138,7 @@ bool splitGlobal(GlobalVariable &GV) { return true; } -bool splitGlobals(Module &M) { +static bool splitGlobals(Module &M) { // First, see if the module uses either of the llvm.type.test or // llvm.type.checked.load intrinsics, which indicates that splitting globals // may be beneficial. @@ -151,12 +159,16 @@ bool splitGlobals(Module &M) { return Changed; } +namespace { + struct GlobalSplit : public ModulePass { static char ID; + GlobalSplit() : ModulePass(ID) { initializeGlobalSplitPass(*PassRegistry::getPassRegistry()); } - bool runOnModule(Module &M) { + + bool runOnModule(Module &M) override { if (skipModule(M)) return false; @@ -164,11 +176,12 @@ struct GlobalSplit : public ModulePass { } }; -} +} // end anonymous namespace -INITIALIZE_PASS(GlobalSplit, "globalsplit", "Global splitter", false, false) char GlobalSplit::ID = 0; +INITIALIZE_PASS(GlobalSplit, "globalsplit", "Global splitter", false, false) + ModulePass *llvm::createGlobalSplitPass() { return new GlobalSplit; } diff --git a/lib/Transforms/IPO/IPO.cpp b/lib/Transforms/IPO/IPO.cpp index 5bb305ca84d0..d5d35ee89e0e 100644 --- a/lib/Transforms/IPO/IPO.cpp +++ b/lib/Transforms/IPO/IPO.cpp @@ -25,6 +25,7 @@ using namespace llvm; void llvm::initializeIPO(PassRegistry &Registry) { initializeArgPromotionPass(Registry); + initializeCalledValuePropagationLegacyPassPass(Registry); initializeConstantMergeLegacyPassPass(Registry); initializeCrossDSOCFIPass(Registry); initializeDAEPass(Registry); @@ -67,6 +68,10 @@ void LLVMAddArgumentPromotionPass(LLVMPassManagerRef PM) { unwrap(PM)->add(createArgumentPromotionPass()); } +void LLVMAddCalledValuePropagationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCalledValuePropagationPass()); +} + void LLVMAddConstantMergePass(LLVMPassManagerRef PM) { unwrap(PM)->add(createConstantMergePass()); } diff --git a/lib/Transforms/IPO/InferFunctionAttrs.cpp b/lib/Transforms/IPO/InferFunctionAttrs.cpp index 15d7515cc842..470f97b8ba61 100644 --- a/lib/Transforms/IPO/InferFunctionAttrs.cpp +++ b/lib/Transforms/IPO/InferFunctionAttrs.cpp @@ -8,7 +8,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/InferFunctionAttrs.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/LLVMContext.h" diff --git a/lib/Transforms/IPO/InlineSimple.cpp b/lib/Transforms/IPO/InlineSimple.cpp index 50e7cc89a3b3..b259a0abd63c 100644 --- a/lib/Transforms/IPO/InlineSimple.cpp +++ b/lib/Transforms/IPO/InlineSimple.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -21,7 +20,6 @@ #include "llvm/IR/CallingConv.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/Transforms/IPO.h" @@ -57,12 +55,23 @@ public: InlineCost getInlineCost(CallSite CS) override { Function *Callee = CS.getCalledFunction(); TargetTransformInfo &TTI = TTIWP->getTTI(*Callee); + + bool RemarksEnabled = false; + const auto &BBs = CS.getCaller()->getBasicBlockList(); + if (!BBs.empty()) { + auto DI = OptimizationRemark(DEBUG_TYPE, "", DebugLoc(), &BBs.front()); + if (DI.isEnabled()) + RemarksEnabled = true; + } + OptimizationRemarkEmitter ORE(CS.getCaller()); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; return llvm::getInlineCost(CS, Params, TTI, GetAssumptionCache, - /*GetBFI=*/None, PSI); + /*GetBFI=*/None, PSI, + RemarksEnabled ? &ORE : nullptr); } bool runOnSCC(CallGraphSCC &SCC) override; diff --git a/lib/Transforms/IPO/Inliner.cpp b/lib/Transforms/IPO/Inliner.cpp index 317770d133b3..4449c87ddefa 100644 --- a/lib/Transforms/IPO/Inliner.cpp +++ b/lib/Transforms/IPO/Inliner.cpp @@ -14,29 +14,60 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/Inliner.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" #include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/CGSCCPassManager.h" #include "llvm/Analysis/CallGraph.h" #include "llvm/Analysis/InlineCost.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/LazyCallGraph.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ImportedFunctionsInliningStatistics.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include <algorithm> +#include <cassert> +#include <functional> +#include <tuple> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "inline" @@ -63,13 +94,16 @@ static cl::opt<bool> cl::init(false), cl::Hidden); namespace { + enum class InlinerFunctionImportStatsOpts { No = 0, Basic = 1, Verbose = 2, }; -cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats( +} // end anonymous namespace + +static cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats( "inliner-function-import-stats", cl::init(InlinerFunctionImportStatsOpts::No), cl::values(clEnumValN(InlinerFunctionImportStatsOpts::Basic, "basic", @@ -77,10 +111,8 @@ cl::opt<InlinerFunctionImportStatsOpts> InlinerFunctionImportStats( clEnumValN(InlinerFunctionImportStatsOpts::Verbose, "verbose", "printing of statistics for each inlined function")), cl::Hidden, cl::desc("Enable inliner stats for imported functions")); -} // namespace -LegacyInlinerBase::LegacyInlinerBase(char &ID) - : CallGraphSCCPass(ID), InsertLifetime(true) {} +LegacyInlinerBase::LegacyInlinerBase(char &ID) : CallGraphSCCPass(ID) {} LegacyInlinerBase::LegacyInlinerBase(char &ID, bool InsertLifetime) : CallGraphSCCPass(ID), InsertLifetime(InsertLifetime) {} @@ -96,7 +128,7 @@ void LegacyInlinerBase::getAnalysisUsage(AnalysisUsage &AU) const { CallGraphSCCPass::getAnalysisUsage(AU); } -typedef DenseMap<ArrayType *, std::vector<AllocaInst *>> InlinedArrayAllocasTy; +using InlinedArrayAllocasTy = DenseMap<ArrayType *, std::vector<AllocaInst *>>; /// Look at all of the allocas that we inlined through this call site. If we /// have already inlined other allocas through other calls into this function, @@ -161,7 +193,6 @@ static void mergeInlinedArrayAllocas( // function. Also, AllocasForType can be empty of course! bool MergedAwayAlloca = false; for (AllocaInst *AvailableAlloca : AllocasForType) { - unsigned Align1 = AI->getAlignment(), Align2 = AvailableAlloca->getAlignment(); @@ -267,7 +298,6 @@ static bool shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, int &TotalSecondaryCost, function_ref<InlineCost(CallSite CS)> GetInlineCost) { - // For now we only handle local or inline functions. if (!Caller->hasLocalLinkage() && !Caller->hasLinkOnceODRLinkage()) return false; @@ -335,11 +365,14 @@ shouldBeDeferred(Function *Caller, CallSite CS, InlineCost IC, return false; } -/// Return true if the inliner should attempt to inline at the given CallSite. -static bool shouldInline(CallSite CS, - function_ref<InlineCost(CallSite CS)> GetInlineCost, - OptimizationRemarkEmitter &ORE) { +/// Return the cost only if the inliner should attempt to inline at the given +/// CallSite. If we return the cost, we will emit an optimisation remark later +/// using that cost, so we won't do so from this function. +static Optional<InlineCost> +shouldInline(CallSite CS, function_ref<InlineCost(CallSite CS)> GetInlineCost, + OptimizationRemarkEmitter &ORE) { using namespace ore; + InlineCost IC = GetInlineCost(CS); Instruction *Call = CS.getInstruction(); Function *Callee = CS.getCalledFunction(); @@ -348,32 +381,33 @@ static bool shouldInline(CallSite CS, if (IC.isAlways()) { DEBUG(dbgs() << " Inlining: cost=always" << ", Call: " << *CS.getInstruction() << "\n"); - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "AlwaysInline", Call) - << NV("Callee", Callee) - << " should always be inlined (cost=always)"); - return true; + return IC; } if (IC.isNever()) { DEBUG(dbgs() << " NOT Inlining: cost=never" << ", Call: " << *CS.getInstruction() << "\n"); - ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) << NV("Callee", Callee) << " not inlined into " << NV("Caller", Caller) - << " because it should never be inlined (cost=never)"); - return false; + << " because it should never be inlined (cost=never)"; + }); + return None; } if (!IC) { DEBUG(dbgs() << " NOT Inlining: cost=" << IC.getCost() - << ", thres=" << (IC.getCostDelta() + IC.getCost()) + << ", thres=" << IC.getThreshold() << ", Call: " << *CS.getInstruction() << "\n"); - ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "TooCostly", Call) + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "TooCostly", Call) << NV("Callee", Callee) << " not inlined into " << NV("Caller", Caller) << " because too costly to inline (cost=" - << NV("Cost", IC.getCost()) << ", threshold=" - << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); - return false; + << NV("Cost", IC.getCost()) + << ", threshold=" << NV("Threshold", IC.getThreshold()) << ")"; + }); + return None; } int TotalSecondaryCost = 0; @@ -381,23 +415,23 @@ static bool shouldInline(CallSite CS, DEBUG(dbgs() << " NOT Inlining: " << *CS.getInstruction() << " Cost = " << IC.getCost() << ", outer Cost = " << TotalSecondaryCost << '\n'); - ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "IncreaseCostInOtherContexts", + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "IncreaseCostInOtherContexts", Call) << "Not inlining. Cost of inlining " << NV("Callee", Callee) << " increases the cost of inlining " << NV("Caller", Caller) - << " in other contexts"); - return false; + << " in other contexts"; + }); + + // IC does not bool() to false, so get an InlineCost that will. + // This will not be inspected to make an error message. + return None; } DEBUG(dbgs() << " Inlining: cost=" << IC.getCost() - << ", thres=" << (IC.getCostDelta() + IC.getCost()) + << ", thres=" << IC.getThreshold() << ", Call: " << *CS.getInstruction() << '\n'); - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "CanBeInlined", Call) - << NV("Callee", Callee) << " can be inlined into " - << NV("Caller", Caller) << " with cost=" << NV("Cost", IC.getCost()) - << " (threshold=" - << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); - return true; + return IC; } /// Return true if the specified inline history ID @@ -475,11 +509,14 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, if (Function *Callee = CS.getCalledFunction()) if (Callee->isDeclaration()) { using namespace ore; - ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) + + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoDefinition", &I) << NV("Callee", Callee) << " will not be inlined into " << NV("Caller", CS.getCaller()) << " because its definition is unavailable" - << setIsVerbose()); + << setIsVerbose(); + }); continue; } @@ -545,9 +582,10 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, // just become a regular analysis dependency. OptimizationRemarkEmitter ORE(Caller); + Optional<InlineCost> OIC = shouldInline(CS, GetInlineCost, ORE); // If the policy determines that we should inline this function, // delete the call instead. - if (!shouldInline(CS, GetInlineCost, ORE)) + if (!OIC) continue; // If this call site is dead and it is to a readonly function, we should @@ -562,26 +600,40 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, ++NumCallsDeleted; } else { // Get DebugLoc to report. CS will be invalid after Inliner. - DebugLoc DLoc = Instr->getDebugLoc(); + DebugLoc DLoc = CS->getDebugLoc(); BasicBlock *Block = CS.getParent(); // Attempt to inline the function. using namespace ore; + if (!InlineCallIfPossible(CS, InlineInfo, InlinedArrayAllocas, InlineHistoryID, InsertLifetime, AARGetter, ImportedFunctionsStats)) { - ORE.emit( - OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) - << NV("Callee", Callee) << " will not be inlined into " - << NV("Caller", Caller)); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, + Block) + << NV("Callee", Callee) << " will not be inlined into " + << NV("Caller", Caller); + }); continue; } ++NumInlined; - // Report the inline decision. - ORE.emit(OptimizationRemark(DEBUG_TYPE, "Inlined", DLoc, Block) - << NV("Callee", Callee) << " inlined into " - << NV("Caller", Caller)); + ORE.emit([&]() { + bool AlwaysInline = OIC->isAlways(); + StringRef RemarkName = AlwaysInline ? "AlwaysInline" : "Inlined"; + OptimizationRemark R(DEBUG_TYPE, RemarkName, DLoc, Block); + R << NV("Callee", Callee) << " inlined into "; + R << NV("Caller", Caller); + if (AlwaysInline) + R << " with cost=always"; + else { + R << " with cost=" << NV("Cost", OIC->getCost()); + R << " (threshold=" << NV("Threshold", OIC->getThreshold()); + R << ")"; + } + return R; + }); // If inlining this function gave us any new call sites, throw them // onto our worklist to process. They are useful inline candidates. @@ -601,7 +653,6 @@ inlineCallsImpl(CallGraphSCC &SCC, CallGraph &CG, if (Callee && Callee->use_empty() && Callee->hasLocalLinkage() && // TODO: Can remove if in SCC now. !SCCFunctions.count(Callee) && - // The function may be apparently dead, but if there are indirect // callgraph references to the node, we cannot delete it yet, this // could invalidate the CGSCC iterator. @@ -840,6 +891,10 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerCGSCCProxy>(*C, CG) .getManager(); + + // Get the remarks emission analysis for the caller. + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&](Function &F) -> AssumptionCache & { return FAM.getResult<AssumptionAnalysis>(F); @@ -852,12 +907,9 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, Function &Callee = *CS.getCalledFunction(); auto &CalleeTTI = FAM.getResult<TargetIRAnalysis>(Callee); return getInlineCost(CS, Params, CalleeTTI, GetAssumptionCache, {GetBFI}, - PSI); + PSI, &ORE); }; - // Get the remarks emission analysis for the caller. - auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); - // Now process as many calls as we have within this caller in the sequnece. // We bail out as soon as the caller has to change so we can update the // call graph and prepare the context of that new caller. @@ -872,8 +924,22 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, InlineHistoryIncludes(&Callee, InlineHistoryID, InlineHistory)) continue; + // Check if this inlining may repeat breaking an SCC apart that has + // already been split once before. In that case, inlining here may + // trigger infinite inlining, much like is prevented within the inliner + // itself by the InlineHistory above, but spread across CGSCC iterations + // and thus hidden from the full inline history. + if (CG.lookupSCC(*CG.lookup(Callee)) == C && + UR.InlinedInternalEdges.count({&N, C})) { + DEBUG(dbgs() << "Skipping inlining internal SCC edge from a node " + "previously split out of this SCC by inlining: " + << F.getName() << " -> " << Callee.getName() << "\n"); + continue; + } + + Optional<InlineCost> OIC = shouldInline(CS, GetInlineCost, ORE); // Check whether we want to inline this callsite. - if (!shouldInline(CS, GetInlineCost, ORE)) + if (!OIC) continue; // Setup the data structure used to plumb customization into the @@ -883,11 +949,39 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, &FAM.getResult<BlockFrequencyAnalysis>(*(CS.getCaller())), &FAM.getResult<BlockFrequencyAnalysis>(Callee)); - if (!InlineFunction(CS, IFI)) + // Get DebugLoc to report. CS will be invalid after Inliner. + DebugLoc DLoc = CS->getDebugLoc(); + BasicBlock *Block = CS.getParent(); + + using namespace ore; + + if (!InlineFunction(CS, IFI)) { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotInlined", DLoc, Block) + << NV("Callee", &Callee) << " will not be inlined into " + << NV("Caller", &F); + }); continue; + } DidInline = true; InlinedCallees.insert(&Callee); + ORE.emit([&]() { + bool AlwaysInline = OIC->isAlways(); + StringRef RemarkName = AlwaysInline ? "AlwaysInline" : "Inlined"; + OptimizationRemark R(DEBUG_TYPE, RemarkName, DLoc, Block); + R << NV("Callee", &Callee) << " inlined into "; + R << NV("Caller", &F); + if (AlwaysInline) + R << " with cost=always"; + else { + R << " with cost=" << NV("Cost", OIC->getCost()); + R << " (threshold=" << NV("Threshold", OIC->getThreshold()); + R << ")"; + } + return R; + }); + // Add any new callsites to defined functions to the worklist. if (!IFI.InlinedCallSites.empty()) { int NewHistoryID = InlineHistory.size(); @@ -949,17 +1043,38 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, for (LazyCallGraph::Edge &E : *CalleeN) RC->insertTrivialRefEdge(N, E.getNode()); } - InlinedCallees.clear(); // At this point, since we have made changes we have at least removed // a call instruction. However, in the process we do some incremental // simplification of the surrounding code. This simplification can // essentially do all of the same things as a function pass and we can // re-use the exact same logic for updating the call graph to reflect the - // change.. + // change. + LazyCallGraph::SCC *OldC = C; C = &updateCGAndAnalysisManagerForFunctionPass(CG, *C, N, AM, UR); DEBUG(dbgs() << "Updated inlining SCC: " << *C << "\n"); RC = &C->getOuterRefSCC(); + + // If this causes an SCC to split apart into multiple smaller SCCs, there + // is a subtle risk we need to prepare for. Other transformations may + // expose an "infinite inlining" opportunity later, and because of the SCC + // mutation, we will revisit this function and potentially re-inline. If we + // do, and that re-inlining also has the potentially to mutate the SCC + // structure, the infinite inlining problem can manifest through infinite + // SCC splits and merges. To avoid this, we capture the originating caller + // node and the SCC containing the call edge. This is a slight over + // approximation of the possible inlining decisions that must be avoided, + // but is relatively efficient to store. + // FIXME: This seems like a very heavyweight way of retaining the inline + // history, we should look for a more efficient way of tracking it. + if (C != OldC && llvm::any_of(InlinedCallees, [&](Function *Callee) { + return CG.lookupSCC(*CG.lookup(*Callee)) == OldC; + })) { + DEBUG(dbgs() << "Inlined an internal call edge and split an SCC, " + "retaining this to avoid infinite inlining.\n"); + UR.InlinedInternalEdges.insert({&N, OldC}); + } + InlinedCallees.clear(); } // Now that we've finished inlining all of the calls across this SCC, delete @@ -976,8 +1091,8 @@ PreservedAnalyses InlinerPass::run(LazyCallGraph::SCC &InitialC, FunctionAnalysisManager &FAM = AM.getResult<FunctionAnalysisManagerCGSCCProxy>(DeadC, CG) .getManager(); - FAM.clear(*DeadF); - AM.clear(DeadC); + FAM.clear(*DeadF, DeadF->getName()); + AM.clear(DeadC, DeadC.getName()); auto &DeadRC = DeadC.getOuterRefSCC(); CG.removeDeadFunction(*DeadF); diff --git a/lib/Transforms/IPO/LoopExtractor.cpp b/lib/Transforms/IPO/LoopExtractor.cpp index c74b0a35e296..36b6bdba2cd0 100644 --- a/lib/Transforms/IPO/LoopExtractor.cpp +++ b/lib/Transforms/IPO/LoopExtractor.cpp @@ -80,7 +80,7 @@ INITIALIZE_PASS(SingleLoopExtractor, "loop-extract-single", // Pass *llvm::createLoopExtractorPass() { return new LoopExtractor(); } -bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &) { +bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &LPM) { if (skipLoop(L)) return false; @@ -143,7 +143,8 @@ bool LoopExtractor::runOnLoop(Loop *L, LPPassManager &) { Changed = true; // After extraction, the loop is replaced by a function call, so // we shouldn't try to run any more loop passes on it. - LI.markAsRemoved(L); + LPM.markLoopAsDeleted(*L); + LI.erase(L); } ++NumExtracted; } diff --git a/lib/Transforms/IPO/LowerTypeTests.cpp b/lib/Transforms/IPO/LowerTypeTests.cpp index 693df5e7ba92..8db7e1e142d2 100644 --- a/lib/Transforms/IPO/LowerTypeTests.cpp +++ b/lib/Transforms/IPO/LowerTypeTests.cpp @@ -1,4 +1,4 @@ -//===-- LowerTypeTests.cpp - type metadata lowering pass ------------------===// +//===- LowerTypeTests.cpp - type metadata lowering pass -------------------===// // // The LLVM Compiler Infrastructure // @@ -13,32 +13,70 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/LowerTypeTests.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/TinyPtrVector.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/TypeMetadataUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalObject.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/ModuleSummaryIndex.h" #include "llvm/IR/ModuleSummaryIndexYAML.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FileSystem.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/TrailingObjects.h" +#include "llvm/Support/YAMLTraits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <memory> +#include <set> +#include <string> +#include <system_error> +#include <utility> +#include <vector> using namespace llvm; using namespace lowertypetests; @@ -197,6 +235,7 @@ struct ByteArrayInfo { uint64_t BitSize; GlobalVariable *ByteArray; GlobalVariable *MaskGlobal; + uint8_t *MaskPtr = nullptr; }; /// A POD-like structure that we use to store a global reference together with @@ -205,16 +244,19 @@ struct ByteArrayInfo { /// operation involving a map lookup; this data structure helps to reduce the /// number of times we need to do this lookup. class GlobalTypeMember final : TrailingObjects<GlobalTypeMember, MDNode *> { + friend TrailingObjects; + GlobalObject *GO; size_t NTypes; + // For functions: true if this is a definition (either in the merged module or // in one of the thinlto modules). bool IsDefinition; + // For functions: true if this function is either defined or used in a thinlto // module and its jumptable entry needs to be exported to thinlto backends. bool IsExported; - friend TrailingObjects; size_t numTrailingObjects(OverloadToken<MDNode *>) const { return NTypes; } public: @@ -231,15 +273,19 @@ public: GTM->getTrailingObjects<MDNode *>()); return GTM; } + GlobalObject *getGlobal() const { return GO; } + bool isDefinition() const { return IsDefinition; } + bool isExported() const { return IsExported; } + ArrayRef<MDNode *> types() const { return makeArrayRef(getTrailingObjects<MDNode *>(), NTypes); } @@ -258,6 +304,7 @@ class LowerTypeTestsModule { IntegerType *Int1Ty = Type::getInt1Ty(M.getContext()); IntegerType *Int8Ty = Type::getInt8Ty(M.getContext()); PointerType *Int8PtrTy = Type::getInt8PtrTy(M.getContext()); + ArrayType *Int8Arr0Ty = ArrayType::get(Type::getInt8Ty(M.getContext()), 0); IntegerType *Int32Ty = Type::getInt32Ty(M.getContext()); PointerType *Int32PtrTy = PointerType::getUnqual(Int32Ty); IntegerType *Int64Ty = Type::getInt64Ty(M.getContext()); @@ -307,7 +354,8 @@ class LowerTypeTestsModule { Function *WeakInitializerFn = nullptr; - void exportTypeId(StringRef TypeId, const TypeIdLowering &TIL); + bool shouldExportConstantsAsAbsoluteSymbols(); + uint8_t *exportTypeId(StringRef TypeId, const TypeIdLowering &TIL); TypeIdLowering importTypeId(StringRef TypeId); void importTypeTest(CallInst *CI); void importFunction(Function *F, bool isDefinition); @@ -329,6 +377,7 @@ class LowerTypeTestsModule { unsigned getJumpTableEntrySize(); Type *getJumpTableEntryType(); void createJumpTableEntry(raw_ostream &AsmOS, raw_ostream &ConstraintOS, + Triple::ArchType JumpTableArch, SmallVectorImpl<Value *> &AsmArgs, Function *Dest); void verifyTypeMDNode(GlobalObject *GO, MDNode *Type); void buildBitSetsFromFunctions(ArrayRef<Metadata *> TypeIds, @@ -350,6 +399,7 @@ class LowerTypeTestsModule { public: LowerTypeTestsModule(Module &M, ModuleSummaryIndex *ExportSummary, const ModuleSummaryIndex *ImportSummary); + bool lower(); // Lower the module using the action and summary passed as command line @@ -385,11 +435,12 @@ struct LowerTypeTests : public ModulePass { } }; -} // anonymous namespace +} // end anonymous namespace + +char LowerTypeTests::ID = 0; INITIALIZE_PASS(LowerTypeTests, "lowertypetests", "Lower type metadata", false, false) -char LowerTypeTests::ID = 0; ModulePass * llvm::createLowerTypeTestsPass(ModuleSummaryIndex *ExportSummary, @@ -473,6 +524,8 @@ void LowerTypeTestsModule::allocateByteArrays() { BAI->MaskGlobal->replaceAllUsesWith( ConstantExpr::getIntToPtr(ConstantInt::get(Int8Ty, Mask), Int8PtrTy)); BAI->MaskGlobal->eraseFromParent(); + if (BAI->MaskPtr) + *BAI->MaskPtr = Mask; } Constant *ByteArrayConst = ConstantDataArray::get(M.getContext(), BAB.Bytes); @@ -634,6 +687,10 @@ Value *LowerTypeTestsModule::lowerTypeTestCall(Metadata *TypeId, CallInst *CI, Br->getMetadata(LLVMContext::MD_prof)); ReplaceInstWithInst(InitialBB->getTerminator(), NewBr); + // Update phis in Else resulting from InitialBB being split + for (auto &Phi : Else->phis()) + Phi.addIncoming(Phi.getIncomingValueForBlock(Then), InitialBB); + IRBuilder<> ThenB(CI); return createBitSetTest(ThenB, TIL, BitOffset); } @@ -720,13 +777,21 @@ void LowerTypeTestsModule::buildBitSetsFromGlobalVariables( } } +bool LowerTypeTestsModule::shouldExportConstantsAsAbsoluteSymbols() { + return (Arch == Triple::x86 || Arch == Triple::x86_64) && + ObjectFormat == Triple::ELF; +} + /// Export the given type identifier so that ThinLTO backends may import it. /// Type identifiers are exported by adding coarse-grained information about how /// to test the type identifier to the summary, and creating symbols in the /// object file (aliases and absolute symbols) containing fine-grained /// information about the type identifier. -void LowerTypeTestsModule::exportTypeId(StringRef TypeId, - const TypeIdLowering &TIL) { +/// +/// Returns a pointer to the location in which to store the bitmask, if +/// applicable. +uint8_t *LowerTypeTestsModule::exportTypeId(StringRef TypeId, + const TypeIdLowering &TIL) { TypeTestResolution &TTRes = ExportSummary->getOrInsertTypeIdSummary(TypeId).TTRes; TTRes.TheKind = TIL.TheKind; @@ -738,14 +803,21 @@ void LowerTypeTestsModule::exportTypeId(StringRef TypeId, GA->setVisibility(GlobalValue::HiddenVisibility); }; + auto ExportConstant = [&](StringRef Name, uint64_t &Storage, Constant *C) { + if (shouldExportConstantsAsAbsoluteSymbols()) + ExportGlobal(Name, ConstantExpr::getIntToPtr(C, Int8PtrTy)); + else + Storage = cast<ConstantInt>(C)->getZExtValue(); + }; + if (TIL.TheKind != TypeTestResolution::Unsat) ExportGlobal("global_addr", TIL.OffsetedGlobal); if (TIL.TheKind == TypeTestResolution::ByteArray || TIL.TheKind == TypeTestResolution::Inline || TIL.TheKind == TypeTestResolution::AllOnes) { - ExportGlobal("align", ConstantExpr::getIntToPtr(TIL.AlignLog2, Int8PtrTy)); - ExportGlobal("size_m1", ConstantExpr::getIntToPtr(TIL.SizeM1, Int8PtrTy)); + ExportConstant("align", TTRes.AlignLog2, TIL.AlignLog2); + ExportConstant("size_m1", TTRes.SizeM1, TIL.SizeM1); uint64_t BitSize = cast<ConstantInt>(TIL.SizeM1)->getZExtValue() + 1; if (TIL.TheKind == TypeTestResolution::Inline) @@ -756,12 +828,16 @@ void LowerTypeTestsModule::exportTypeId(StringRef TypeId, if (TIL.TheKind == TypeTestResolution::ByteArray) { ExportGlobal("byte_array", TIL.TheByteArray); - ExportGlobal("bit_mask", TIL.BitMask); + if (shouldExportConstantsAsAbsoluteSymbols()) + ExportGlobal("bit_mask", TIL.BitMask); + else + return &TTRes.BitMask; } if (TIL.TheKind == TypeTestResolution::Inline) - ExportGlobal("inline_bits", - ConstantExpr::getIntToPtr(TIL.InlineBits, Int8PtrTy)); + ExportConstant("inline_bits", TTRes.InlineBits, TIL.InlineBits); + + return nullptr; } LowerTypeTestsModule::TypeIdLowering @@ -774,16 +850,34 @@ LowerTypeTestsModule::importTypeId(StringRef TypeId) { TypeIdLowering TIL; TIL.TheKind = TTRes.TheKind; - auto ImportGlobal = [&](StringRef Name, unsigned AbsWidth) { - Constant *C = - M.getOrInsertGlobal(("__typeid_" + TypeId + "_" + Name).str(), Int8Ty); - auto *GV = dyn_cast<GlobalVariable>(C); - // We only need to set metadata if the global is newly created, in which - // case it would not have hidden visibility. - if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) + auto ImportGlobal = [&](StringRef Name) { + // Give the global a type of length 0 so that it is not assumed not to alias + // with any other global. + Constant *C = M.getOrInsertGlobal(("__typeid_" + TypeId + "_" + Name).str(), + Int8Arr0Ty); + if (auto *GV = dyn_cast<GlobalVariable>(C)) + GV->setVisibility(GlobalValue::HiddenVisibility); + C = ConstantExpr::getBitCast(C, Int8PtrTy); + return C; + }; + + auto ImportConstant = [&](StringRef Name, uint64_t Const, unsigned AbsWidth, + Type *Ty) { + if (!shouldExportConstantsAsAbsoluteSymbols()) { + Constant *C = + ConstantInt::get(isa<IntegerType>(Ty) ? Ty : Int64Ty, Const); + if (!isa<IntegerType>(Ty)) + C = ConstantExpr::getIntToPtr(C, Ty); + return C; + } + + Constant *C = ImportGlobal(Name); + auto *GV = cast<GlobalVariable>(C->stripPointerCasts()); + if (isa<IntegerType>(Ty)) + C = ConstantExpr::getPtrToInt(C, Ty); + if (GV->getMetadata(LLVMContext::MD_absolute_symbol)) return C; - GV->setVisibility(GlobalValue::HiddenVisibility); auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); @@ -792,30 +886,30 @@ LowerTypeTestsModule::importTypeId(StringRef TypeId) { }; if (AbsWidth == IntPtrTy->getBitWidth()) SetAbsRange(~0ull, ~0ull); // Full set. - else if (AbsWidth) + else SetAbsRange(0, 1ull << AbsWidth); return C; }; if (TIL.TheKind != TypeTestResolution::Unsat) - TIL.OffsetedGlobal = ImportGlobal("global_addr", 0); + TIL.OffsetedGlobal = ImportGlobal("global_addr"); if (TIL.TheKind == TypeTestResolution::ByteArray || TIL.TheKind == TypeTestResolution::Inline || TIL.TheKind == TypeTestResolution::AllOnes) { - TIL.AlignLog2 = ConstantExpr::getPtrToInt(ImportGlobal("align", 8), Int8Ty); - TIL.SizeM1 = ConstantExpr::getPtrToInt( - ImportGlobal("size_m1", TTRes.SizeM1BitWidth), IntPtrTy); + TIL.AlignLog2 = ImportConstant("align", TTRes.AlignLog2, 8, Int8Ty); + TIL.SizeM1 = + ImportConstant("size_m1", TTRes.SizeM1, TTRes.SizeM1BitWidth, IntPtrTy); } if (TIL.TheKind == TypeTestResolution::ByteArray) { - TIL.TheByteArray = ImportGlobal("byte_array", 0); - TIL.BitMask = ImportGlobal("bit_mask", 8); + TIL.TheByteArray = ImportGlobal("byte_array"); + TIL.BitMask = ImportConstant("bit_mask", TTRes.BitMask, 8, Int8PtrTy); } if (TIL.TheKind == TypeTestResolution::Inline) - TIL.InlineBits = ConstantExpr::getPtrToInt( - ImportGlobal("inline_bits", 1 << TTRes.SizeM1BitWidth), + TIL.InlineBits = ImportConstant( + "inline_bits", TTRes.InlineBits, 1 << TTRes.SizeM1BitWidth, TTRes.SizeM1BitWidth <= 5 ? Int32Ty : Int64Ty); return TIL; @@ -894,6 +988,7 @@ void LowerTypeTestsModule::lowerTypeTestCalls( BSI.print(dbgs()); }); + ByteArrayInfo *BAI = nullptr; TypeIdLowering TIL; TIL.OffsetedGlobal = ConstantExpr::getGetElementPtr( Int8Ty, CombinedGlobalAddr, ConstantInt::get(IntPtrTy, BSI.ByteOffset)), @@ -915,15 +1010,18 @@ void LowerTypeTestsModule::lowerTypeTestCalls( } else { TIL.TheKind = TypeTestResolution::ByteArray; ++NumByteArraysCreated; - ByteArrayInfo *BAI = createByteArray(BSI); + BAI = createByteArray(BSI); TIL.TheByteArray = BAI->ByteArray; TIL.BitMask = BAI->MaskGlobal; } TypeIdUserInfo &TIUI = TypeIdUsers[TypeId]; - if (TIUI.IsExported) - exportTypeId(cast<MDString>(TypeId)->getString(), TIL); + if (TIUI.IsExported) { + uint8_t *MaskPtr = exportTypeId(cast<MDString>(TypeId)->getString(), TIL); + if (BAI) + BAI->MaskPtr = MaskPtr; + } // Lower each call to llvm.type.test for this type identifier. for (CallInst *CI : TIUI.CallSites) { @@ -979,15 +1077,16 @@ unsigned LowerTypeTestsModule::getJumpTableEntrySize() { // constraints and arguments to AsmOS, ConstraintOS and AsmArgs. void LowerTypeTestsModule::createJumpTableEntry( raw_ostream &AsmOS, raw_ostream &ConstraintOS, - SmallVectorImpl<Value *> &AsmArgs, Function *Dest) { + Triple::ArchType JumpTableArch, SmallVectorImpl<Value *> &AsmArgs, + Function *Dest) { unsigned ArgIndex = AsmArgs.size(); - if (Arch == Triple::x86 || Arch == Triple::x86_64) { + if (JumpTableArch == Triple::x86 || JumpTableArch == Triple::x86_64) { AsmOS << "jmp ${" << ArgIndex << ":c}@plt\n"; AsmOS << "int3\nint3\nint3\n"; - } else if (Arch == Triple::arm || Arch == Triple::aarch64) { + } else if (JumpTableArch == Triple::arm || JumpTableArch == Triple::aarch64) { AsmOS << "b $" << ArgIndex << "\n"; - } else if (Arch == Triple::thumb) { + } else if (JumpTableArch == Triple::thumb) { AsmOS << "b.w $" << ArgIndex << "\n"; } else { report_fatal_error("Unsupported architecture for jump tables"); @@ -1074,6 +1173,46 @@ void LowerTypeTestsModule::replaceWeakDeclarationWithJumpTablePtr( PlaceholderFn->eraseFromParent(); } +static bool isThumbFunction(Function *F, Triple::ArchType ModuleArch) { + Attribute TFAttr = F->getFnAttribute("target-features"); + if (!TFAttr.hasAttribute(Attribute::None)) { + SmallVector<StringRef, 6> Features; + TFAttr.getValueAsString().split(Features, ','); + for (StringRef Feature : Features) { + if (Feature == "-thumb-mode") + return false; + else if (Feature == "+thumb-mode") + return true; + } + } + + return ModuleArch == Triple::thumb; +} + +// Each jump table must be either ARM or Thumb as a whole for the bit-test math +// to work. Pick one that matches the majority of members to minimize interop +// veneers inserted by the linker. +static Triple::ArchType +selectJumpTableArmEncoding(ArrayRef<GlobalTypeMember *> Functions, + Triple::ArchType ModuleArch) { + if (ModuleArch != Triple::arm && ModuleArch != Triple::thumb) + return ModuleArch; + + unsigned ArmCount = 0, ThumbCount = 0; + for (const auto GTM : Functions) { + if (!GTM->isDefinition()) { + // PLT stubs are always ARM. + ++ArmCount; + continue; + } + + Function *F = cast<Function>(GTM->getGlobal()); + ++(isThumbFunction(F, ModuleArch) ? ThumbCount : ArmCount); + } + + return ArmCount > ThumbCount ? Triple::arm : Triple::thumb; +} + void LowerTypeTestsModule::createJumpTable( Function *F, ArrayRef<GlobalTypeMember *> Functions) { std::string AsmStr, ConstraintStr; @@ -1081,8 +1220,10 @@ void LowerTypeTestsModule::createJumpTable( SmallVector<Value *, 16> AsmArgs; AsmArgs.reserve(Functions.size() * 2); + Triple::ArchType JumpTableArch = selectJumpTableArmEncoding(Functions, Arch); + for (unsigned I = 0; I != Functions.size(); ++I) - createJumpTableEntry(AsmOS, ConstraintOS, AsmArgs, + createJumpTableEntry(AsmOS, ConstraintOS, JumpTableArch, AsmArgs, cast<Function>(Functions[I]->getGlobal())); // Try to emit the jump table at the end of the text segment. @@ -1098,11 +1239,15 @@ void LowerTypeTestsModule::createJumpTable( // Luckily, this function does not get any prologue even without the // attribute. if (OS != Triple::Win32) - F->addFnAttr(llvm::Attribute::Naked); - // Thumb jump table assembly needs Thumb2. The following attribute is added by - // Clang for -march=armv7. - if (Arch == Triple::thumb) + F->addFnAttr(Attribute::Naked); + if (JumpTableArch == Triple::arm) + F->addFnAttr("target-features", "-thumb-mode"); + if (JumpTableArch == Triple::thumb) { + F->addFnAttr("target-features", "+thumb-mode"); + // Thumb jump table assembly needs Thumb2. The following attribute is added + // by Clang for -march=armv7. F->addFnAttr("target-cpu", "cortex-a8"); + } BasicBlock *BB = BasicBlock::Create(M.getContext(), "entry", F); IRBuilder<> IRB(BB); @@ -1256,7 +1401,7 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsNative( FAlias->takeName(F); if (FAlias->hasName()) F->setName(FAlias->getName() + ".cfi"); - F->replaceAllUsesWith(FAlias); + F->replaceUsesExceptBlockAddr(FAlias); } if (!F->isDeclarationForLinker()) F->setLinkage(GlobalValue::InternalLinkage); @@ -1303,7 +1448,7 @@ void LowerTypeTestsModule::buildBitSetsFromFunctionsWASM( void LowerTypeTestsModule::buildBitSetsFromDisjointSet( ArrayRef<Metadata *> TypeIds, ArrayRef<GlobalTypeMember *> Globals) { - llvm::DenseMap<Metadata *, uint64_t> TypeIdIndices; + DenseMap<Metadata *, uint64_t> TypeIdIndices; for (unsigned I = 0; I != TypeIds.size(); ++I) TypeIdIndices[TypeIds[I]] = I; @@ -1335,38 +1480,25 @@ void LowerTypeTestsModule::buildBitSetsFromDisjointSet( for (auto &&MemSet : TypeMembers) GLB.addFragment(MemSet); - // Build the bitsets from this disjoint set. - if (Globals.empty() || isa<GlobalVariable>(Globals[0]->getGlobal())) { - // Build a vector of global variables with the computed layout. - std::vector<GlobalTypeMember *> OrderedGVs(Globals.size()); - auto OGI = OrderedGVs.begin(); - for (auto &&F : GLB.Fragments) { - for (auto &&Offset : F) { - auto GV = dyn_cast<GlobalVariable>(Globals[Offset]->getGlobal()); - if (!GV) - report_fatal_error("Type identifier may not contain both global " - "variables and functions"); - *OGI++ = Globals[Offset]; - } - } - - buildBitSetsFromGlobalVariables(TypeIds, OrderedGVs); - } else { - // Build a vector of functions with the computed layout. - std::vector<GlobalTypeMember *> OrderedFns(Globals.size()); - auto OFI = OrderedFns.begin(); - for (auto &&F : GLB.Fragments) { - for (auto &&Offset : F) { - auto Fn = dyn_cast<Function>(Globals[Offset]->getGlobal()); - if (!Fn) - report_fatal_error("Type identifier may not contain both global " - "variables and functions"); - *OFI++ = Globals[Offset]; - } + // Build a vector of globals with the computed layout. + bool IsGlobalSet = + Globals.empty() || isa<GlobalVariable>(Globals[0]->getGlobal()); + std::vector<GlobalTypeMember *> OrderedGTMs(Globals.size()); + auto OGTMI = OrderedGTMs.begin(); + for (auto &&F : GLB.Fragments) { + for (auto &&Offset : F) { + if (IsGlobalSet != isa<GlobalVariable>(Globals[Offset]->getGlobal())) + report_fatal_error("Type identifier may not contain both global " + "variables and functions"); + *OGTMI++ = Globals[Offset]; } - - buildBitSetsFromFunctions(TypeIds, OrderedFns); } + + // Build the bitsets from this disjoint set. + if (IsGlobalSet) + buildBitSetsFromGlobalVariables(TypeIds, OrderedGTMs); + else + buildBitSetsFromFunctions(TypeIds, OrderedGTMs); } /// Lower all type tests in this module. @@ -1457,8 +1589,8 @@ bool LowerTypeTestsModule::lower() { // Equivalence class set containing type identifiers and the globals that // reference them. This is used to partition the set of type identifiers in // the module into disjoint sets. - typedef EquivalenceClasses<PointerUnion<GlobalTypeMember *, Metadata *>> - GlobalClassesTy; + using GlobalClassesTy = + EquivalenceClasses<PointerUnion<GlobalTypeMember *, Metadata *>>; GlobalClassesTy GlobalClasses; // Verify the type metadata and build a few data structures to let us @@ -1473,7 +1605,7 @@ bool LowerTypeTestsModule::lower() { unsigned Index; std::vector<GlobalTypeMember *> RefGlobals; }; - llvm::DenseMap<Metadata *, TIInfo> TypeIdInfo; + DenseMap<Metadata *, TIInfo> TypeIdInfo; unsigned I = 0; SmallVector<MDNode *, 2> Types; @@ -1512,14 +1644,27 @@ bool LowerTypeTestsModule::lower() { FunctionType::get(Type::getVoidTy(M.getContext()), false), GlobalVariable::ExternalLinkage, FunctionName, &M); - if (Linkage == CFL_Definition) - F->eraseMetadata(LLVMContext::MD_type); + // If the function is available_externally, remove its definition so + // that it is handled the same way as a declaration. Later we will try + // to create an alias using this function's linkage, which will fail if + // the linkage is available_externally. This will also result in us + // following the code path below to replace the type metadata. + if (F->hasAvailableExternallyLinkage()) { + F->setLinkage(GlobalValue::ExternalLinkage); + F->deleteBody(); + F->setComdat(nullptr); + F->clearMetadata(); + } + // If the function in the full LTO module is a declaration, replace its + // type metadata with the type metadata we found in cfi.functions. That + // metadata is presumed to be more accurate than the metadata attached + // to the declaration. if (F->isDeclaration()) { if (Linkage == CFL_WeakDeclaration) F->setLinkage(GlobalValue::ExternalWeakLinkage); - SmallVector<MDNode *, 2> Types; + F->eraseMetadata(LLVMContext::MD_type); for (unsigned I = 2; I < FuncMD->getNumOperands(); ++I) F->addMetadata(LLVMContext::MD_type, *cast<MDNode>(FuncMD->getOperand(I).get())); @@ -1548,7 +1693,7 @@ bool LowerTypeTestsModule::lower() { GlobalTypeMember::create(Alloc, &GO, IsDefinition, IsExported, Types); for (MDNode *Type : Types) { verifyTypeMDNode(&GO, Type); - auto &Info = TypeIdInfo[cast<MDNode>(Type)->getOperand(1)]; + auto &Info = TypeIdInfo[Type->getOperand(1)]; Info.Index = ++I; Info.RefGlobals.push_back(GTM); } @@ -1596,12 +1741,12 @@ bool LowerTypeTestsModule::lower() { for (auto &P : *ExportSummary) { for (auto &S : P.second.SummaryList) { - auto *FS = dyn_cast<FunctionSummary>(S.get()); - if (!FS || !ExportSummary->isGlobalValueLive(FS)) + if (!ExportSummary->isGlobalValueLive(S.get())) continue; - for (GlobalValue::GUID G : FS->type_tests()) - for (Metadata *MD : MetadataByGUID[G]) - AddTypeIdUse(MD).IsExported = true; + if (auto *FS = dyn_cast<FunctionSummary>(S->getBaseObject())) + for (GlobalValue::GUID G : FS->type_tests()) + for (Metadata *MD : MetadataByGUID[G]) + AddTypeIdUse(MD).IsExported = true; } } } diff --git a/lib/Transforms/IPO/MergeFunctions.cpp b/lib/Transforms/IPO/MergeFunctions.cpp index 0e478ba607be..76b90391fbb1 100644 --- a/lib/Transforms/IPO/MergeFunctions.cpp +++ b/lib/Transforms/IPO/MergeFunctions.cpp @@ -89,28 +89,45 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/IR/ValueMap.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/FunctionComparator.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <set> +#include <utility> #include <vector> using namespace llvm; @@ -119,7 +136,6 @@ using namespace llvm; STATISTIC(NumFunctionsMerged, "Number of functions merged"); STATISTIC(NumThunksWritten, "Number of thunks generated"); -STATISTIC(NumAliasesWritten, "Number of aliases generated"); STATISTIC(NumDoubleWeak, "Number of new functions created"); static cl::opt<unsigned> NumFunctionsForSanityCheck( @@ -154,10 +170,12 @@ namespace { class FunctionNode { mutable AssertingVH<Function> F; FunctionComparator::FunctionHash Hash; + public: // Note the hash is recalculated potentially multiple times, but it is cheap. FunctionNode(Function *F) : F(F), Hash(FunctionComparator::functionHash(*F)) {} + Function *getFunc() const { return F; } FunctionComparator::FunctionHash getHash() const { return Hash; } @@ -174,13 +192,12 @@ public: /// by considering all pointer types to be equivalent. Once identified, /// MergeFunctions will fold them by replacing a call to one to a call to a /// bitcast of the other. -/// class MergeFunctions : public ModulePass { public: static char ID; + MergeFunctions() - : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)), FNodesInTree(), - HasGlobalAliases(false) { + : ModulePass(ID), FnTree(FunctionNodeCmp(&GlobalNumbers)) { initializeMergeFunctionsPass(*PassRegistry::getPassRegistry()); } @@ -191,8 +208,10 @@ private: // not need to become larger with another pointer. class FunctionNodeCmp { GlobalNumberState* GlobalNumbers; + public: FunctionNodeCmp(GlobalNumberState* GN) : GlobalNumbers(GN) {} + bool operator()(const FunctionNode &LHS, const FunctionNode &RHS) const { // Order first by hashes, then full function comparison. if (LHS.getHash() != RHS.getHash()) @@ -201,7 +220,7 @@ private: return FCmp.compare() == -1; } }; - typedef std::set<FunctionNode, FunctionNodeCmp> FnTreeType; + using FnTreeType = std::set<FunctionNode, FunctionNodeCmp>; GlobalNumberState GlobalNumbers; @@ -209,9 +228,9 @@ private: /// analyzed again. std::vector<WeakTrackingVH> Deferred; +#ifndef NDEBUG /// Checks the rules of order relation introduced among functions set. /// Returns true, if sanity check has been passed, and false if failed. -#ifndef NDEBUG bool doSanityCheck(std::vector<WeakTrackingVH> &Worklist); #endif @@ -236,9 +255,6 @@ private: /// again. void mergeTwoFunctions(Function *F, Function *G); - /// Replace G with a thunk or an alias to F. Deletes G. - void writeThunkOrAlias(Function *F, Function *G); - /// Fill PDIUnrelatedWL with instructions from the entry block that are /// unrelated to parameter related debug info. void filterInstsUnrelatedToPDI(BasicBlock *GEntryBlock, @@ -256,29 +272,25 @@ private: /// delete G. void writeThunk(Function *F, Function *G); - /// Replace G with an alias to F. Deletes G. - void writeAlias(Function *F, Function *G); - /// Replace function F with function G in the function tree. void replaceFunctionInTree(const FunctionNode &FN, Function *G); /// The set of all distinct functions. Use the insert() and remove() methods /// to modify it. The map allows efficient lookup and deferring of Functions. FnTreeType FnTree; + // Map functions to the iterators of the FunctionNode which contains them // in the FnTree. This must be updated carefully whenever the FnTree is // modified, i.e. in insert(), remove(), and replaceFunctionInTree(), to avoid // dangling iterators into FnTree. The invariant that preserves this is that // there is exactly one mapping F -> FN for each FunctionNode FN in FnTree. ValueMap<Function*, FnTreeType::iterator> FNodesInTree; - - /// Whether or not the target supports global aliases. - bool HasGlobalAliases; }; } // end anonymous namespace char MergeFunctions::ID = 0; + INITIALIZE_PASS(MergeFunctions, "mergefunc", "Merge Functions", false, false) ModulePass *llvm::createMergeFunctionsPass() { @@ -454,19 +466,6 @@ void MergeFunctions::replaceDirectCallers(Function *Old, Function *New) { } } -// Replace G with an alias to F if possible, or else a thunk to F. Deletes G. -void MergeFunctions::writeThunkOrAlias(Function *F, Function *G) { - if (HasGlobalAliases && G->hasGlobalUnnamedAddr()) { - if (G->hasExternalLinkage() || G->hasLocalLinkage() || - G->hasWeakLinkage()) { - writeAlias(F, G); - return; - } - } - - writeThunk(F, G); -} - // Helper for writeThunk, // Selects proper bitcast operation, // but a bit simpler then CastInst::getCastOpcode. @@ -499,7 +498,6 @@ static Value *createCast(IRBuilder<> &Builder, Value *V, Type *DestTy) { // parameter debug info, from the entry block. void MergeFunctions::eraseInstsUnrelatedToPDI( std::vector<Instruction *> &PDIUnrelatedWL) { - DEBUG(dbgs() << " Erasing instructions (in reverse order of appearance in " "entry block) unrelated to parameter debug info from entry " "block: {\n"); @@ -517,7 +515,6 @@ void MergeFunctions::eraseInstsUnrelatedToPDI( // Reduce G to its entry block. void MergeFunctions::eraseTail(Function *G) { - std::vector<BasicBlock *> WorklistBB; for (Function::iterator BBI = std::next(G->begin()), BBE = G->end(); BBI != BBE; ++BBI) { @@ -542,7 +539,6 @@ void MergeFunctions::eraseTail(Function *G) { // PDIUnrelatedWL with such instructions. void MergeFunctions::filterInstsUnrelatedToPDI( BasicBlock *GEntryBlock, std::vector<Instruction *> &PDIUnrelatedWL) { - std::set<Instruction *> PDIRelated; for (BasicBlock::iterator BI = GEntryBlock->begin(), BIE = GEntryBlock->end(); BI != BIE; ++BI) { @@ -652,9 +648,18 @@ void MergeFunctions::filterInstsUnrelatedToPDI( // call sites to point to F even when within the same translation unit. void MergeFunctions::writeThunk(Function *F, Function *G) { if (!G->isInterposable() && !MergeFunctionsPDI) { - // Redirect direct callers of G to F. (See note on MergeFunctionsPDI - // above). - replaceDirectCallers(G, F); + if (G->hasGlobalUnnamedAddr()) { + // G might have been a key in our GlobalNumberState, and it's illegal + // to replace a key in ValueMap<GlobalValue *> with a non-global. + GlobalNumbers.erase(G); + // If G's address is not significant, replace it entirely. + Constant *BitcastF = ConstantExpr::getBitCast(F, G->getType()); + G->replaceAllUsesWith(BitcastF); + } else { + // Redirect direct callers of G to F. (See note on MergeFunctionsPDI + // above). + replaceDirectCallers(G, F); + } } // If G was internal then we may have replaced all uses of G with F. If so, @@ -665,6 +670,16 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { return; } + // Don't merge tiny functions using a thunk, since it can just end up + // making the function larger. + if (F->size() == 1) { + if (F->front().size() <= 2) { + DEBUG(dbgs() << "writeThunk: " << F->getName() + << " is too small to bother creating a thunk for\n"); + return; + } + } + BasicBlock *GEntryBlock = nullptr; std::vector<Instruction *> PDIUnrelatedWL; BasicBlock *BB = nullptr; @@ -691,7 +706,7 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { SmallVector<Value *, 16> Args; unsigned i = 0; FunctionType *FFTy = F->getFunctionType(); - for (Argument & AI : H->args()) { + for (Argument &AI : H->args()) { Args.push_back(createCast(Builder, &AI, FFTy->getParamType(i))); ++i; } @@ -734,20 +749,6 @@ void MergeFunctions::writeThunk(Function *F, Function *G) { ++NumThunksWritten; } -// Replace G with an alias to F and delete G. -void MergeFunctions::writeAlias(Function *F, Function *G) { - auto *GA = GlobalAlias::create(G->getLinkage(), "", F); - F->setAlignment(std::max(F->getAlignment(), G->getAlignment())); - GA->takeName(G); - GA->setVisibility(G->getVisibility()); - removeUsers(G); - G->replaceAllUsesWith(GA); - G->eraseFromParent(); - - DEBUG(dbgs() << "writeAlias: " << GA->getName() << '\n'); - ++NumAliasesWritten; -} - // Merge two equivalent functions. Upon completion, Function G is deleted. void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { if (F->isInterposable()) { @@ -763,19 +764,14 @@ void MergeFunctions::mergeTwoFunctions(Function *F, Function *G) { unsigned MaxAlignment = std::max(G->getAlignment(), H->getAlignment()); - if (HasGlobalAliases) { - writeAlias(F, G); - writeAlias(F, H); - } else { - writeThunk(F, G); - writeThunk(F, H); - } + writeThunk(F, G); + writeThunk(F, H); F->setAlignment(MaxAlignment); F->setLinkage(GlobalValue::PrivateLinkage); ++NumDoubleWeak; } else { - writeThunkOrAlias(F, G); + writeThunk(F, G); } ++NumFunctionsMerged; @@ -816,18 +812,6 @@ bool MergeFunctions::insert(Function *NewFunction) { const FunctionNode &OldF = *Result.first; - // Don't merge tiny functions, since it can just end up making the function - // larger. - // FIXME: Should still merge them if they are unnamed_addr and produce an - // alias. - if (NewFunction->size() == 1) { - if (NewFunction->front().size() <= 2) { - DEBUG(dbgs() << NewFunction->getName() - << " is to small to bother merging\n"); - return false; - } - } - // Impose a total order (by name) on the replacement of functions. This is // important when operating on more than one module independently to prevent // cycles of thunks calling each other when the modules are linked together. diff --git a/lib/Transforms/IPO/PartialInlining.cpp b/lib/Transforms/IPO/PartialInlining.cpp index 8840435af642..683655f1f68b 100644 --- a/lib/Transforms/IPO/PartialInlining.cpp +++ b/lib/Transforms/IPO/PartialInlining.cpp @@ -13,45 +13,126 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/PartialInlining.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" -#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DebugLoc.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" +#include "llvm/IR/User.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/CodeExtractor.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <functional> +#include <iterator> +#include <memory> +#include <tuple> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "partial-inlining" STATISTIC(NumPartialInlined, "Number of callsites functions partially inlined into."); +STATISTIC(NumColdOutlinePartialInlined, "Number of times functions with " + "cold outlined regions were partially " + "inlined into its caller(s)."); +STATISTIC(NumColdRegionsFound, + "Number of cold single entry/exit regions found."); +STATISTIC(NumColdRegionsOutlined, + "Number of cold single entry/exit regions outlined."); // Command line option to disable partial-inlining. The default is false: static cl::opt<bool> DisablePartialInlining("disable-partial-inlining", cl::init(false), - cl::Hidden, cl::desc("Disable partial ininling")); + cl::Hidden, cl::desc("Disable partial inlining")); +// Command line option to disable multi-region partial-inlining. The default is +// false: +static cl::opt<bool> DisableMultiRegionPartialInline( + "disable-mr-partial-inlining", cl::init(false), cl::Hidden, + cl::desc("Disable multi-region partial inlining")); + +// Command line option to force outlining in regions with live exit variables. +// The default is false: +static cl::opt<bool> + ForceLiveExit("pi-force-live-exit-outline", cl::init(false), cl::Hidden, + cl::desc("Force outline regions with live exits")); + +// Command line option to enable marking outline functions with Cold Calling +// Convention. The default is false: +static cl::opt<bool> + MarkOutlinedColdCC("pi-mark-coldcc", cl::init(false), cl::Hidden, + cl::desc("Mark outline function calls with ColdCC")); + +#ifndef NDEBUG +// Command line option to debug partial-inlining. The default is none: +static cl::opt<bool> TracePartialInlining("trace-partial-inlining", + cl::init(false), cl::Hidden, + cl::desc("Trace partial inlining.")); +#endif + // This is an option used by testing: static cl::opt<bool> SkipCostAnalysis("skip-partial-inlining-cost-analysis", cl::init(false), cl::ZeroOrMore, cl::ReallyHidden, cl::desc("Skip Cost Analysis")); +// Used to determine if a cold region is worth outlining based on +// its inlining cost compared to the original function. Default is set at 10%. +// ie. if the cold region reduces the inlining cost of the original function by +// at least 10%. +static cl::opt<float> MinRegionSizeRatio( + "min-region-size-ratio", cl::init(0.1), cl::Hidden, + cl::desc("Minimum ratio comparing relative sizes of each " + "outline candidate and original function")); +// Used to tune the minimum number of execution counts needed in the predecessor +// block to the cold edge. ie. confidence interval. +static cl::opt<unsigned> + MinBlockCounterExecution("min-block-execution", cl::init(100), cl::Hidden, + cl::desc("Minimum block executions to consider " + "its BranchProbabilityInfo valid")); +// Used to determine when an edge is considered cold. Default is set to 10%. ie. +// if the branch probability is 10% or less, then it is deemed as 'cold'. +static cl::opt<float> ColdBranchRatio( + "cold-branch-ratio", cl::init(0.1), cl::Hidden, + cl::desc("Minimum BranchProbability to consider a region cold.")); static cl::opt<unsigned> MaxNumInlineBlocks( "max-num-inline-blocks", cl::init(5), cl::Hidden, - cl::desc("Max Number of Blocks To be Partially Inlined")); + cl::desc("Max number of blocks to be partially inlined")); // Command line option to set the maximum number of partial inlining allowed // for the module. The default value of -1 means no limit. @@ -75,9 +156,8 @@ static cl::opt<unsigned> ExtraOutliningPenalty( namespace { struct FunctionOutliningInfo { - FunctionOutliningInfo() - : Entries(), ReturnBlock(nullptr), NonReturnBlock(nullptr), - ReturnBlockPreds() {} + FunctionOutliningInfo() = default; + // Returns the number of blocks to be inlined including all blocks // in Entries and one return block. unsigned GetNumInlinedBlocks() const { return Entries.size() + 1; } @@ -85,30 +165,69 @@ struct FunctionOutliningInfo { // A set of blocks including the function entry that guard // the region to be outlined. SmallVector<BasicBlock *, 4> Entries; + // The return block that is not included in the outlined region. - BasicBlock *ReturnBlock; + BasicBlock *ReturnBlock = nullptr; + // The dominating block of the region to be outlined. - BasicBlock *NonReturnBlock; + BasicBlock *NonReturnBlock = nullptr; + // The set of blocks in Entries that that are predecessors to ReturnBlock SmallVector<BasicBlock *, 4> ReturnBlockPreds; }; +struct FunctionOutliningMultiRegionInfo { + FunctionOutliningMultiRegionInfo() + : ORI() {} + + // Container for outline regions + struct OutlineRegionInfo { + OutlineRegionInfo(SmallVector<BasicBlock *, 8> Region, + BasicBlock *EntryBlock, BasicBlock *ExitBlock, + BasicBlock *ReturnBlock) + : Region(Region), EntryBlock(EntryBlock), ExitBlock(ExitBlock), + ReturnBlock(ReturnBlock) {} + SmallVector<BasicBlock *, 8> Region; + BasicBlock *EntryBlock; + BasicBlock *ExitBlock; + BasicBlock *ReturnBlock; + }; + + SmallVector<OutlineRegionInfo, 4> ORI; +}; + struct PartialInlinerImpl { + PartialInlinerImpl( std::function<AssumptionCache &(Function &)> *GetAC, std::function<TargetTransformInfo &(Function &)> *GTTI, Optional<function_ref<BlockFrequencyInfo &(Function &)>> GBFI, - ProfileSummaryInfo *ProfSI) - : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI) {} + ProfileSummaryInfo *ProfSI, + std::function<OptimizationRemarkEmitter &(Function &)> *GORE) + : GetAssumptionCache(GetAC), GetTTI(GTTI), GetBFI(GBFI), PSI(ProfSI), + GetORE(GORE) {} + bool run(Module &M); - Function *unswitchFunction(Function *F); + // Main part of the transformation that calls helper functions to find + // outlining candidates, clone & outline the function, and attempt to + // partially inline the resulting function. Returns true if + // inlining was successful, false otherwise. Also returns the outline + // function (only if we partially inlined early returns) as there is a + // possibility to further "peel" early return statements that were left in the + // outline function due to code size. + std::pair<bool, Function *> unswitchFunction(Function *F); // This class speculatively clones the the function to be partial inlined. // At the end of partial inlining, the remaining callsites to the cloned // function that are not partially inlined will be fixed up to reference // the original function, and the cloned function will be erased. struct FunctionCloner { - FunctionCloner(Function *F, FunctionOutliningInfo *OI); + // Two constructors, one for single region outlining, the other for + // multi-region outlining. + FunctionCloner(Function *F, FunctionOutliningInfo *OI, + OptimizationRemarkEmitter &ORE); + FunctionCloner(Function *F, FunctionOutliningMultiRegionInfo *OMRI, + OptimizationRemarkEmitter &ORE); ~FunctionCloner(); // Prepare for function outlining: making sure there is only @@ -116,20 +235,34 @@ struct PartialInlinerImpl { // the return block. void NormalizeReturnBlock(); - // Do function outlining: - Function *doFunctionOutlining(); + // Do function outlining for cold regions. + bool doMultiRegionFunctionOutlining(); + // Do function outlining for region after early return block(s). + // NOTE: For vararg functions that do the vararg handling in the outlined + // function, we temporarily generate IR that does not properly + // forward varargs to the outlined function. Calling InlineFunction + // will update calls to the outlined functions to properly forward + // the varargs. + Function *doSingleRegionFunctionOutlining(); Function *OrigFunc = nullptr; Function *ClonedFunc = nullptr; - Function *OutlinedFunc = nullptr; - BasicBlock *OutliningCallBB = nullptr; + + typedef std::pair<Function *, BasicBlock *> FuncBodyCallerPair; + // Keep track of Outlined Functions and the basic block they're called from. + SmallVector<FuncBodyCallerPair, 4> OutlinedFunctions; + // ClonedFunc is inlined in one of its callers after function // outlining. bool IsFunctionInlined = false; // The cost of the region to be outlined. int OutlinedRegionCost = 0; + // ClonedOI is specific to outlining non-early return blocks. std::unique_ptr<FunctionOutliningInfo> ClonedOI = nullptr; + // ClonedOMRI is specific to outlining cold regions. + std::unique_ptr<FunctionOutliningMultiRegionInfo> ClonedOMRI = nullptr; std::unique_ptr<BlockFrequencyInfo> ClonedFuncBFI = nullptr; + OptimizationRemarkEmitter &ORE; }; private: @@ -138,6 +271,7 @@ private: std::function<TargetTransformInfo &(Function &)> *GetTTI; Optional<function_ref<BlockFrequencyInfo &(Function &)>> GetBFI; ProfileSummaryInfo *PSI; + std::function<OptimizationRemarkEmitter &(Function &)> *GetORE; // Return the frequency of the OutlininingBB relative to F's entry point. // The result is no larger than 1 and is represented using BP. @@ -148,8 +282,7 @@ private: // Return true if the callee of CS should be partially inlined with // profit. bool shouldPartialInline(CallSite CS, FunctionCloner &Cloner, - BlockFrequency WeightedOutliningRcost, - OptimizationRemarkEmitter &ORE); + BlockFrequency WeightedOutliningRcost); // Try to inline DuplicateFunction (cloned from F with call to // the OutlinedFunction into its callers. Return true @@ -196,17 +329,20 @@ private: // - The second value is the estimated size of the new call sequence in // basic block Cloner.OutliningCallBB; std::tuple<int, int> computeOutliningCosts(FunctionCloner &Cloner); + // Compute the 'InlineCost' of block BB. InlineCost is a proxy used to // approximate both the size and runtime cost (Note that in the current // inline cost analysis, there is no clear distinction there either). static int computeBBInlineCost(BasicBlock *BB); std::unique_ptr<FunctionOutliningInfo> computeOutliningInfo(Function *F); - + std::unique_ptr<FunctionOutliningMultiRegionInfo> + computeOutliningColdRegionsInfo(Function *F); }; struct PartialInlinerLegacyPass : public ModulePass { static char ID; // Pass identification, replacement for typeid + PartialInlinerLegacyPass() : ModulePass(ID) { initializePartialInlinerLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -216,6 +352,7 @@ struct PartialInlinerLegacyPass : public ModulePass { AU.addRequired<ProfileSummaryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } + bool runOnModule(Module &M) override { if (skipModule(M)) return false; @@ -225,6 +362,7 @@ struct PartialInlinerLegacyPass : public ModulePass { &getAnalysis<TargetTransformInfoWrapperPass>(); ProfileSummaryInfo *PSI = getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + std::unique_ptr<OptimizationRemarkEmitter> UPORE; std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&ACT](Function &F) -> AssumptionCache & { @@ -236,9 +374,185 @@ struct PartialInlinerLegacyPass : public ModulePass { return TTIWP->getTTI(F); }; - return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, None, PSI).run(M); + std::function<OptimizationRemarkEmitter &(Function &)> GetORE = + [&UPORE](Function &F) -> OptimizationRemarkEmitter & { + UPORE.reset(new OptimizationRemarkEmitter(&F)); + return *UPORE.get(); + }; + + return PartialInlinerImpl(&GetAssumptionCache, &GetTTI, NoneType::None, PSI, + &GetORE) + .run(M); } }; + +} // end anonymous namespace + +std::unique_ptr<FunctionOutliningMultiRegionInfo> +PartialInlinerImpl::computeOutliningColdRegionsInfo(Function *F) { + BasicBlock *EntryBlock = &F->front(); + + DominatorTree DT(*F); + LoopInfo LI(DT); + BranchProbabilityInfo BPI(*F, LI); + std::unique_ptr<BlockFrequencyInfo> ScopedBFI; + BlockFrequencyInfo *BFI; + if (!GetBFI) { + ScopedBFI.reset(new BlockFrequencyInfo(*F, BPI, LI)); + BFI = ScopedBFI.get(); + } else + BFI = &(*GetBFI)(*F); + + auto &ORE = (*GetORE)(*F); + + // Return if we don't have profiling information. + if (!PSI->hasInstrumentationProfile()) + return std::unique_ptr<FunctionOutliningMultiRegionInfo>(); + + std::unique_ptr<FunctionOutliningMultiRegionInfo> OutliningInfo = + llvm::make_unique<FunctionOutliningMultiRegionInfo>(); + + auto IsSingleEntry = [](SmallVectorImpl<BasicBlock *> &BlockList) { + BasicBlock *Dom = BlockList.front(); + return BlockList.size() > 1 && + std::distance(pred_begin(Dom), pred_end(Dom)) == 1; + }; + + auto IsSingleExit = + [&ORE](SmallVectorImpl<BasicBlock *> &BlockList) -> BasicBlock * { + BasicBlock *ExitBlock = nullptr; + for (auto *Block : BlockList) { + for (auto SI = succ_begin(Block); SI != succ_end(Block); ++SI) { + if (!is_contained(BlockList, *SI)) { + if (ExitBlock) { + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "MultiExitRegion", + &SI->front()) + << "Region dominated by " + << ore::NV("Block", BlockList.front()->getName()) + << " has more than one region exit edge."; + }); + return nullptr; + } else + ExitBlock = Block; + } + } + } + return ExitBlock; + }; + + auto BBProfileCount = [BFI](BasicBlock *BB) { + return BFI->getBlockProfileCount(BB) + ? BFI->getBlockProfileCount(BB).getValue() + : 0; + }; + + // Use the same computeBBInlineCost function to compute the cost savings of + // the outlining the candidate region. + int OverallFunctionCost = 0; + for (auto &BB : *F) + OverallFunctionCost += computeBBInlineCost(&BB); + +#ifndef NDEBUG + if (TracePartialInlining) + dbgs() << "OverallFunctionCost = " << OverallFunctionCost << "\n"; +#endif + int MinOutlineRegionCost = + static_cast<int>(OverallFunctionCost * MinRegionSizeRatio); + BranchProbability MinBranchProbability( + static_cast<int>(ColdBranchRatio * MinBlockCounterExecution), + MinBlockCounterExecution); + bool ColdCandidateFound = false; + BasicBlock *CurrEntry = EntryBlock; + std::vector<BasicBlock *> DFS; + DenseMap<BasicBlock *, bool> VisitedMap; + DFS.push_back(CurrEntry); + VisitedMap[CurrEntry] = true; + // Use Depth First Search on the basic blocks to find CFG edges that are + // considered cold. + // Cold regions considered must also have its inline cost compared to the + // overall inline cost of the original function. The region is outlined only + // if it reduced the inline cost of the function by 'MinOutlineRegionCost' or + // more. + while (!DFS.empty()) { + auto *thisBB = DFS.back(); + DFS.pop_back(); + // Only consider regions with predecessor blocks that are considered + // not-cold (default: part of the top 99.99% of all block counters) + // AND greater than our minimum block execution count (default: 100). + if (PSI->isColdBB(thisBB, BFI) || + BBProfileCount(thisBB) < MinBlockCounterExecution) + continue; + for (auto SI = succ_begin(thisBB); SI != succ_end(thisBB); ++SI) { + if (VisitedMap[*SI]) + continue; + VisitedMap[*SI] = true; + DFS.push_back(*SI); + // If branch isn't cold, we skip to the next one. + BranchProbability SuccProb = BPI.getEdgeProbability(thisBB, *SI); + if (SuccProb > MinBranchProbability) + continue; +#ifndef NDEBUG + if (TracePartialInlining) { + dbgs() << "Found cold edge: " << thisBB->getName() << "->" + << (*SI)->getName() << "\nBranch Probability = " << SuccProb + << "\n"; + } +#endif + SmallVector<BasicBlock *, 8> DominateVector; + DT.getDescendants(*SI, DominateVector); + // We can only outline single entry regions (for now). + if (!IsSingleEntry(DominateVector)) + continue; + BasicBlock *ExitBlock = nullptr; + // We can only outline single exit regions (for now). + if (!(ExitBlock = IsSingleExit(DominateVector))) + continue; + int OutlineRegionCost = 0; + for (auto *BB : DominateVector) + OutlineRegionCost += computeBBInlineCost(BB); + +#ifndef NDEBUG + if (TracePartialInlining) + dbgs() << "OutlineRegionCost = " << OutlineRegionCost << "\n"; +#endif + + if (OutlineRegionCost < MinOutlineRegionCost) { + ORE.emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", + &SI->front()) + << ore::NV("Callee", F) << " inline cost-savings smaller than " + << ore::NV("Cost", MinOutlineRegionCost); + }); + continue; + } + // For now, ignore blocks that belong to a SISE region that is a + // candidate for outlining. In the future, we may want to look + // at inner regions because the outer region may have live-exit + // variables. + for (auto *BB : DominateVector) + VisitedMap[BB] = true; + // ReturnBlock here means the block after the outline call + BasicBlock *ReturnBlock = ExitBlock->getSingleSuccessor(); + // assert(ReturnBlock && "ReturnBlock is NULL somehow!"); + FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegInfo( + DominateVector, DominateVector.front(), ExitBlock, ReturnBlock); + RegInfo.Region = DominateVector; + OutliningInfo->ORI.push_back(RegInfo); +#ifndef NDEBUG + if (TracePartialInlining) { + dbgs() << "Found Cold Candidate starting at block: " + << DominateVector.front()->getName() << "\n"; + } +#endif + ColdCandidateFound = true; + NumColdRegionsFound++; + } + } + if (ColdCandidateFound) + return OutliningInfo; + else + return std::unique_ptr<FunctionOutliningMultiRegionInfo>(); } std::unique_ptr<FunctionOutliningInfo> @@ -319,7 +633,6 @@ PartialInlinerImpl::computeOutliningInfo(Function *F) { OutliningInfo->Entries.push_back(CurrEntry); CurrEntry = OtherSucc; - } while (true); if (!CandidateFound) @@ -413,15 +726,19 @@ static bool hasProfileData(Function *F, FunctionOutliningInfo *OI) { BranchProbability PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) { - + BasicBlock *OutliningCallBB = Cloner.OutlinedFunctions.back().second; auto EntryFreq = Cloner.ClonedFuncBFI->getBlockFreq(&Cloner.ClonedFunc->getEntryBlock()); auto OutliningCallFreq = - Cloner.ClonedFuncBFI->getBlockFreq(Cloner.OutliningCallBB); - - auto OutlineRegionRelFreq = - BranchProbability::getBranchProbability(OutliningCallFreq.getFrequency(), - EntryFreq.getFrequency()); + Cloner.ClonedFuncBFI->getBlockFreq(OutliningCallBB); + // FIXME Hackery needed because ClonedFuncBFI is based on the function BEFORE + // we outlined any regions, so we may encounter situations where the + // OutliningCallFreq is *slightly* bigger than the EntryFreq. + if (OutliningCallFreq.getFrequency() > EntryFreq.getFrequency()) { + OutliningCallFreq = EntryFreq; + } + auto OutlineRegionRelFreq = BranchProbability::getBranchProbability( + OutliningCallFreq.getFrequency(), EntryFreq.getFrequency()); if (hasProfileData(Cloner.OrigFunc, Cloner.ClonedOI.get())) return OutlineRegionRelFreq; @@ -448,10 +765,10 @@ PartialInlinerImpl::getOutliningCallBBRelativeFreq(FunctionCloner &Cloner) { } bool PartialInlinerImpl::shouldPartialInline( - CallSite CS, FunctionCloner &Cloner, BlockFrequency WeightedOutliningRcost, - OptimizationRemarkEmitter &ORE) { - + CallSite CS, FunctionCloner &Cloner, + BlockFrequency WeightedOutliningRcost) { using namespace ore; + if (SkipCostAnalysis) return true; @@ -461,30 +778,37 @@ bool PartialInlinerImpl::shouldPartialInline( Function *Caller = CS.getCaller(); auto &CalleeTTI = (*GetTTI)(*Callee); + auto &ORE = (*GetORE)(*Caller); InlineCost IC = getInlineCost(CS, getInlineParams(), CalleeTTI, - *GetAssumptionCache, GetBFI, PSI); + *GetAssumptionCache, GetBFI, PSI, &ORE); if (IC.isAlways()) { - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "AlwaysInline", Call) + ORE.emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "AlwaysInline", Call) << NV("Callee", Cloner.OrigFunc) - << " should always be fully inlined, not partially"); + << " should always be fully inlined, not partially"; + }); return false; } if (IC.isNever()) { - ORE.emit(OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NeverInline", Call) << NV("Callee", Cloner.OrigFunc) << " not partially inlined into " << NV("Caller", Caller) - << " because it should never be inlined (cost=never)"); + << " because it should never be inlined (cost=never)"; + }); return false; } if (!IC) { - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call) + ORE.emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "TooCostly", Call) << NV("Callee", Cloner.OrigFunc) << " not partially inlined into " << NV("Caller", Caller) << " because too costly to inline (cost=" << NV("Cost", IC.getCost()) << ", threshold=" - << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); + << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"; + }); return false; } const DataLayout &DL = Caller->getParent()->getDataLayout(); @@ -495,23 +819,28 @@ bool PartialInlinerImpl::shouldPartialInline( // Weighted saving is smaller than weighted cost, return false if (NormWeightedSavings < WeightedOutliningRcost) { - ORE.emit( - OptimizationRemarkAnalysis(DEBUG_TYPE, "OutliningCallcostTooHigh", Call) - << NV("Callee", Cloner.OrigFunc) << " not partially inlined into " - << NV("Caller", Caller) << " runtime overhead (overhead=" - << NV("Overhead", (unsigned)WeightedOutliningRcost.getFrequency()) - << ", savings=" - << NV("Savings", (unsigned)NormWeightedSavings.getFrequency()) << ")" - << " of making the outlined call is too high"); + ORE.emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "OutliningCallcostTooHigh", + Call) + << NV("Callee", Cloner.OrigFunc) << " not partially inlined into " + << NV("Caller", Caller) << " runtime overhead (overhead=" + << NV("Overhead", (unsigned)WeightedOutliningRcost.getFrequency()) + << ", savings=" + << NV("Savings", (unsigned)NormWeightedSavings.getFrequency()) + << ")" + << " of making the outlined call is too high"; + }); return false; } - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "CanBePartiallyInlined", Call) + ORE.emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "CanBePartiallyInlined", Call) << NV("Callee", Cloner.OrigFunc) << " can be partially inlined into " << NV("Caller", Caller) << " with cost=" << NV("Cost", IC.getCost()) << " (threshold=" - << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"); + << NV("Threshold", IC.getCostDelta() + IC.getCost()) << ")"; + }); return true; } @@ -566,23 +895,26 @@ int PartialInlinerImpl::computeBBInlineCost(BasicBlock *BB) { std::tuple<int, int> PartialInlinerImpl::computeOutliningCosts(FunctionCloner &Cloner) { - - // Now compute the cost of the call sequence to the outlined function - // 'OutlinedFunction' in BB 'OutliningCallBB': - int OutliningFuncCallCost = computeBBInlineCost(Cloner.OutliningCallBB); - - // Now compute the cost of the extracted/outlined function itself: - int OutlinedFunctionCost = 0; - for (BasicBlock &BB : *Cloner.OutlinedFunc) { - OutlinedFunctionCost += computeBBInlineCost(&BB); + int OutliningFuncCallCost = 0, OutlinedFunctionCost = 0; + for (auto FuncBBPair : Cloner.OutlinedFunctions) { + Function *OutlinedFunc = FuncBBPair.first; + BasicBlock* OutliningCallBB = FuncBBPair.second; + // Now compute the cost of the call sequence to the outlined function + // 'OutlinedFunction' in BB 'OutliningCallBB': + OutliningFuncCallCost += computeBBInlineCost(OutliningCallBB); + + // Now compute the cost of the extracted/outlined function itself: + for (BasicBlock &BB : *OutlinedFunc) + OutlinedFunctionCost += computeBBInlineCost(&BB); } - assert(OutlinedFunctionCost >= Cloner.OutlinedRegionCost && "Outlined function cost should be no less than the outlined region"); + // The code extractor introduces a new root and exit stub blocks with // additional unconditional branches. Those branches will be eliminated // later with bb layout. The cost should be adjusted accordingly: - OutlinedFunctionCost -= 2 * InlineConstants::InstrCost; + OutlinedFunctionCost -= + 2 * InlineConstants::InstrCost * Cloner.OutlinedFunctions.size(); int OutliningRuntimeOverhead = OutliningFuncCallCost + @@ -636,9 +968,9 @@ void PartialInlinerImpl::computeCallsiteToProfCountMap( } } -PartialInlinerImpl::FunctionCloner::FunctionCloner(Function *F, - FunctionOutliningInfo *OI) - : OrigFunc(F) { +PartialInlinerImpl::FunctionCloner::FunctionCloner( + Function *F, FunctionOutliningInfo *OI, OptimizationRemarkEmitter &ORE) + : OrigFunc(F), ORE(ORE) { ClonedOI = llvm::make_unique<FunctionOutliningInfo>(); // Clone the function, so that we can hack away on it. @@ -659,8 +991,39 @@ PartialInlinerImpl::FunctionCloner::FunctionCloner(Function *F, F->replaceAllUsesWith(ClonedFunc); } -void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { +PartialInlinerImpl::FunctionCloner::FunctionCloner( + Function *F, FunctionOutliningMultiRegionInfo *OI, + OptimizationRemarkEmitter &ORE) + : OrigFunc(F), ORE(ORE) { + ClonedOMRI = llvm::make_unique<FunctionOutliningMultiRegionInfo>(); + + // Clone the function, so that we can hack away on it. + ValueToValueMapTy VMap; + ClonedFunc = CloneFunction(F, VMap); + // Go through all Outline Candidate Regions and update all BasicBlock + // information. + for (FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegionInfo : + OI->ORI) { + SmallVector<BasicBlock *, 8> Region; + for (BasicBlock *BB : RegionInfo.Region) { + Region.push_back(cast<BasicBlock>(VMap[BB])); + } + BasicBlock *NewEntryBlock = cast<BasicBlock>(VMap[RegionInfo.EntryBlock]); + BasicBlock *NewExitBlock = cast<BasicBlock>(VMap[RegionInfo.ExitBlock]); + BasicBlock *NewReturnBlock = nullptr; + if (RegionInfo.ReturnBlock) + NewReturnBlock = cast<BasicBlock>(VMap[RegionInfo.ReturnBlock]); + FunctionOutliningMultiRegionInfo::OutlineRegionInfo MappedRegionInfo( + Region, NewEntryBlock, NewExitBlock, NewReturnBlock); + ClonedOMRI->ORI.push_back(MappedRegionInfo); + } + // Go ahead and update all uses to the duplicate, so that we can just + // use the inliner functionality when we're done hacking. + F->replaceAllUsesWith(ClonedFunc); +} + +void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { auto getFirstPHI = [](BasicBlock *BB) { BasicBlock::iterator I = BB->begin(); PHINode *FirstPhi = nullptr; @@ -676,6 +1039,11 @@ void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { return FirstPhi; }; + // Shouldn't need to normalize PHIs if we're not outlining non-early return + // blocks. + if (!ClonedOI) + return; + // Special hackery is needed with PHI nodes that have inputs from more than // one extracted block. For simplicity, just split the PHIs into a two-level // sequence of PHIs, some of which will go in the extracted region, and some @@ -726,16 +1094,90 @@ void PartialInlinerImpl::FunctionCloner::NormalizeReturnBlock() { DeadPhis.push_back(OldPhi); } ++I; - } - for (auto *DP : DeadPhis) - DP->eraseFromParent(); + } + for (auto *DP : DeadPhis) + DP->eraseFromParent(); + + for (auto E : ClonedOI->ReturnBlockPreds) { + E->getTerminator()->replaceUsesOfWith(PreReturn, ClonedOI->ReturnBlock); + } +} + +bool PartialInlinerImpl::FunctionCloner::doMultiRegionFunctionOutlining() { + + auto ComputeRegionCost = [](SmallVectorImpl<BasicBlock *> &Region) { + int Cost = 0; + for (BasicBlock* BB : Region) + Cost += computeBBInlineCost(BB); + return Cost; + }; + + assert(ClonedOMRI && "Expecting OutlineInfo for multi region outline"); + + if (ClonedOMRI->ORI.empty()) + return false; + + // The CodeExtractor needs a dominator tree. + DominatorTree DT; + DT.recalculate(*ClonedFunc); + + // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo. + LoopInfo LI(DT); + BranchProbabilityInfo BPI(*ClonedFunc, LI); + ClonedFuncBFI.reset(new BlockFrequencyInfo(*ClonedFunc, BPI, LI)); - for (auto E : ClonedOI->ReturnBlockPreds) { - E->getTerminator()->replaceUsesOfWith(PreReturn, ClonedOI->ReturnBlock); + SetVector<Value *> Inputs, Outputs, Sinks; + for (FunctionOutliningMultiRegionInfo::OutlineRegionInfo RegionInfo : + ClonedOMRI->ORI) { + int CurrentOutlinedRegionCost = ComputeRegionCost(RegionInfo.Region); + + CodeExtractor CE(RegionInfo.Region, &DT, /*AggregateArgs*/ false, + ClonedFuncBFI.get(), &BPI, /* AllowVarargs */ false); + + CE.findInputsOutputs(Inputs, Outputs, Sinks); + +#ifndef NDEBUG + if (TracePartialInlining) { + dbgs() << "inputs: " << Inputs.size() << "\n"; + dbgs() << "outputs: " << Outputs.size() << "\n"; + for (Value *value : Inputs) + dbgs() << "value used in func: " << *value << "\n"; + for (Value *output : Outputs) + dbgs() << "instr used in func: " << *output << "\n"; } +#endif + // Do not extract regions that have live exit variables. + if (Outputs.size() > 0 && !ForceLiveExit) + continue; + + Function *OutlinedFunc = CE.extractCodeRegion(); + + if (OutlinedFunc) { + CallSite OCS = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc); + BasicBlock *OutliningCallBB = OCS.getInstruction()->getParent(); + assert(OutliningCallBB->getParent() == ClonedFunc); + OutlinedFunctions.push_back(std::make_pair(OutlinedFunc,OutliningCallBB)); + NumColdRegionsOutlined++; + OutlinedRegionCost += CurrentOutlinedRegionCost; + + if (MarkOutlinedColdCC) { + OutlinedFunc->setCallingConv(CallingConv::Cold); + OCS.setCallingConv(CallingConv::Cold); + } + } else + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ExtractFailed", + &RegionInfo.Region.front()->front()) + << "Failed to extract region at block " + << ore::NV("Block", RegionInfo.Region.front()); + }); + } + + return !OutlinedFunctions.empty(); } -Function *PartialInlinerImpl::FunctionCloner::doFunctionOutlining() { +Function * +PartialInlinerImpl::FunctionCloner::doSingleRegionFunctionOutlining() { // Returns true if the block is to be partial inlined into the caller // (i.e. not to be extracted to the out of line function) auto ToBeInlined = [&, this](BasicBlock *BB) { @@ -744,6 +1186,16 @@ Function *PartialInlinerImpl::FunctionCloner::doFunctionOutlining() { ClonedOI->Entries.end()); }; + assert(ClonedOI && "Expecting OutlineInfo for single region outline"); + // The CodeExtractor needs a dominator tree. + DominatorTree DT; + DT.recalculate(*ClonedFunc); + + // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo. + LoopInfo LI(DT); + BranchProbabilityInfo BPI(*ClonedFunc, LI); + ClonedFuncBFI.reset(new BlockFrequencyInfo(*ClonedFunc, BPI, LI)); + // Gather up the blocks that we're going to extract. std::vector<BasicBlock *> ToExtract; ToExtract.push_back(ClonedOI->NonReturnBlock); @@ -759,26 +1211,27 @@ Function *PartialInlinerImpl::FunctionCloner::doFunctionOutlining() { OutlinedRegionCost += computeBBInlineCost(&BB); } - // The CodeExtractor needs a dominator tree. - DominatorTree DT; - DT.recalculate(*ClonedFunc); - - // Manually calculate a BlockFrequencyInfo and BranchProbabilityInfo. - LoopInfo LI(DT); - BranchProbabilityInfo BPI(*ClonedFunc, LI); - ClonedFuncBFI.reset(new BlockFrequencyInfo(*ClonedFunc, BPI, LI)); - // Extract the body of the if. - OutlinedFunc = CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, - ClonedFuncBFI.get(), &BPI) - .extractCodeRegion(); + Function *OutlinedFunc = + CodeExtractor(ToExtract, &DT, /*AggregateArgs*/ false, + ClonedFuncBFI.get(), &BPI, + /* AllowVarargs */ true) + .extractCodeRegion(); if (OutlinedFunc) { - OutliningCallBB = PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc) - .getInstruction() - ->getParent(); + BasicBlock *OutliningCallBB = + PartialInlinerImpl::getOneCallSiteTo(OutlinedFunc) + .getInstruction() + ->getParent(); assert(OutliningCallBB->getParent() == ClonedFunc); - } + OutlinedFunctions.push_back(std::make_pair(OutlinedFunc, OutliningCallBB)); + } else + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ExtractFailed", + &ToExtract.front()->front()) + << "Failed to extract region at block " + << ore::NV("Block", ToExtract.front()); + }); return OutlinedFunc; } @@ -789,76 +1242,133 @@ PartialInlinerImpl::FunctionCloner::~FunctionCloner() { ClonedFunc->replaceAllUsesWith(OrigFunc); ClonedFunc->eraseFromParent(); if (!IsFunctionInlined) { - // Remove the function that is speculatively created if there is no + // Remove each function that was speculatively created if there is no // reference. - if (OutlinedFunc) - OutlinedFunc->eraseFromParent(); + for (auto FuncBBPair : OutlinedFunctions) { + Function *Func = FuncBBPair.first; + Func->eraseFromParent(); + } } } -Function *PartialInlinerImpl::unswitchFunction(Function *F) { +std::pair<bool, Function *> PartialInlinerImpl::unswitchFunction(Function *F) { if (F->hasAddressTaken()) - return nullptr; + return {false, nullptr}; // Let inliner handle it if (F->hasFnAttribute(Attribute::AlwaysInline)) - return nullptr; + return {false, nullptr}; if (F->hasFnAttribute(Attribute::NoInline)) - return nullptr; + return {false, nullptr}; if (PSI->isFunctionEntryCold(F)) - return nullptr; + return {false, nullptr}; if (F->user_begin() == F->user_end()) - return nullptr; + return {false, nullptr}; + + auto &ORE = (*GetORE)(*F); + + // Only try to outline cold regions if we have a profile summary, which + // implies we have profiling information. + if (PSI->hasProfileSummary() && F->getEntryCount().hasValue() && + !DisableMultiRegionPartialInline) { + std::unique_ptr<FunctionOutliningMultiRegionInfo> OMRI = + computeOutliningColdRegionsInfo(F); + if (OMRI) { + FunctionCloner Cloner(F, OMRI.get(), ORE); + +#ifndef NDEBUG + if (TracePartialInlining) { + dbgs() << "HotCountThreshold = " << PSI->getHotCountThreshold() << "\n"; + dbgs() << "ColdCountThreshold = " << PSI->getColdCountThreshold() + << "\n"; + } +#endif + bool DidOutline = Cloner.doMultiRegionFunctionOutlining(); + + if (DidOutline) { +#ifndef NDEBUG + if (TracePartialInlining) { + dbgs() << ">>>>>> Outlined (Cloned) Function >>>>>>\n"; + Cloner.ClonedFunc->print(dbgs()); + dbgs() << "<<<<<< Outlined (Cloned) Function <<<<<<\n"; + } +#endif - std::unique_ptr<FunctionOutliningInfo> OI = computeOutliningInfo(F); + if (tryPartialInline(Cloner)) + return {true, nullptr}; + } + } + } + // Fall-thru to regular partial inlining if we: + // i) can't find any cold regions to outline, or + // ii) can't inline the outlined function anywhere. + std::unique_ptr<FunctionOutliningInfo> OI = computeOutliningInfo(F); if (!OI) - return nullptr; + return {false, nullptr}; - FunctionCloner Cloner(F, OI.get()); + FunctionCloner Cloner(F, OI.get(), ORE); Cloner.NormalizeReturnBlock(); - Function *OutlinedFunction = Cloner.doFunctionOutlining(); + + Function *OutlinedFunction = Cloner.doSingleRegionFunctionOutlining(); + + if (!OutlinedFunction) + return {false, nullptr}; bool AnyInline = tryPartialInline(Cloner); if (AnyInline) - return OutlinedFunction; + return {true, OutlinedFunction}; - return nullptr; + return {false, nullptr}; } bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { - int NonWeightedRcost; - int SizeCost; - - if (Cloner.OutlinedFunc == nullptr) + if (Cloner.OutlinedFunctions.empty()) return false; + int SizeCost = 0; + BlockFrequency WeightedRcost; + int NonWeightedRcost; std::tie(SizeCost, NonWeightedRcost) = computeOutliningCosts(Cloner); - auto RelativeToEntryFreq = getOutliningCallBBRelativeFreq(Cloner); - auto WeightedRcost = BlockFrequency(NonWeightedRcost) * RelativeToEntryFreq; - - // The call sequence to the outlined function is larger than the original - // outlined region size, it does not increase the chances of inlining - // the function with outlining (The inliner usies the size increase to + // Only calculate RelativeToEntryFreq when we are doing single region + // outlining. + BranchProbability RelativeToEntryFreq; + if (Cloner.ClonedOI) { + RelativeToEntryFreq = getOutliningCallBBRelativeFreq(Cloner); + } else + // RelativeToEntryFreq doesn't make sense when we have more than one + // outlined call because each call will have a different relative frequency + // to the entry block. We can consider using the average, but the + // usefulness of that information is questionable. For now, assume we never + // execute the calls to outlined functions. + RelativeToEntryFreq = BranchProbability(0, 1); + + WeightedRcost = BlockFrequency(NonWeightedRcost) * RelativeToEntryFreq; + + // The call sequence(s) to the outlined function(s) are larger than the sum of + // the original outlined region size(s), it does not increase the chances of + // inlining the function with outlining (The inliner uses the size increase to // model the cost of inlining a callee). if (!SkipCostAnalysis && Cloner.OutlinedRegionCost < SizeCost) { - OptimizationRemarkEmitter ORE(Cloner.OrigFunc); + auto &ORE = (*GetORE)(*Cloner.OrigFunc); DebugLoc DLoc; BasicBlock *Block; std::tie(DLoc, Block) = getOneDebugLoc(Cloner.ClonedFunc); - ORE.emit(OptimizationRemarkAnalysis(DEBUG_TYPE, "OutlineRegionTooSmall", + ORE.emit([&]() { + return OptimizationRemarkAnalysis(DEBUG_TYPE, "OutlineRegionTooSmall", DLoc, Block) << ore::NV("Function", Cloner.OrigFunc) << " not partially inlined into callers (Original Size = " << ore::NV("OutlinedRegionOriginalSize", Cloner.OutlinedRegionCost) << ", Size of call sequence to outlined function = " - << ore::NV("NewSize", SizeCost) << ")"); + << ore::NV("NewSize", SizeCost) << ")"; + }); return false; } @@ -882,18 +1392,26 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { if (IsLimitReached()) continue; - OptimizationRemarkEmitter ORE(CS.getCaller()); - if (!shouldPartialInline(CS, Cloner, WeightedRcost, ORE)) + if (!shouldPartialInline(CS, Cloner, WeightedRcost)) continue; - ORE.emit( - OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", CS.getInstruction()) - << ore::NV("Callee", Cloner.OrigFunc) << " partially inlined into " - << ore::NV("Caller", CS.getCaller())); + auto &ORE = (*GetORE)(*CS.getCaller()); + // Construct remark before doing the inlining, as after successful inlining + // the callsite is removed. + OptimizationRemark OR(DEBUG_TYPE, "PartiallyInlined", CS.getInstruction()); + OR << ore::NV("Callee", Cloner.OrigFunc) << " partially inlined into " + << ore::NV("Caller", CS.getCaller()); InlineFunctionInfo IFI(nullptr, GetAssumptionCache, PSI); - InlineFunction(CS, IFI); + // We can only forward varargs when we outlined a single region, else we + // bail on vararg functions. + if (!InlineFunction(CS, IFI, nullptr, true, + (Cloner.ClonedOI ? Cloner.OutlinedFunctions.back().first + : nullptr))) + continue; + + ORE.emit(OR); // Now update the entry count: if (CalleeEntryCountV && CallSiteToProfCountMap.count(User)) { @@ -904,13 +1422,23 @@ bool PartialInlinerImpl::tryPartialInline(FunctionCloner &Cloner) { AnyInline = true; NumPartialInlining++; // Update the stats - NumPartialInlined++; + if (Cloner.ClonedOI) + NumPartialInlined++; + else + NumColdOutlinePartialInlined++; + } if (AnyInline) { Cloner.IsFunctionInlined = true; if (CalleeEntryCount) Cloner.OrigFunc->setEntryCount(CalleeEntryCountV); + auto &ORE = (*GetORE)(*Cloner.OrigFunc); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "PartiallyInlined", Cloner.OrigFunc) + << "Partially inlined into at least one caller"; + }); + } return AnyInline; @@ -944,8 +1472,10 @@ bool PartialInlinerImpl::run(Module &M) { if (Recursive) continue; - if (Function *NewFunc = unswitchFunction(CurrFunc)) { - Worklist.push_back(NewFunc); + std::pair<bool, Function * > Result = unswitchFunction(CurrFunc); + if (Result.second) + Worklist.push_back(Result.second); + if (Result.first) { Changed = true; } } @@ -954,6 +1484,7 @@ bool PartialInlinerImpl::run(Module &M) { } char PartialInlinerLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(PartialInlinerLegacyPass, "partial-inliner", "Partial Inliner", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) @@ -985,9 +1516,15 @@ PreservedAnalyses PartialInlinerPass::run(Module &M, return FAM.getResult<TargetIRAnalysis>(F); }; + std::function<OptimizationRemarkEmitter &(Function &)> GetORE = + [&FAM](Function &F) -> OptimizationRemarkEmitter & { + return FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + }; + ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); - if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI).run(M)) + if (PartialInlinerImpl(&GetAssumptionCache, &GetTTI, {GetBFI}, PSI, &GetORE) + .run(M)) return PreservedAnalyses::none(); return PreservedAnalyses::all(); } diff --git a/lib/Transforms/IPO/PassManagerBuilder.cpp b/lib/Transforms/IPO/PassManagerBuilder.cpp index 0b319f6a488b..3855e6245d8e 100644 --- a/lib/Transforms/IPO/PassManagerBuilder.cpp +++ b/lib/Transforms/IPO/PassManagerBuilder.cpp @@ -26,11 +26,9 @@ #include "llvm/Analysis/TypeBasedAliasAnalysis.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/ModuleSummaryIndex.h" #include "llvm/IR/Verifier.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ManagedStatic.h" -#include "llvm/Target/TargetMachine.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/IPO/ForceFunctionAttrs.h" #include "llvm/Transforms/IPO/FunctionAttrs.h" @@ -94,15 +92,6 @@ static cl::opt<bool> EnableLoopInterchange( "enable-loopinterchange", cl::init(false), cl::Hidden, cl::desc("Enable the new, experimental LoopInterchange Pass")); -static cl::opt<bool> EnableNonLTOGlobalsModRef( - "enable-non-lto-gmr", cl::init(true), cl::Hidden, - cl::desc( - "Enable the GlobalsModRef AliasAnalysis outside of the LTO pipeline.")); - -static cl::opt<bool> EnableLoopLoadElim( - "enable-loop-load-elim", cl::init(true), cl::Hidden, - cl::desc("Enable the LoopLoadElimination Pass")); - static cl::opt<bool> EnablePrepareForThinLTO("prepare-for-thinlto", cl::init(false), cl::Hidden, cl::desc("Enable preparation for ThinLTO.")); @@ -160,7 +149,6 @@ PassManagerBuilder::PassManagerBuilder() { SizeLevel = 0; LibraryInfo = nullptr; Inliner = nullptr; - DisableUnitAtATime = false; DisableUnrollLoops = false; SLPVectorize = RunSLPVectorization; LoopVectorize = RunLoopVectorization; @@ -251,6 +239,7 @@ void PassManagerBuilder::addInstructionCombiningPass( void PassManagerBuilder::populateFunctionPassManager( legacy::FunctionPassManager &FPM) { addExtensionsToPM(EP_EarlyAsPossible, FPM); + FPM.add(createEntryExitInstrumenterPass()); // Add LibraryInfo if we have some. if (LibraryInfo) @@ -428,6 +417,14 @@ void PassManagerBuilder::populateModulePassManager( else if (GlobalExtensionsNotEmpty() || !Extensions.empty()) MPM.add(createBarrierNoopPass()); + if (PerformThinLTO) { + // Drop available_externally and unreferenced globals. This is necessary + // with ThinLTO in order to avoid leaving undefined references to dead + // globals in the object file. + MPM.add(createEliminateAvailableExternallyPass()); + MPM.add(createGlobalDCEPass()); + } + addExtensionsToPM(EP_EnabledOnOptLevel0, MPM); // Rename anon globals to be able to export them in the summary. @@ -464,23 +461,25 @@ void PassManagerBuilder::populateModulePassManager( if (PrepareForThinLTOUsingPGOSampleProfile) DisableUnrollLoops = true; - if (!DisableUnitAtATime) { - // Infer attributes about declarations if possible. - MPM.add(createInferFunctionAttrsLegacyPass()); + // Infer attributes about declarations if possible. + MPM.add(createInferFunctionAttrsLegacyPass()); - addExtensionsToPM(EP_ModuleOptimizerEarly, MPM); + addExtensionsToPM(EP_ModuleOptimizerEarly, MPM); - MPM.add(createIPSCCPPass()); // IP SCCP - MPM.add(createGlobalOptimizerPass()); // Optimize out global vars - // Promote any localized global vars. - MPM.add(createPromoteMemoryToRegisterPass()); + if (OptLevel > 2) + MPM.add(createCallSiteSplittingPass()); - MPM.add(createDeadArgEliminationPass()); // Dead argument elimination + MPM.add(createIPSCCPPass()); // IP SCCP + MPM.add(createCalledValuePropagationPass()); + MPM.add(createGlobalOptimizerPass()); // Optimize out global vars + // Promote any localized global vars. + MPM.add(createPromoteMemoryToRegisterPass()); - addInstructionCombiningPass(MPM); // Clean up after IPCP & DAE - addExtensionsToPM(EP_Peephole, MPM); - MPM.add(createCFGSimplificationPass()); // Clean up after IPCP & DAE - } + MPM.add(createDeadArgEliminationPass()); // Dead argument elimination + + addInstructionCombiningPass(MPM); // Clean up after IPCP & DAE + addExtensionsToPM(EP_Peephole, MPM); + MPM.add(createCFGSimplificationPass()); // Clean up after IPCP & DAE // For SamplePGO in ThinLTO compile phase, we do not want to do indirect // call promotion as it will change the CFG too much to make the 2nd @@ -490,21 +489,21 @@ void PassManagerBuilder::populateModulePassManager( if (!PerformThinLTO && !PrepareForThinLTOUsingPGOSampleProfile) addPGOInstrPasses(MPM); - if (EnableNonLTOGlobalsModRef) - // We add a module alias analysis pass here. In part due to bugs in the - // analysis infrastructure this "works" in that the analysis stays alive - // for the entire SCC pass run below. - MPM.add(createGlobalsAAWrapperPass()); + // We add a module alias analysis pass here. In part due to bugs in the + // analysis infrastructure this "works" in that the analysis stays alive + // for the entire SCC pass run below. + MPM.add(createGlobalsAAWrapperPass()); // Start of CallGraph SCC passes. - if (!DisableUnitAtATime) - MPM.add(createPruneEHPass()); // Remove dead EH info + MPM.add(createPruneEHPass()); // Remove dead EH info + bool RunInliner = false; if (Inliner) { MPM.add(Inliner); Inliner = nullptr; + RunInliner = true; } - if (!DisableUnitAtATime) - MPM.add(createPostOrderFunctionAttrsLegacyPass()); + + MPM.add(createPostOrderFunctionAttrsLegacyPass()); if (OptLevel > 2) MPM.add(createArgumentPromotionPass()); // Scalarize uninlined fn args @@ -515,11 +514,11 @@ void PassManagerBuilder::populateModulePassManager( // pass manager that we are specifically trying to avoid. To prevent this // we must insert a no-op module pass to reset the pass manager. MPM.add(createBarrierNoopPass()); + if (RunPartialInlining) MPM.add(createPartialInliningPass()); - if (!DisableUnitAtATime && OptLevel > 1 && !PrepareForLTO && - !PrepareForThinLTO) + if (OptLevel > 1 && !PrepareForLTO && !PrepareForThinLTO) // Remove avail extern fns and globals definitions if we aren't // compiling an object file for later LTO. For LTO we want to preserve // these so they are eligible for inlining at link-time. Note if they @@ -531,15 +530,26 @@ void PassManagerBuilder::populateModulePassManager( // and saves running remaining passes on the eliminated functions. MPM.add(createEliminateAvailableExternallyPass()); - if (!DisableUnitAtATime) - MPM.add(createReversePostOrderFunctionAttrsPass()); + MPM.add(createReversePostOrderFunctionAttrsPass()); + + // The inliner performs some kind of dead code elimination as it goes, + // but there are cases that are not really caught by it. We might + // at some point consider teaching the inliner about them, but it + // is OK for now to run GlobalOpt + GlobalDCE in tandem as their + // benefits generally outweight the cost, making the whole pipeline + // faster. + if (RunInliner) { + MPM.add(createGlobalOptimizerPass()); + MPM.add(createGlobalDCEPass()); + } // If we are planning to perform ThinLTO later, let's not bloat the code with // unrolling/vectorization/... now. We'll first run the inliner + CGSCC passes // during ThinLTO and perform the rest of the optimizations afterward. if (PrepareForThinLTO) { - // Reduce the size of the IR as much as possible. - MPM.add(createGlobalOptimizerPass()); + // Ensure we perform any last passes, but do so before renaming anonymous + // globals in case the passes add any. + addExtensionsToPM(EP_OptimizerLast, MPM); // Rename anon globals to be able to export them in the summary. MPM.add(createNameAnonGlobalPass()); return; @@ -560,23 +570,22 @@ void PassManagerBuilder::populateModulePassManager( MPM.add(createLICMPass()); // Hoist loop invariants } - if (EnableNonLTOGlobalsModRef) - // We add a fresh GlobalsModRef run at this point. This is particularly - // useful as the above will have inlined, DCE'ed, and function-attr - // propagated everything. We should at this point have a reasonably minimal - // and richly annotated call graph. By computing aliasing and mod/ref - // information for all local globals here, the late loop passes and notably - // the vectorizer will be able to use them to help recognize vectorizable - // memory operations. - // - // Note that this relies on a bug in the pass manager which preserves - // a module analysis into a function pass pipeline (and throughout it) so - // long as the first function pass doesn't invalidate the module analysis. - // Thus both Float2Int and LoopRotate have to preserve AliasAnalysis for - // this to work. Fortunately, it is trivial to preserve AliasAnalysis - // (doing nothing preserves it as it is required to be conservatively - // correct in the face of IR changes). - MPM.add(createGlobalsAAWrapperPass()); + // We add a fresh GlobalsModRef run at this point. This is particularly + // useful as the above will have inlined, DCE'ed, and function-attr + // propagated everything. We should at this point have a reasonably minimal + // and richly annotated call graph. By computing aliasing and mod/ref + // information for all local globals here, the late loop passes and notably + // the vectorizer will be able to use them to help recognize vectorizable + // memory operations. + // + // Note that this relies on a bug in the pass manager which preserves + // a module analysis into a function pass pipeline (and throughout it) so + // long as the first function pass doesn't invalidate the module analysis. + // Thus both Float2Int and LoopRotate have to preserve AliasAnalysis for + // this to work. Fortunately, it is trivial to preserve AliasAnalysis + // (doing nothing preserves it as it is required to be conservatively + // correct in the face of IR changes). + MPM.add(createGlobalsAAWrapperPass()); MPM.add(createFloat2IntPass()); @@ -597,8 +606,7 @@ void PassManagerBuilder::populateModulePassManager( // Eliminate loads by forwarding stores from the previous iteration to loads // of the current iteration. - if (EnableLoopLoadElim) - MPM.add(createLoopLoadEliminationPass()); + MPM.add(createLoopLoadEliminationPass()); // FIXME: Because of #pragma vectorize enable, the passes below are always // inserted in the pipeline, even when the vectorizer doesn't run (ex. when @@ -622,6 +630,13 @@ void PassManagerBuilder::populateModulePassManager( addInstructionCombiningPass(MPM); } + // Cleanup after loop vectorization, etc. Simplification passes like CVP and + // GVN, loop transforms, and others have already run, so it's now better to + // convert to more optimized IR using more aggressive simplify CFG options. + // The extra sinking transform can create larger basic blocks, so do this + // before SLP vectorization. + MPM.add(createCFGSimplificationPass(1, true, true, false, true)); + if (RunSLPAfterLoopVectorization && SLPVectorize) { MPM.add(createSLPVectorizerPass()); // Vectorize parallel scalar chains. if (OptLevel > 1 && ExtraVectorizerPasses) { @@ -630,7 +645,6 @@ void PassManagerBuilder::populateModulePassManager( } addExtensionsToPM(EP_Peephole, MPM); - MPM.add(createLateCFGSimplificationPass()); // Switches to lookup tables addInstructionCombiningPass(MPM); if (!DisableUnrollLoops) { @@ -650,16 +664,14 @@ void PassManagerBuilder::populateModulePassManager( // about pointer alignments. MPM.add(createAlignmentFromAssumptionsPass()); - if (!DisableUnitAtATime) { - // FIXME: We shouldn't bother with this anymore. - MPM.add(createStripDeadPrototypesPass()); // Get rid of dead prototypes + // FIXME: We shouldn't bother with this anymore. + MPM.add(createStripDeadPrototypesPass()); // Get rid of dead prototypes - // GlobalOpt already deletes dead functions and globals, at -O2 try a - // late pass of GlobalDCE. It is capable of deleting dead cycles. - if (OptLevel > 1) { - MPM.add(createGlobalDCEPass()); // Remove dead fns and globals. - MPM.add(createConstantMergePass()); // Merge dup global constants - } + // GlobalOpt already deletes dead functions and globals, at -O2 try a + // late pass of GlobalDCE. It is capable of deleting dead cycles. + if (OptLevel > 1) { + MPM.add(createGlobalDCEPass()); // Remove dead fns and globals. + MPM.add(createConstantMergePass()); // Merge dup global constants } if (MergeFunctions) @@ -673,6 +685,11 @@ void PassManagerBuilder::populateModulePassManager( // Get rid of LCSSA nodes. MPM.add(createInstructionSimplifierPass()); + // This hoists/decomposes div/rem ops. It should run after other sink/hoist + // passes to avoid re-sinking, but before SimplifyCFG because it can allow + // flattening of blocks. + MPM.add(createDivRemPairsPass()); + // LoopSink (and other loop passes since the last simplifyCFG) might have // resulted in single-entry-single-exit or empty blocks. Clean up the CFG. MPM.add(createCFGSimplificationPass()); @@ -695,6 +712,9 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { PM.add(createInferFunctionAttrsLegacyPass()); if (OptLevel > 1) { + // Split call-site with more constrained arguments. + PM.add(createCallSiteSplittingPass()); + // Indirect call promotion. This should promote all the targets that are // left by the earlier promotion pass that promotes intra-module targets. // This two-step promotion is to save the compile time. For LTO, it should @@ -706,6 +726,10 @@ void PassManagerBuilder::addLTOOptimizationPasses(legacy::PassManagerBase &PM) { // opens opportunities for globalopt (and inlining) by substituting function // pointers passed as arguments to direct uses of functions. PM.add(createIPSCCPPass()); + + // Attach metadata to indirect call sites indicating the set of functions + // they may target at run-time. This should follow IPSCCP. + PM.add(createCalledValuePropagationPass()); } // Infer attributes about definitions. The readnone attribute in particular is @@ -936,8 +960,7 @@ LLVMPassManagerBuilderSetSizeLevel(LLVMPassManagerBuilderRef PMB, void LLVMPassManagerBuilderSetDisableUnitAtATime(LLVMPassManagerBuilderRef PMB, LLVMBool Value) { - PassManagerBuilder *Builder = unwrap(PMB); - Builder->DisableUnitAtATime = Value; + // NOTE: The DisableUnitAtATime switch has been removed. } void diff --git a/lib/Transforms/IPO/PruneEH.cpp b/lib/Transforms/IPO/PruneEH.cpp index 3fd59847a005..46b088189040 100644 --- a/lib/Transforms/IPO/PruneEH.cpp +++ b/lib/Transforms/IPO/PruneEH.cpp @@ -24,7 +24,6 @@ #include "llvm/IR/Function.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" diff --git a/lib/Transforms/IPO/SampleProfile.cpp b/lib/Transforms/IPO/SampleProfile.cpp index 6baada2c1ae1..f0e781b9d923 100644 --- a/lib/Transforms/IPO/SampleProfile.cpp +++ b/lib/Transforms/IPO/SampleProfile.cpp @@ -23,39 +23,65 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/SampleProfile.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" #include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/PostDominators.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" -#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/ValueSymbolTable.h" #include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" +#include "llvm/ProfileData/SampleProf.h" #include "llvm/ProfileData/SampleProfReader.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/ErrorOr.h" -#include "llvm/Support/Format.h" +#include "llvm/Support/GenericDomTree.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include "llvm/Transforms/Utils/Cloning.h" -#include <cctype> +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <functional> +#include <limits> +#include <map> +#include <memory> +#include <string> +#include <system_error> +#include <utility> +#include <vector> using namespace llvm; using namespace sampleprof; @@ -67,34 +93,39 @@ using namespace sampleprof; static cl::opt<std::string> SampleProfileFile( "sample-profile-file", cl::init(""), cl::value_desc("filename"), cl::desc("Profile file loaded by -sample-profile"), cl::Hidden); + static cl::opt<unsigned> SampleProfileMaxPropagateIterations( "sample-profile-max-propagate-iterations", cl::init(100), cl::desc("Maximum number of iterations to go through when propagating " "sample block/edge weights through the CFG.")); + static cl::opt<unsigned> SampleProfileRecordCoverage( "sample-profile-check-record-coverage", cl::init(0), cl::value_desc("N"), cl::desc("Emit a warning if less than N% of records in the input profile " "are matched to the IR.")); + static cl::opt<unsigned> SampleProfileSampleCoverage( "sample-profile-check-sample-coverage", cl::init(0), cl::value_desc("N"), cl::desc("Emit a warning if less than N% of samples in the input profile " "are matched to the IR.")); + static cl::opt<double> SampleProfileHotThreshold( "sample-profile-inline-hot-threshold", cl::init(0.1), cl::value_desc("N"), cl::desc("Inlined functions that account for more than N% of all samples " "collected in the parent function, will be inlined again.")); namespace { -typedef DenseMap<const BasicBlock *, uint64_t> BlockWeightMap; -typedef DenseMap<const BasicBlock *, const BasicBlock *> EquivalenceClassMap; -typedef std::pair<const BasicBlock *, const BasicBlock *> Edge; -typedef DenseMap<Edge, uint64_t> EdgeWeightMap; -typedef DenseMap<const BasicBlock *, SmallVector<const BasicBlock *, 8>> - BlockEdgeMap; + +using BlockWeightMap = DenseMap<const BasicBlock *, uint64_t>; +using EquivalenceClassMap = DenseMap<const BasicBlock *, const BasicBlock *>; +using Edge = std::pair<const BasicBlock *, const BasicBlock *>; +using EdgeWeightMap = DenseMap<Edge, uint64_t>; +using BlockEdgeMap = + DenseMap<const BasicBlock *, SmallVector<const BasicBlock *, 8>>; class SampleCoverageTracker { public: - SampleCoverageTracker() : SampleCoverage(), TotalUsedSamples(0) {} + SampleCoverageTracker() = default; bool markSamplesUsed(const FunctionSamples *FS, uint32_t LineOffset, uint32_t Discriminator, uint64_t Samples); @@ -103,15 +134,16 @@ public: unsigned countBodyRecords(const FunctionSamples *FS) const; uint64_t getTotalUsedSamples() const { return TotalUsedSamples; } uint64_t countBodySamples(const FunctionSamples *FS) const; + void clear() { SampleCoverage.clear(); TotalUsedSamples = 0; } private: - typedef std::map<LineLocation, unsigned> BodySampleCoverageMap; - typedef DenseMap<const FunctionSamples *, BodySampleCoverageMap> - FunctionSamplesCoverageMap; + using BodySampleCoverageMap = std::map<LineLocation, unsigned>; + using FunctionSamplesCoverageMap = + DenseMap<const FunctionSamples *, BodySampleCoverageMap>; /// Coverage map for sampling records. /// @@ -135,7 +167,7 @@ private: /// and all the inlined callsites. Strictly, we should have a map of counters /// keyed by FunctionSamples pointers, but these stats are cleared after /// every function, so we just need to keep a single counter. - uint64_t TotalUsedSamples; + uint64_t TotalUsedSamples = 0; }; /// \brief Sample profile pass. @@ -145,29 +177,31 @@ private: /// profile information found in that file. class SampleProfileLoader { public: - SampleProfileLoader(StringRef Name = SampleProfileFile) - : DT(nullptr), PDT(nullptr), LI(nullptr), ACT(nullptr), Reader(), - Samples(nullptr), Filename(Name), ProfileIsValid(false), - TotalCollectedSamples(0) {} + SampleProfileLoader( + StringRef Name, bool IsThinLTOPreLink, + std::function<AssumptionCache &(Function &)> GetAssumptionCache, + std::function<TargetTransformInfo &(Function &)> GetTargetTransformInfo) + : GetAC(GetAssumptionCache), GetTTI(GetTargetTransformInfo), + Filename(Name), IsThinLTOPreLink(IsThinLTOPreLink) {} bool doInitialization(Module &M); - bool runOnModule(Module &M); - void setACT(AssumptionCacheTracker *A) { ACT = A; } + bool runOnModule(Module &M, ModuleAnalysisManager *AM); void dump() { Reader->dump(); } protected: - bool runOnFunction(Function &F); + bool runOnFunction(Function &F, ModuleAnalysisManager *AM); unsigned getFunctionLoc(Function &F); bool emitAnnotations(Function &F); ErrorOr<uint64_t> getInstWeight(const Instruction &I); ErrorOr<uint64_t> getBlockWeight(const BasicBlock *BB); const FunctionSamples *findCalleeFunctionSamples(const Instruction &I) const; std::vector<const FunctionSamples *> - findIndirectCallFunctionSamples(const Instruction &I) const; + findIndirectCallFunctionSamples(const Instruction &I, uint64_t &Sum) const; const FunctionSamples *findFunctionSamples(const Instruction &I) const; + bool inlineCallInstruction(Instruction *I); bool inlineHotFunctions(Function &F, - DenseSet<GlobalValue::GUID> &ImportGUIDs); + DenseSet<GlobalValue::GUID> &InlinedGUIDs); void printEdgeWeight(raw_ostream &OS, Edge E); void printBlockWeight(raw_ostream &OS, const BasicBlock *BB) const; void printBlockEquivalence(raw_ostream &OS, const BasicBlock *BB); @@ -222,7 +256,8 @@ protected: std::unique_ptr<PostDomTreeBase<BasicBlock>> PDT; std::unique_ptr<LoopInfo> LI; - AssumptionCacheTracker *ACT; + std::function<AssumptionCache &(Function &)> GetAC; + std::function<TargetTransformInfo &(Function &)> GetTTI; /// \brief Predecessors for each basic block in the CFG. BlockEdgeMap Predecessors; @@ -236,19 +271,28 @@ protected: std::unique_ptr<SampleProfileReader> Reader; /// \brief Samples collected for the body of this function. - FunctionSamples *Samples; + FunctionSamples *Samples = nullptr; /// \brief Name of the profile file to load. std::string Filename; /// \brief Flag indicating whether the profile input loaded successfully. - bool ProfileIsValid; + bool ProfileIsValid = false; + + /// \brief Flag indicating if the pass is invoked in ThinLTO compile phase. + /// + /// In this phase, in annotation, we should not promote indirect calls. + /// Instead, we will mark GUIDs that needs to be annotated to the function. + bool IsThinLTOPreLink; /// \brief Total number of samples collected in this profile. /// /// This is the sum of all the samples collected in all the functions executed /// at runtime. - uint64_t TotalCollectedSamples; + uint64_t TotalCollectedSamples = 0; + + /// \brief Optimization Remark Emitter used to emit diagnostic remarks. + OptimizationRemarkEmitter *ORE = nullptr; }; class SampleProfileLoaderLegacyPass : public ModulePass { @@ -256,8 +300,15 @@ public: // Class identification, replacement for typeinfo static char ID; - SampleProfileLoaderLegacyPass(StringRef Name = SampleProfileFile) - : ModulePass(ID), SampleLoader(Name) { + SampleProfileLoaderLegacyPass(StringRef Name = SampleProfileFile, + bool IsThinLTOPreLink = false) + : ModulePass(ID), SampleLoader(Name, IsThinLTOPreLink, + [&](Function &F) -> AssumptionCache & { + return ACT->getAssumptionCache(F); + }, + [&](Function &F) -> TargetTransformInfo & { + return TTIWP->getTTI(F); + }) { initializeSampleProfileLoaderLegacyPassPass( *PassRegistry::getPassRegistry()); } @@ -267,17 +318,23 @@ public: bool doInitialization(Module &M) override { return SampleLoader.doInitialization(M); } + StringRef getPassName() const override { return "Sample profile pass"; } bool runOnModule(Module &M) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); } private: SampleProfileLoader SampleLoader; + AssumptionCacheTracker *ACT = nullptr; + TargetTransformInfoWrapperPass *TTIWP = nullptr; }; +} // end anonymous namespace + /// Return true if the given callsite is hot wrt to its caller. /// /// Functions that were inlined in the original binary will be represented @@ -292,8 +349,8 @@ private: /// /// If that fraction is larger than the default given by /// SampleProfileHotThreshold, the callsite will be inlined again. -bool callsiteIsHot(const FunctionSamples *CallerFS, - const FunctionSamples *CallsiteFS) { +static bool callsiteIsHot(const FunctionSamples *CallerFS, + const FunctionSamples *CallsiteFS) { if (!CallsiteFS) return false; // The callsite was not inlined in the original binary. @@ -309,7 +366,6 @@ bool callsiteIsHot(const FunctionSamples *CallerFS, (double)CallsiteTotalSamples / (double)ParentTotalSamples * 100.0; return PercentSamples >= SampleProfileHotThreshold; } -} /// Mark as used the sample record for the given function samples at /// (LineOffset, Discriminator). @@ -423,6 +479,7 @@ unsigned SampleProfileLoader::getOffset(const DILocation *DIL) const { 0xffff; } +#ifndef NDEBUG /// \brief Print the weight of edge \p E on stream \p OS. /// /// \param OS Stream to emit the output to. @@ -453,6 +510,7 @@ void SampleProfileLoader::printBlockWeight(raw_ostream &OS, uint64_t W = (I == BlockWeights.end() ? 0 : I->second); OS << "weight[" << BB->getName() << "]: " << W << "\n"; } +#endif /// \brief Get the weight for an instruction. /// @@ -480,10 +538,12 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { if (isa<BranchInst>(Inst) || isa<IntrinsicInst>(Inst)) return std::error_code(); - // If a call/invoke instruction is inlined in profile, but not inlined here, + // If a direct call/invoke instruction is inlined in profile + // (findCalleeFunctionSamples returns non-empty result), but not inlined here, // it means that the inlined callsite has no sample, thus the call // instruction should have 0 count. if ((isa<CallInst>(Inst) || isa<InvokeInst>(Inst)) && + !ImmutableCallSite(&Inst).isIndirectCall() && findCalleeFunctionSamples(Inst)) return 0; @@ -495,13 +555,18 @@ ErrorOr<uint64_t> SampleProfileLoader::getInstWeight(const Instruction &Inst) { bool FirstMark = CoverageTracker.markSamplesUsed(FS, LineOffset, Discriminator, R.get()); if (FirstMark) { - const Function *F = Inst.getParent()->getParent(); - LLVMContext &Ctx = F->getContext(); - emitOptimizationRemark( - Ctx, DEBUG_TYPE, *F, DLoc, - Twine("Applied ") + Twine(*R) + - " samples from profile (offset: " + Twine(LineOffset) + - ((Discriminator) ? Twine(".") + Twine(Discriminator) : "") + ")"); + ORE->emit([&]() { + OptimizationRemarkAnalysis Remark(DEBUG_TYPE, "AppliedSamples", &Inst); + Remark << "Applied " << ore::NV("NumSamples", *R); + Remark << " samples from profile (offset: "; + Remark << ore::NV("LineOffset", LineOffset); + if (Discriminator) { + Remark << "."; + Remark << ore::NV("Discriminator", Discriminator); + } + Remark << ")"; + return Remark; + }); } DEBUG(dbgs() << " " << DLoc.getLine() << "." << DIL->getBaseDiscriminator() << ":" << Inst @@ -588,10 +653,11 @@ SampleProfileLoader::findCalleeFunctionSamples(const Instruction &Inst) const { } /// Returns a vector of FunctionSamples that are the indirect call targets -/// of \p Inst. The vector is sorted by the total number of samples. +/// of \p Inst. The vector is sorted by the total number of samples. Stores +/// the total call count of the indirect call in \p Sum. std::vector<const FunctionSamples *> SampleProfileLoader::findIndirectCallFunctionSamples( - const Instruction &Inst) const { + const Instruction &Inst, uint64_t &Sum) const { const DILocation *DIL = Inst.getDebugLoc(); std::vector<const FunctionSamples *> R; @@ -603,16 +669,25 @@ SampleProfileLoader::findIndirectCallFunctionSamples( if (FS == nullptr) return R; + uint32_t LineOffset = getOffset(DIL); + uint32_t Discriminator = DIL->getBaseDiscriminator(); + + auto T = FS->findCallTargetMapAt(LineOffset, Discriminator); + Sum = 0; + if (T) + for (const auto &T_C : T.get()) + Sum += T_C.second; if (const FunctionSamplesMap *M = FS->findFunctionSamplesMapAt( LineLocation(getOffset(DIL), DIL->getBaseDiscriminator()))) { - if (M->size() == 0) + if (M->empty()) return R; for (const auto &NameFS : *M) { + Sum += NameFS.second.getEntrySamples(); R.push_back(&NameFS.second); } std::sort(R.begin(), R.end(), [](const FunctionSamples *L, const FunctionSamples *R) { - return L->getTotalSamples() > R->getTotalSamples(); + return L->getEntrySamples() > R->getEntrySamples(); }); } return R; @@ -650,6 +725,39 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { return FS; } +bool SampleProfileLoader::inlineCallInstruction(Instruction *I) { + assert(isa<CallInst>(I) || isa<InvokeInst>(I)); + CallSite CS(I); + Function *CalledFunction = CS.getCalledFunction(); + assert(CalledFunction); + DebugLoc DLoc = I->getDebugLoc(); + BasicBlock *BB = I->getParent(); + InlineParams Params = getInlineParams(); + Params.ComputeFullInlineCost = true; + // Checks if there is anything in the reachable portion of the callee at + // this callsite that makes this inlining potentially illegal. Need to + // set ComputeFullInlineCost, otherwise getInlineCost may return early + // when cost exceeds threshold without checking all IRs in the callee. + // The acutal cost does not matter because we only checks isNever() to + // see if it is legal to inline the callsite. + InlineCost Cost = getInlineCost(CS, Params, GetTTI(*CalledFunction), GetAC, + None, nullptr, nullptr); + if (Cost.isNever()) { + ORE->emit(OptimizationRemark(DEBUG_TYPE, "Not inline", DLoc, BB) + << "incompatible inlining"); + return false; + } + InlineFunctionInfo IFI(nullptr, &GetAC); + if (InlineFunction(CS, IFI)) { + // The call to InlineFunction erases I, so we can't pass it here. + ORE->emit(OptimizationRemark(DEBUG_TYPE, "HotInline", DLoc, BB) + << "inlined hot callee '" << ore::NV("Callee", CalledFunction) + << "' into '" << ore::NV("Caller", BB->getParent()) << "'"); + return true; + } + return false; +} + /// \brief Iteratively inline hot callsites of a function. /// /// Iteratively traverse all callsites of the function \p F, and find if @@ -659,17 +767,14 @@ SampleProfileLoader::findFunctionSamples(const Instruction &Inst) const { /// it to direct call. Each indirect call is limited with a single target. /// /// \param F function to perform iterative inlining. -/// \param ImportGUIDs a set to be updated to include all GUIDs that come -/// from a different module but inlined in the profiled binary. +/// \param InlinedGUIDs a set to be updated to include all GUIDs that are +/// inlined in the profiled binary. /// /// \returns True if there is any inline happened. bool SampleProfileLoader::inlineHotFunctions( - Function &F, DenseSet<GlobalValue::GUID> &ImportGUIDs) { + Function &F, DenseSet<GlobalValue::GUID> &InlinedGUIDs) { DenseSet<Instruction *> PromotedInsns; bool Changed = false; - LLVMContext &Ctx = F.getContext(); - std::function<AssumptionCache &(Function &)> GetAssumptionCache = [&]( - Function &F) -> AssumptionCache & { return ACT->getAssumptionCache(F); }; while (true) { bool LocalChanged = false; SmallVector<Instruction *, 10> CIS; @@ -690,57 +795,59 @@ bool SampleProfileLoader::inlineHotFunctions( } } for (auto I : CIS) { - InlineFunctionInfo IFI(nullptr, ACT ? &GetAssumptionCache : nullptr); Function *CalledFunction = CallSite(I).getCalledFunction(); // Do not inline recursive calls. if (CalledFunction == &F) continue; - Instruction *DI = I; - if (!CalledFunction && !PromotedInsns.count(I) && - CallSite(I).isIndirectCall()) - for (const auto *FS : findIndirectCallFunctionSamples(*I)) { + if (CallSite(I).isIndirectCall()) { + if (PromotedInsns.count(I)) + continue; + uint64_t Sum; + for (const auto *FS : findIndirectCallFunctionSamples(*I, Sum)) { + if (IsThinLTOPreLink) { + FS->findInlinedFunctions(InlinedGUIDs, F.getParent(), + Samples->getTotalSamples() * + SampleProfileHotThreshold / 100); + continue; + } auto CalleeFunctionName = FS->getName(); // If it is a recursive call, we do not inline it as it could bloat // the code exponentially. There is way to better handle this, e.g. // clone the caller first, and inline the cloned caller if it is - // recursive. As llvm does not inline recursive calls, we will simply - // ignore it instead of handling it explicitly. + // recursive. As llvm does not inline recursive calls, we will + // simply ignore it instead of handling it explicitly. if (CalleeFunctionName == F.getName()) continue; + const char *Reason = "Callee function not available"; auto R = SymbolMap.find(CalleeFunctionName); - if (R == SymbolMap.end()) - continue; - CalledFunction = R->getValue(); - if (CalledFunction && isLegalToPromote(I, CalledFunction, &Reason)) { - // The indirect target was promoted and inlined in the profile, as a - // result, we do not have profile info for the branch probability. - // We set the probability to 80% taken to indicate that the static - // call is likely taken. - DI = dyn_cast<Instruction>( - promoteIndirectCall(I, CalledFunction, 80, 100, false) - ->stripPointerCasts()); + if (R != SymbolMap.end() && R->getValue() && + !R->getValue()->isDeclaration() && + R->getValue()->getSubprogram() && + isLegalToPromote(CallSite(I), R->getValue(), &Reason)) { + uint64_t C = FS->getEntrySamples(); + Instruction *DI = + pgo::promoteIndirectCall(I, R->getValue(), C, Sum, false, ORE); + Sum -= C; PromotedInsns.insert(I); + // If profile mismatches, we should not attempt to inline DI. + if ((isa<CallInst>(DI) || isa<InvokeInst>(DI)) && + inlineCallInstruction(DI)) + LocalChanged = true; } else { - DEBUG(dbgs() << "\nFailed to promote indirect call to " - << CalleeFunctionName << " because " << Reason - << "\n"); - continue; + DEBUG(dbgs() + << "\nFailed to promote indirect call to " + << CalleeFunctionName << " because " << Reason << "\n"); } } - if (!CalledFunction || !CalledFunction->getSubprogram()) { - findCalleeFunctionSamples(*I)->findImportedFunctions( - ImportGUIDs, F.getParent(), + } else if (CalledFunction && CalledFunction->getSubprogram() && + !CalledFunction->isDeclaration()) { + if (inlineCallInstruction(I)) + LocalChanged = true; + } else if (IsThinLTOPreLink) { + findCalleeFunctionSamples(*I)->findInlinedFunctions( + InlinedGUIDs, F.getParent(), Samples->getTotalSamples() * SampleProfileHotThreshold / 100); - continue; - } - DebugLoc DLoc = I->getDebugLoc(); - if (InlineFunction(CallSite(DI), IFI)) { - LocalChanged = true; - emitOptimizationRemark(Ctx, DEBUG_TYPE, F, DLoc, - Twine("inlined hot callee '") + - CalledFunction->getName() + "' into '" + - F.getName() + "'"); } } if (LocalChanged) { @@ -1076,24 +1183,20 @@ void SampleProfileLoader::buildEdges(Function &F) { } } -/// Sorts the CallTargetMap \p M by count in descending order and stores the -/// sorted result in \p Sorted. Returns the total counts. -static uint64_t SortCallTargets(SmallVector<InstrProfValueData, 2> &Sorted, - const SampleRecord::CallTargetMap &M) { - Sorted.clear(); - uint64_t Sum = 0; - for (auto I = M.begin(); I != M.end(); ++I) { - Sum += I->getValue(); - Sorted.push_back({Function::getGUID(I->getKey()), I->getValue()}); - } - std::sort(Sorted.begin(), Sorted.end(), +/// Returns the sorted CallTargetMap \p M by count in descending order. +static SmallVector<InstrProfValueData, 2> SortCallTargets( + const SampleRecord::CallTargetMap &M) { + SmallVector<InstrProfValueData, 2> R; + for (auto I = M.begin(); I != M.end(); ++I) + R.push_back({Function::getGUID(I->getKey()), I->getValue()}); + std::sort(R.begin(), R.end(), [](const InstrProfValueData &L, const InstrProfValueData &R) { if (L.Count == R.Count) return L.Value > R.Value; else return L.Count > R.Count; }); - return Sum; + return R; } /// \brief Propagate weights into edges @@ -1184,10 +1287,12 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (!FS) continue; auto T = FS->findCallTargetMapAt(LineOffset, Discriminator); - if (!T || T.get().size() == 0) + if (!T || T.get().empty()) continue; - SmallVector<InstrProfValueData, 2> SortedCallTargets; - uint64_t Sum = SortCallTargets(SortedCallTargets, T.get()); + SmallVector<InstrProfValueData, 2> SortedCallTargets = + SortCallTargets(T.get()); + uint64_t Sum; + findIndirectCallFunctionSamples(I, Sum); annotateValueSite(*I.getParent()->getParent()->getParent(), I, SortedCallTargets, Sum, IPVK_IndirectCallTarget, SortedCallTargets.size()); @@ -1211,7 +1316,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { << ".\n"); SmallVector<uint32_t, 4> Weights; uint32_t MaxWeight = 0; - DebugLoc MaxDestLoc; + Instruction *MaxDestInst; for (unsigned I = 0; I < TI->getNumSuccessors(); ++I) { BasicBlock *Succ = TI->getSuccessor(I); Edge E = std::make_pair(BB, Succ); @@ -1230,7 +1335,7 @@ void SampleProfileLoader::propagateWeights(Function &F) { if (Weight != 0) { if (Weight > MaxWeight) { MaxWeight = Weight; - MaxDestLoc = Succ->getFirstNonPHIOrDbgOrLifetime()->getDebugLoc(); + MaxDestInst = Succ->getFirstNonPHIOrDbgOrLifetime(); } } } @@ -1243,15 +1348,13 @@ void SampleProfileLoader::propagateWeights(Function &F) { // weights, the second pass does not need to set it. if (MaxWeight > 0 && !TI->extractProfTotalWeight(TempWeight)) { DEBUG(dbgs() << "SUCCESS. Found non-zero weights.\n"); - TI->setMetadata(llvm::LLVMContext::MD_prof, + TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); - emitOptimizationRemark( - Ctx, DEBUG_TYPE, F, MaxDestLoc, - Twine("most popular destination for conditional branches at ") + - ((BranchLoc) ? Twine(BranchLoc->getFilename() + ":" + - Twine(BranchLoc.getLine()) + ":" + - Twine(BranchLoc.getCol())) - : Twine("<UNKNOWN LOCATION>"))); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "PopularDest", MaxDestInst) + << "most popular destination for conditional branches at " + << ore::NV("CondBranchesLoc", BranchLoc); + }); } else { DEBUG(dbgs() << "SKIPPED. All branch weights are zero.\n"); } @@ -1351,18 +1454,19 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { DEBUG(dbgs() << "Line number for the first instruction in " << F.getName() << ": " << getFunctionLoc(F) << "\n"); - DenseSet<GlobalValue::GUID> ImportGUIDs; - Changed |= inlineHotFunctions(F, ImportGUIDs); + DenseSet<GlobalValue::GUID> InlinedGUIDs; + Changed |= inlineHotFunctions(F, InlinedGUIDs); // Compute basic block weights. Changed |= computeBlockWeights(F); if (Changed) { // Add an entry count to the function using the samples gathered at the - // function entry. Also sets the GUIDs that comes from a different - // module but inlined in the profiled binary. This is aiming at making - // the IR match the profiled binary before annotation. - F.setEntryCount(Samples->getHeadSamples() + 1, &ImportGUIDs); + // function entry. + // Sets the GUIDs that are inlined in the profiled binary. This is used + // for ThinLink to make correct liveness analysis, and also make the IR + // match the profiled binary before annotation. + F.setEntryCount(Samples->getHeadSamples() + 1, &InlinedGUIDs); // Compute dominance and loop info needed for propagation. computeDominanceAndLoopInfo(F); @@ -1404,9 +1508,11 @@ bool SampleProfileLoader::emitAnnotations(Function &F) { } char SampleProfileLoaderLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(SampleProfileLoaderLegacyPass, "sample-profile", "Sample Profile loader", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(SampleProfileLoaderLegacyPass, "sample-profile", "Sample Profile loader", false, false) @@ -1431,7 +1537,7 @@ ModulePass *llvm::createSampleProfileLoaderPass(StringRef Name) { return new SampleProfileLoaderLegacyPass(Name); } -bool SampleProfileLoader::runOnModule(Module &M) { +bool SampleProfileLoader::runOnModule(Module &M, ModuleAnalysisManager *AM) { if (!ProfileIsValid) return false; @@ -1463,7 +1569,7 @@ bool SampleProfileLoader::runOnModule(Module &M) { for (auto &F : M) if (!F.isDeclaration()) { clearFunctionData(); - retval |= runOnFunction(F); + retval |= runOnFunction(F, AM); } if (M.getProfileSummary() == nullptr) M.setProfileSummary(Reader->getSummary().getMD(M.getContext())); @@ -1471,13 +1577,23 @@ bool SampleProfileLoader::runOnModule(Module &M) { } bool SampleProfileLoaderLegacyPass::runOnModule(Module &M) { - // FIXME: pass in AssumptionCache correctly for the new pass manager. - SampleLoader.setACT(&getAnalysis<AssumptionCacheTracker>()); - return SampleLoader.runOnModule(M); + ACT = &getAnalysis<AssumptionCacheTracker>(); + TTIWP = &getAnalysis<TargetTransformInfoWrapperPass>(); + return SampleLoader.runOnModule(M, nullptr); } -bool SampleProfileLoader::runOnFunction(Function &F) { +bool SampleProfileLoader::runOnFunction(Function &F, ModuleAnalysisManager *AM) { F.setEntryCount(0); + std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; + if (AM) { + auto &FAM = + AM->getResult<FunctionAnalysisManagerModuleProxy>(*F.getParent()) + .getManager(); + ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + } else { + OwnedORE = make_unique<OptimizationRemarkEmitter>(&F); + ORE = OwnedORE.get(); + } Samples = Reader->getSamplesFor(F); if (Samples && !Samples->empty()) return emitAnnotations(F); @@ -1486,13 +1602,23 @@ bool SampleProfileLoader::runOnFunction(Function &F) { PreservedAnalyses SampleProfileLoaderPass::run(Module &M, ModuleAnalysisManager &AM) { + FunctionAnalysisManager &FAM = + AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + + auto GetAssumptionCache = [&](Function &F) -> AssumptionCache & { + return FAM.getResult<AssumptionAnalysis>(F); + }; + auto GetTTI = [&](Function &F) -> TargetTransformInfo & { + return FAM.getResult<TargetIRAnalysis>(F); + }; SampleProfileLoader SampleLoader( - ProfileFileName.empty() ? SampleProfileFile : ProfileFileName); + ProfileFileName.empty() ? SampleProfileFile : ProfileFileName, + IsThinLTOPreLink, GetAssumptionCache, GetTTI); SampleLoader.doInitialization(M); - if (!SampleLoader.runOnModule(M)) + if (!SampleLoader.runOnModule(M, &AM)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp index 8ef6bb652309..945133074059 100644 --- a/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp +++ b/lib/Transforms/IPO/ThinLTOBitcodeWriter.cpp @@ -19,7 +19,6 @@ #include "llvm/IR/Module.h" #include "llvm/IR/PassManager.h" #include "llvm/Pass.h" -#include "llvm/Support/FileSystem.h" #include "llvm/Support/ScopedPrinter.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" @@ -40,9 +39,17 @@ void promoteInternals(Module &ExportM, Module &ImportM, StringRef ModuleId, continue; auto Name = ExportGV.getName(); - GlobalValue *ImportGV = ImportM.getNamedValue(Name); - if ((!ImportGV || ImportGV->use_empty()) && !PromoteExtra.count(&ExportGV)) - continue; + GlobalValue *ImportGV = nullptr; + if (!PromoteExtra.count(&ExportGV)) { + ImportGV = ImportM.getNamedValue(Name); + if (!ImportGV) + continue; + ImportGV->removeDeadConstantUsers(); + if (ImportGV->use_empty()) { + ImportGV->eraseFromParent(); + continue; + } + } std::string NewName = (Name + ModuleId).str(); @@ -141,7 +148,9 @@ void simplifyExternals(Module &M) { continue; } - if (!F.isDeclaration() || F.getFunctionType() == EmptyFT) + if (!F.isDeclaration() || F.getFunctionType() == EmptyFT || + // Changing the type of an intrinsic may invalidate the IR. + F.getName().startswith("llvm.")) continue; Function *NewF = @@ -376,15 +385,14 @@ void splitAndWriteThinLTOBitcode( W.writeStrtab(); OS << Buffer; - // If a minimized bitcode module was requested for the thin link, - // strip the debug info (the merged module was already stripped above) - // and write it to the given OS. + // If a minimized bitcode module was requested for the thin link, only + // the information that is needed by thin link will be written in the + // given OS (the merged module will be written as usual). if (ThinLinkOS) { Buffer.clear(); BitcodeWriter W2(Buffer); StripDebugInfo(M); - W2.writeModule(&M, /*ShouldPreserveUseListOrder=*/false, &Index, - /*GenerateHash=*/false, &ModHash); + W2.writeThinLinkBitcode(&M, Index, ModHash); W2.writeModule(MergedM.get(), /*ShouldPreserveUseListOrder=*/false, &MergedMIndex); W2.writeSymtab(); @@ -420,14 +428,11 @@ void writeThinLTOBitcode(raw_ostream &OS, raw_ostream *ThinLinkOS, ModuleHash ModHash = {{0}}; WriteBitcodeToFile(&M, OS, /*ShouldPreserveUseListOrder=*/false, Index, /*GenerateHash=*/true, &ModHash); - // If a minimized bitcode module was requested for the thin link, - // strip the debug info and write it to the given OS. - if (ThinLinkOS) { - StripDebugInfo(M); - WriteBitcodeToFile(&M, *ThinLinkOS, /*ShouldPreserveUseListOrder=*/false, - Index, - /*GenerateHash=*/false, &ModHash); - } + // If a minimized bitcode module was requested for the thin link, only + // the information that is needed by thin link will be written in the + // given OS. + if (ThinLinkOS && Index) + WriteThinLinkBitcodeToFile(&M, *ThinLinkOS, *Index, ModHash); } class WriteThinLTOBitcode : public ModulePass { diff --git a/lib/Transforms/IPO/WholeProgramDevirt.cpp b/lib/Transforms/IPO/WholeProgramDevirt.cpp index 00769cd63229..ec56f0cde25d 100644 --- a/lib/Transforms/IPO/WholeProgramDevirt.cpp +++ b/lib/Transforms/IPO/WholeProgramDevirt.cpp @@ -51,14 +51,13 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TypeMetadataUtils.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalVariable.h" @@ -275,18 +274,39 @@ struct VirtualCallSite { // of that field for details. unsigned *NumUnsafeUses; - void emitRemark(const Twine &OptName, const Twine &TargetName) { + void + emitRemark(const StringRef OptName, const StringRef TargetName, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { Function *F = CS.getCaller(); - emitOptimizationRemark( - F->getContext(), DEBUG_TYPE, *F, - CS.getInstruction()->getDebugLoc(), - OptName + ": devirtualized a call to " + TargetName); + DebugLoc DLoc = CS->getDebugLoc(); + BasicBlock *Block = CS.getParent(); + + // In the new pass manager, we can request the optimization + // remark emitter pass on a per-function-basis, which the + // OREGetter will do for us. + // In the old pass manager, this is harder, so we just build + // a optimization remark emitter on the fly, when we need it. + std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; + OptimizationRemarkEmitter *ORE; + if (OREGetter) + ORE = &OREGetter(F); + else { + OwnedORE = make_unique<OptimizationRemarkEmitter>(F); + ORE = OwnedORE.get(); + } + + using namespace ore; + ORE->emit(OptimizationRemark(DEBUG_TYPE, OptName, DLoc, Block) + << NV("Optimization", OptName) << ": devirtualized a call to " + << NV("FunctionName", TargetName)); } - void replaceAndErase(const Twine &OptName, const Twine &TargetName, - bool RemarksEnabled, Value *New) { + void replaceAndErase( + const StringRef OptName, const StringRef TargetName, bool RemarksEnabled, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, + Value *New) { if (RemarksEnabled) - emitRemark(OptName, TargetName); + emitRemark(OptName, TargetName, OREGetter); CS->replaceAllUsesWith(New); if (auto II = dyn_cast<InvokeInst>(CS.getInstruction())) { BranchInst::Create(II->getNormalDest(), CS.getInstruction()); @@ -383,6 +403,7 @@ struct DevirtModule { IntegerType *IntPtrTy; bool RemarksEnabled; + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter; MapVector<VTableSlot, VTableSlotInfo> CallSlots; @@ -397,6 +418,7 @@ struct DevirtModule { std::map<CallInst *, unsigned> NumUnsafeUsesForTypeTest; DevirtModule(Module &M, function_ref<AAResults &(Function &)> AARGetter, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter, ModuleSummaryIndex *ExportSummary, const ModuleSummaryIndex *ImportSummary) : M(M), AARGetter(AARGetter), ExportSummary(ExportSummary), @@ -405,7 +427,7 @@ struct DevirtModule { Int32Ty(Type::getInt32Ty(M.getContext())), Int64Ty(Type::getInt64Ty(M.getContext())), IntPtrTy(M.getDataLayout().getIntPtrType(M.getContext(), 0)), - RemarksEnabled(areRemarksEnabled()) { + RemarksEnabled(areRemarksEnabled()), OREGetter(OREGetter) { assert(!(ExportSummary && ImportSummary)); } @@ -444,16 +466,23 @@ struct DevirtModule { std::string getGlobalName(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name); + bool shouldExportConstantsAsAbsoluteSymbols(); + // This function is called during the export phase to create a symbol // definition containing information about the given vtable slot and list of // arguments. void exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, Constant *C); + void exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, + uint32_t Const, uint32_t &Storage); // This function is called during the import phase to create a reference to // the symbol definition created during the export phase. Constant *importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, - StringRef Name, unsigned AbsWidth = 0); + StringRef Name); + Constant *importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, IntegerType *IntTy, + uint32_t Storage); void applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, bool IsOne, Constant *UniqueMemberAddr); @@ -482,8 +511,9 @@ struct DevirtModule { // Lower the module using the action and summary passed as command line // arguments. For testing purposes only. - static bool runForTesting(Module &M, - function_ref<AAResults &(Function &)> AARGetter); + static bool runForTesting( + Module &M, function_ref<AAResults &(Function &)> AARGetter, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter); }; struct WholeProgramDevirt : public ModulePass { @@ -508,9 +538,14 @@ struct WholeProgramDevirt : public ModulePass { bool runOnModule(Module &M) override { if (skipModule(M)) return false; + + auto OREGetter = function_ref<OptimizationRemarkEmitter &(Function *)>(); + if (UseCommandLine) - return DevirtModule::runForTesting(M, LegacyAARGetter(*this)); - return DevirtModule(M, LegacyAARGetter(*this), ExportSummary, ImportSummary) + return DevirtModule::runForTesting(M, LegacyAARGetter(*this), OREGetter); + + return DevirtModule(M, LegacyAARGetter(*this), OREGetter, ExportSummary, + ImportSummary) .run(); } @@ -542,13 +577,17 @@ PreservedAnalyses WholeProgramDevirtPass::run(Module &M, auto AARGetter = [&](Function &F) -> AAResults & { return FAM.getResult<AAManager>(F); }; - if (!DevirtModule(M, AARGetter, nullptr, nullptr).run()) + auto OREGetter = [&](Function *F) -> OptimizationRemarkEmitter & { + return FAM.getResult<OptimizationRemarkEmitterAnalysis>(*F); + }; + if (!DevirtModule(M, AARGetter, OREGetter, nullptr, nullptr).run()) return PreservedAnalyses::all(); return PreservedAnalyses::none(); } bool DevirtModule::runForTesting( - Module &M, function_ref<AAResults &(Function &)> AARGetter) { + Module &M, function_ref<AAResults &(Function &)> AARGetter, + function_ref<OptimizationRemarkEmitter &(Function *)> OREGetter) { ModuleSummaryIndex Summary; // Handle the command-line summary arguments. This code is for testing @@ -566,7 +605,7 @@ bool DevirtModule::runForTesting( bool Changed = DevirtModule( - M, AARGetter, + M, AARGetter, OREGetter, ClSummaryAction == PassSummaryAction::Export ? &Summary : nullptr, ClSummaryAction == PassSummaryAction::Import ? &Summary : nullptr) .run(); @@ -684,7 +723,7 @@ void DevirtModule::applySingleImplDevirt(VTableSlotInfo &SlotInfo, auto Apply = [&](CallSiteInfo &CSInfo) { for (auto &&VCallSite : CSInfo.CallSites) { if (RemarksEnabled) - VCallSite.emitRemark("single-impl", TheFn->getName()); + VCallSite.emitRemark("single-impl", TheFn->getName(), OREGetter); VCallSite.CS.setCalledFunction(ConstantExpr::getBitCast( TheFn, VCallSite.CS.getCalledValue()->getType())); // This use is no longer unsafe. @@ -724,9 +763,24 @@ bool DevirtModule::trySingleImplDevirt( // to make it visible to thin LTO objects. We can only get here during the // ThinLTO export phase. if (TheFn->hasLocalLinkage()) { + std::string NewName = (TheFn->getName() + "$merged").str(); + + // Since we are renaming the function, any comdats with the same name must + // also be renamed. This is required when targeting COFF, as the comdat name + // must match one of the names of the symbols in the comdat. + if (Comdat *C = TheFn->getComdat()) { + if (C->getName() == TheFn->getName()) { + Comdat *NewC = M.getOrInsertComdat(NewName); + NewC->setSelectionKind(C->getSelectionKind()); + for (GlobalObject &GO : M.global_objects()) + if (GO.getComdat() == C) + GO.setComdat(NewC); + } + } + TheFn->setLinkage(GlobalValue::ExternalLinkage); TheFn->setVisibility(GlobalValue::HiddenVisibility); - TheFn->setName(TheFn->getName() + "$merged"); + TheFn->setName(NewName); } Res->TheKind = WholeProgramDevirtResolution::SingleImpl; @@ -769,7 +823,7 @@ void DevirtModule::applyUniformRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, uint64_t TheRetVal) { for (auto Call : CSInfo.CallSites) Call.replaceAndErase( - "uniform-ret-val", FnName, RemarksEnabled, + "uniform-ret-val", FnName, RemarksEnabled, OREGetter, ConstantInt::get(cast<IntegerType>(Call.CS.getType()), TheRetVal)); CSInfo.markDevirt(); } @@ -808,6 +862,12 @@ std::string DevirtModule::getGlobalName(VTableSlot Slot, return OS.str(); } +bool DevirtModule::shouldExportConstantsAsAbsoluteSymbols() { + Triple T(M.getTargetTriple()); + return (T.getArch() == Triple::x86 || T.getArch() == Triple::x86_64) && + T.getObjectFormat() == Triple::ELF; +} + void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, StringRef Name, Constant *C) { GlobalAlias *GA = GlobalAlias::create(Int8Ty, 0, GlobalValue::ExternalLinkage, @@ -815,27 +875,55 @@ void DevirtModule::exportGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, GA->setVisibility(GlobalValue::HiddenVisibility); } +void DevirtModule::exportConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, uint32_t Const, + uint32_t &Storage) { + if (shouldExportConstantsAsAbsoluteSymbols()) { + exportGlobal( + Slot, Args, Name, + ConstantExpr::getIntToPtr(ConstantInt::get(Int32Ty, Const), Int8PtrTy)); + return; + } + + Storage = Const; +} + Constant *DevirtModule::importGlobal(VTableSlot Slot, ArrayRef<uint64_t> Args, - StringRef Name, unsigned AbsWidth) { + StringRef Name) { Constant *C = M.getOrInsertGlobal(getGlobalName(Slot, Args, Name), Int8Ty); auto *GV = dyn_cast<GlobalVariable>(C); + if (GV) + GV->setVisibility(GlobalValue::HiddenVisibility); + return C; +} + +Constant *DevirtModule::importConstant(VTableSlot Slot, ArrayRef<uint64_t> Args, + StringRef Name, IntegerType *IntTy, + uint32_t Storage) { + if (!shouldExportConstantsAsAbsoluteSymbols()) + return ConstantInt::get(IntTy, Storage); + + Constant *C = importGlobal(Slot, Args, Name); + auto *GV = cast<GlobalVariable>(C->stripPointerCasts()); + C = ConstantExpr::getPtrToInt(C, IntTy); + // We only need to set metadata if the global is newly created, in which // case it would not have hidden visibility. - if (!GV || GV->getVisibility() == GlobalValue::HiddenVisibility) + if (GV->getMetadata(LLVMContext::MD_absolute_symbol)) return C; - GV->setVisibility(GlobalValue::HiddenVisibility); auto SetAbsRange = [&](uint64_t Min, uint64_t Max) { auto *MinC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Min)); auto *MaxC = ConstantAsMetadata::get(ConstantInt::get(IntPtrTy, Max)); GV->setMetadata(LLVMContext::MD_absolute_symbol, MDNode::get(M.getContext(), {MinC, MaxC})); }; + unsigned AbsWidth = IntTy->getBitWidth(); if (AbsWidth == IntPtrTy->getBitWidth()) SetAbsRange(~0ull, ~0ull); // Full set. - else if (AbsWidth) + else SetAbsRange(0, 1ull << AbsWidth); - return GV; + return C; } void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, @@ -843,10 +931,12 @@ void DevirtModule::applyUniqueRetValOpt(CallSiteInfo &CSInfo, StringRef FnName, Constant *UniqueMemberAddr) { for (auto &&Call : CSInfo.CallSites) { IRBuilder<> B(Call.CS.getInstruction()); - Value *Cmp = B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, - Call.VTable, UniqueMemberAddr); + Value *Cmp = + B.CreateICmp(IsOne ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + B.CreateBitCast(Call.VTable, Int8PtrTy), UniqueMemberAddr); Cmp = B.CreateZExt(Cmp, Call.CS->getType()); - Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, Cmp); + Call.replaceAndErase("unique-ret-val", FnName, RemarksEnabled, OREGetter, + Cmp); } CSInfo.markDevirt(); } @@ -909,17 +999,19 @@ void DevirtModule::applyVirtualConstProp(CallSiteInfo &CSInfo, StringRef FnName, for (auto Call : CSInfo.CallSites) { auto *RetType = cast<IntegerType>(Call.CS.getType()); IRBuilder<> B(Call.CS.getInstruction()); - Value *Addr = B.CreateGEP(Int8Ty, Call.VTable, Byte); + Value *Addr = + B.CreateGEP(Int8Ty, B.CreateBitCast(Call.VTable, Int8PtrTy), Byte); if (RetType->getBitWidth() == 1) { Value *Bits = B.CreateLoad(Addr); Value *BitsAndBit = B.CreateAnd(Bits, Bit); auto IsBitSet = B.CreateICmpNE(BitsAndBit, ConstantInt::get(Int8Ty, 0)); Call.replaceAndErase("virtual-const-prop-1-bit", FnName, RemarksEnabled, - IsBitSet); + OREGetter, IsBitSet); } else { Value *ValAddr = B.CreateBitCast(Addr, RetType->getPointerTo()); Value *Val = B.CreateLoad(RetType, ValAddr); - Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, Val); + Call.replaceAndErase("virtual-const-prop", FnName, RemarksEnabled, + OREGetter, Val); } } CSInfo.markDevirt(); @@ -1007,18 +1099,18 @@ bool DevirtModule::tryVirtualConstProp( for (auto &&Target : TargetsForSlot) Target.WasDevirt = true; - Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); - Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); if (CSByConstantArg.second.isExported()) { ResByArg->TheKind = WholeProgramDevirtResolution::ByArg::VirtualConstProp; - exportGlobal(Slot, CSByConstantArg.first, "byte", - ConstantExpr::getIntToPtr(ByteConst, Int8PtrTy)); - exportGlobal(Slot, CSByConstantArg.first, "bit", - ConstantExpr::getIntToPtr(BitConst, Int8PtrTy)); + exportConstant(Slot, CSByConstantArg.first, "byte", OffsetByte, + ResByArg->Byte); + exportConstant(Slot, CSByConstantArg.first, "bit", 1ULL << OffsetBit, + ResByArg->Bit); } // Rewrite each call to a load from OffsetByte/OffsetBit. + Constant *ByteConst = ConstantInt::get(Int32Ty, OffsetByte); + Constant *BitConst = ConstantInt::get(Int8Ty, 1ULL << OffsetBit); applyVirtualConstProp(CSByConstantArg.second, TargetsForSlot[0].Fn->getName(), ByteConst, BitConst); } @@ -1112,8 +1204,7 @@ void DevirtModule::scanTypeTestUsers(Function *TypeTestFunc, Value *Ptr = CI->getArgOperand(0)->stripPointerCasts(); if (SeenPtrs.insert(Ptr).second) { for (DevirtCallSite Call : DevirtCalls) { - CallSlots[{TypeId, Call.Offset}].addCallSite(CI->getArgOperand(0), - Call.CS, nullptr); + CallSlots[{TypeId, Call.Offset}].addCallSite(Ptr, Call.CS, nullptr); } } } @@ -1250,10 +1341,10 @@ void DevirtModule::importResolution(VTableSlot Slot, VTableSlotInfo &SlotInfo) { break; } case WholeProgramDevirtResolution::ByArg::VirtualConstProp: { - Constant *Byte = importGlobal(Slot, CSByConstantArg.first, "byte", 32); - Byte = ConstantExpr::getPtrToInt(Byte, Int32Ty); - Constant *Bit = importGlobal(Slot, CSByConstantArg.first, "bit", 8); - Bit = ConstantExpr::getPtrToInt(Bit, Int8Ty); + Constant *Byte = importConstant(Slot, CSByConstantArg.first, "byte", + Int32Ty, ResByArg.Byte); + Constant *Bit = importConstant(Slot, CSByConstantArg.first, "bit", Int8Ty, + ResByArg.Bit); applyVirtualConstProp(CSByConstantArg.second, "", Byte, Bit); } default: @@ -1406,9 +1497,24 @@ bool DevirtModule::run() { // Generate remarks for each devirtualized function. for (const auto &DT : DevirtTargets) { Function *F = DT.second; - DISubprogram *SP = F->getSubprogram(); - emitOptimizationRemark(F->getContext(), DEBUG_TYPE, *F, SP, - Twine("devirtualized ") + F->getName()); + + // In the new pass manager, we can request the optimization + // remark emitter pass on a per-function-basis, which the + // OREGetter will do for us. + // In the old pass manager, this is harder, so we just build + // a optimization remark emitter on the fly, when we need it. + std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; + OptimizationRemarkEmitter *ORE; + if (OREGetter) + ORE = &OREGetter(F); + else { + OwnedORE = make_unique<OptimizationRemarkEmitter>(F); + ORE = OwnedORE.get(); + } + + using namespace ore; + ORE->emit(OptimizationRemark(DEBUG_TYPE, "Devirtualized", F) + << "devirtualized " << NV("FunctionName", F->getName())); } } diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp index 809471cfd74f..688897644848 100644 --- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp +++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp @@ -12,12 +12,26 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/IR/DataLayout.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/AlignOf.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/KnownBits.h" +#include <cassert> +#include <utility> using namespace llvm; using namespace PatternMatch; @@ -39,10 +53,15 @@ namespace { // is expensive. In order to avoid the cost of the constructor, we should // reuse some instances whenever possible. The pre-created instances // FAddCombine::Add[0-5] embodies this idea. - // - FAddendCoef() : IsFp(false), BufHasFpVal(false), IntVal(0) {} + FAddendCoef() = default; ~FAddendCoef(); + // If possible, don't define operator+/operator- etc because these + // operators inevitably call FAddendCoef's constructor which is not cheap. + void operator=(const FAddendCoef &A); + void operator+=(const FAddendCoef &A); + void operator*=(const FAddendCoef &S); + void set(short C) { assert(!insaneIntVal(C) && "Insane coefficient"); IsFp = false; IntVal = C; @@ -55,12 +74,6 @@ namespace { bool isZero() const { return isInt() ? !IntVal : getFpVal().isZero(); } Value *getValue(Type *) const; - // If possible, don't define operator+/operator- etc because these - // operators inevitably call FAddendCoef's constructor which is not cheap. - void operator=(const FAddendCoef &A); - void operator+=(const FAddendCoef &A); - void operator*=(const FAddendCoef &S); - bool isOne() const { return isInt() && IntVal == 1; } bool isTwo() const { return isInt() && IntVal == 2; } bool isMinusOne() const { return isInt() && IntVal == -1; } @@ -68,10 +81,12 @@ namespace { private: bool insaneIntVal(int V) { return V > 4 || V < -4; } + APFloat *getFpValPtr() - { return reinterpret_cast<APFloat*>(&FpValBuf.buffer[0]); } + { return reinterpret_cast<APFloat *>(&FpValBuf.buffer[0]); } + const APFloat *getFpValPtr() const - { return reinterpret_cast<const APFloat*>(&FpValBuf.buffer[0]); } + { return reinterpret_cast<const APFloat *>(&FpValBuf.buffer[0]); } const APFloat &getFpVal() const { assert(IsFp && BufHasFpVal && "Incorret state"); @@ -94,17 +109,16 @@ namespace { // from an *SIGNED* integer. APFloat createAPFloatFromInt(const fltSemantics &Sem, int Val); - private: - bool IsFp; + bool IsFp = false; // True iff FpValBuf contains an instance of APFloat. - bool BufHasFpVal; + bool BufHasFpVal = false; // The integer coefficient of an individual addend is either 1 or -1, // and we try to simplify at most 4 addends from neighboring at most // two instructions. So the range of <IntVal> falls in [-4, 4]. APInt // is overkill of this end. - short IntVal; + short IntVal = 0; AlignedCharArrayUnion<APFloat> FpValBuf; }; @@ -112,10 +126,14 @@ namespace { /// FAddend is used to represent floating-point addend. An addend is /// represented as <C, V>, where the V is a symbolic value, and C is a /// constant coefficient. A constant addend is represented as <C, 0>. - /// class FAddend { public: - FAddend() : Val(nullptr) {} + FAddend() = default; + + void operator+=(const FAddend &T) { + assert((Val == T.Val) && "Symbolic-values disagree"); + Coeff += T.Coeff; + } Value *getSymVal() const { return Val; } const FAddendCoef &getCoef() const { return Coeff; } @@ -146,16 +164,11 @@ namespace { /// splitted is the addend itself. unsigned drillAddendDownOneStep(FAddend &Addend0, FAddend &Addend1) const; - void operator+=(const FAddend &T) { - assert((Val == T.Val) && "Symbolic-values disagree"); - Coeff += T.Coeff; - } - private: void Scale(const FAddendCoef& ScaleAmt) { Coeff *= ScaleAmt; } // This addend has the value of "Coeff * Val". - Value *Val; + Value *Val = nullptr; FAddendCoef Coeff; }; @@ -164,11 +177,12 @@ namespace { /// class FAddCombine { public: - FAddCombine(InstCombiner::BuilderTy &B) : Builder(B), Instr(nullptr) {} + FAddCombine(InstCombiner::BuilderTy &B) : Builder(B) {} + Value *simplify(Instruction *FAdd); private: - typedef SmallVector<const FAddend*, 4> AddendVect; + using AddendVect = SmallVector<const FAddend *, 4>; Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota); @@ -179,6 +193,7 @@ namespace { /// Return the number of instructions needed to emit the N-ary addition. unsigned calcInstrNumber(const AddendVect& Vect); + Value *createFSub(Value *Opnd0, Value *Opnd1); Value *createFAdd(Value *Opnd0, Value *Opnd1); Value *createFMul(Value *Opnd0, Value *Opnd1); @@ -187,9 +202,6 @@ namespace { Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota); void createInstPostProc(Instruction *NewInst, bool NoNumber = false); - InstCombiner::BuilderTy &Builder; - Instruction *Instr; - // Debugging stuff are clustered here. #ifndef NDEBUG unsigned CreateInstrNum; @@ -199,9 +211,12 @@ namespace { void initCreateInstNum() {} void incCreateInstNum() {} #endif + + InstCombiner::BuilderTy &Builder; + Instruction *Instr = nullptr; }; -} // anonymous namespace +} // end anonymous namespace //===----------------------------------------------------------------------===// // @@ -332,7 +347,6 @@ Value *FAddendCoef::getValue(Type *Ty) const { // 0 +/- 0 <0, NULL> (corner case) // // Legend: A and B are not constant, C is constant -// unsigned FAddend::drillValueDownOneStep (Value *Val, FAddend &Addend0, FAddend &Addend1) { Instruction *I = nullptr; @@ -396,7 +410,6 @@ unsigned FAddend::drillValueDownOneStep // Try to break *this* addend into two addends. e.g. Suppose this addend is // <2.3, V>, and V = X + Y, by calling this function, we obtain two addends, // i.e. <2.3, X> and <2.3, Y>. -// unsigned FAddend::drillAddendDownOneStep (FAddend &Addend0, FAddend &Addend1) const { if (isConstant()) @@ -421,7 +434,6 @@ unsigned FAddend::drillAddendDownOneStep // ------------------------------------------------------- // (x * y) +/- (x * z) x * (y +/- z) // (y / x) +/- (z / x) (y +/- z) / x -// Value *FAddCombine::performFactorization(Instruction *I) { assert((I->getOpcode() == Instruction::FAdd || I->getOpcode() == Instruction::FSub) && "Expect add/sub"); @@ -447,7 +459,6 @@ Value *FAddCombine::performFactorization(Instruction *I) { // ---------------------------------------------- // (x*y) +/- (x*z) x y z // (y/x) +/- (z/x) x y z - // Value *Factor = nullptr; Value *AddSub0 = nullptr, *AddSub1 = nullptr; @@ -471,7 +482,7 @@ Value *FAddCombine::performFactorization(Instruction *I) { return nullptr; FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); if (I0) Flags &= I->getFastMathFlags(); if (I1) Flags &= I->getFastMathFlags(); @@ -500,7 +511,7 @@ Value *FAddCombine::performFactorization(Instruction *I) { } Value *FAddCombine::simplify(Instruction *I) { - assert(I->hasUnsafeAlgebra() && "Should be in unsafe mode"); + assert(I->isFast() && "Expected 'fast' instruction"); // Currently we are not able to handle vector type. if (I->getType()->isVectorTy()) @@ -599,7 +610,6 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { // desirable to reside at the top of the resulting expression tree. Placing // constant close to supper-expr(s) will potentially reveal some optimization // opportunities in super-expr(s). - // const FAddend *ConstAdd = nullptr; // Simplified addends are placed <SimpVect>. @@ -608,7 +618,6 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { // The outer loop works on one symbolic-value at a time. Suppose the input // addends are : <a1, x>, <b1, y>, <a2, x>, <c1, z>, <b2, y>, ... // The symbolic-values will be processed in this order: x, y, z. - // for (unsigned SymIdx = 0; SymIdx < AddendNum; SymIdx++) { const FAddend *ThisAddend = Addends[SymIdx]; @@ -626,7 +635,6 @@ Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) { // example, if the symbolic value "y" is being processed, the inner loop // will collect two addends "<b1,y>" and "<b2,Y>". These two addends will // be later on folded into "<b1+b2, y>". - // for (unsigned SameSymIdx = SymIdx + 1; SameSymIdx < AddendNum; SameSymIdx++) { const FAddend *T = Addends[SameSymIdx]; @@ -681,7 +689,7 @@ Value *FAddCombine::createNaryFAdd assert(!Opnds.empty() && "Expect at least one addend"); // Step 1: Check if the # of instructions needed exceeds the quota. - // + unsigned InstrNeeded = calcInstrNumber(Opnds); if (InstrNeeded > InstrQuota) return nullptr; @@ -726,10 +734,10 @@ Value *FAddCombine::createNaryFAdd LastVal = createFNeg(LastVal); } - #ifndef NDEBUG - assert(CreateInstrNum == InstrNeeded && - "Inconsistent in instruction numbers"); - #endif +#ifndef NDEBUG + assert(CreateInstrNum == InstrNeeded && + "Inconsistent in instruction numbers"); +#endif return LastVal; } @@ -950,9 +958,25 @@ static Value *checkForNegativeOperand(BinaryOperator &I, return nullptr; } -static Instruction *foldAddWithConstant(BinaryOperator &Add, - InstCombiner::BuilderTy &Builder) { +Instruction *InstCombiner::foldAddWithConstant(BinaryOperator &Add) { Value *Op0 = Add.getOperand(0), *Op1 = Add.getOperand(1); + Constant *Op1C; + if (!match(Op1, m_Constant(Op1C))) + return nullptr; + + if (Instruction *NV = foldOpWithConstantIntoOperand(Add)) + return NV; + + Value *X; + // zext(bool) + C -> bool ? C + 1 : C + if (match(Op0, m_ZExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, AddOne(Op1C), Op1); + + // ~X + C --> (C-1) - X + if (match(Op0, m_Not(m_Value(X)))) + return BinaryOperator::CreateSub(SubOne(Op1C), X); + const APInt *C; if (!match(Op1, m_APInt(C))) return nullptr; @@ -968,21 +992,17 @@ static Instruction *foldAddWithConstant(BinaryOperator &Add, return BinaryOperator::CreateXor(Op0, Op1); } - Value *X; - const APInt *C2; - Type *Ty = Add.getType(); - // Is this add the last step in a convoluted sext? // add(zext(xor i16 X, -32768), -32768) --> sext X + Type *Ty = Add.getType(); + const APInt *C2; if (match(Op0, m_ZExt(m_Xor(m_Value(X), m_APInt(C2)))) && C2->isMinSignedValue() && C2->sext(Ty->getScalarSizeInBits()) == *C) return CastInst::Create(Instruction::SExt, X, Ty); // (add (zext (add nuw X, C2)), C) --> (zext (add nuw X, C2 + C)) - // FIXME: This should check hasOneUse to not increase the instruction count? - if (C->isNegative() && - match(Op0, m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2)))) && - C->sge(-C2->sext(C->getBitWidth()))) { + if (match(Op0, m_OneUse(m_ZExt(m_NUWAdd(m_Value(X), m_APInt(C2))))) && + C->isNegative() && C->sge(-C2->sext(C->getBitWidth()))) { Constant *NewC = ConstantInt::get(X->getType(), *C2 + C->trunc(C2->getBitWidth())); return new ZExtInst(Builder.CreateNUWAdd(X, NewC), Ty); @@ -1013,34 +1033,29 @@ static Instruction *foldAddWithConstant(BinaryOperator &Add, Instruction *InstCombiner::visitAdd(BinaryOperator &I) { bool Changed = SimplifyAssociativeOrCommutative(I); - Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); - if (Value *V = SimplifyVectorOp(I)) return replaceInstUsesWith(I, V); + Value *LHS = I.getOperand(0), *RHS = I.getOperand(1); if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(), SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - // (A*B)+(A*C) -> A*(B+C) etc + // (A*B)+(A*C) -> A*(B+C) etc if (Value *V = SimplifyUsingDistributiveLaws(I)) return replaceInstUsesWith(I, V); - if (Instruction *X = foldAddWithConstant(I, Builder)) + if (Instruction *X = foldAddWithConstant(I)) return X; // FIXME: This should be moved into the above helper function to allow these - // transforms for splat vectors. + // transforms for general constant or constant splat vectors. + Type *Ty = I.getType(); if (ConstantInt *CI = dyn_cast<ConstantInt>(RHS)) { - // zext(bool) + C -> bool ? C + 1 : C - if (ZExtInst *ZI = dyn_cast<ZExtInst>(LHS)) - if (ZI->getSrcTy()->isIntegerTy(1)) - return SelectInst::Create(ZI->getOperand(0), AddOne(CI), CI); - Value *XorLHS = nullptr; ConstantInt *XorRHS = nullptr; if (match(LHS, m_Xor(m_Value(XorLHS), m_ConstantInt(XorRHS)))) { - uint32_t TySizeBits = I.getType()->getScalarSizeInBits(); + unsigned TySizeBits = Ty->getScalarSizeInBits(); const APInt &RHSVal = CI->getValue(); unsigned ExtendAmt = 0; // If we have ADD(XOR(AND(X, 0xFF), 0x80), 0xF..F80), it's a sext. @@ -1059,7 +1074,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } if (ExtendAmt) { - Constant *ShAmt = ConstantInt::get(I.getType(), ExtendAmt); + Constant *ShAmt = ConstantInt::get(Ty, ExtendAmt); Value *NewShl = Builder.CreateShl(XorLHS, ShAmt, "sext"); return BinaryOperator::CreateAShr(NewShl, ShAmt); } @@ -1080,38 +1095,30 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { } } - if (isa<Constant>(RHS)) - if (Instruction *NV = foldOpWithConstantIntoOperand(I)) - return NV; - - if (I.getType()->isIntOrIntVectorTy(1)) + if (Ty->isIntOrIntVectorTy(1)) return BinaryOperator::CreateXor(LHS, RHS); // X + X --> X << 1 if (LHS == RHS) { - BinaryOperator *New = - BinaryOperator::CreateShl(LHS, ConstantInt::get(I.getType(), 1)); - New->setHasNoSignedWrap(I.hasNoSignedWrap()); - New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - return New; + auto *Shl = BinaryOperator::CreateShl(LHS, ConstantInt::get(Ty, 1)); + Shl->setHasNoSignedWrap(I.hasNoSignedWrap()); + Shl->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); + return Shl; } - // -A + B --> B - A - // -A + -B --> -(A + B) - if (Value *LHSV = dyn_castNegVal(LHS)) { - if (!isa<Constant>(RHS)) - if (Value *RHSV = dyn_castNegVal(RHS)) { - Value *NewAdd = Builder.CreateAdd(LHSV, RHSV, "sum"); - return BinaryOperator::CreateNeg(NewAdd); - } + Value *A, *B; + if (match(LHS, m_Neg(m_Value(A)))) { + // -A + -B --> -(A + B) + if (match(RHS, m_Neg(m_Value(B)))) + return BinaryOperator::CreateNeg(Builder.CreateAdd(A, B)); - return BinaryOperator::CreateSub(RHS, LHSV); + // -A + B --> B - A + return BinaryOperator::CreateSub(RHS, A); } // A + -B --> A - B - if (!isa<Constant>(RHS)) - if (Value *V = dyn_castNegVal(RHS)) - return BinaryOperator::CreateSub(LHS, V); + if (match(RHS, m_Neg(m_Value(B)))) + return BinaryOperator::CreateSub(LHS, B); if (Value *V = checkForNegativeOperand(I, Builder)) return replaceInstUsesWith(I, V); @@ -1120,12 +1127,6 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (haveNoCommonBitsSet(LHS, RHS, DL, &AC, &I, &DT)) return BinaryOperator::CreateOr(LHS, RHS); - if (Constant *CRHS = dyn_cast<Constant>(RHS)) { - Value *X; - if (match(LHS, m_Not(m_Value(X)))) // ~X + C --> (C-1) - X - return BinaryOperator::CreateSub(SubOne(CRHS), X); - } - // FIXME: We already did a check for ConstantInt RHS above this. // FIXME: Is this pattern covered by another fold? No regression tests fail on // removal. @@ -1187,12 +1188,12 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (LHSConv->hasOneUse()) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (ConstantExpr::getSExt(CI, I.getType()) == RHSC && + if (ConstantExpr::getSExt(CI, Ty) == RHSC && willNotOverflowSignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new, smaller add. Value *NewAdd = Builder.CreateNSWAdd(LHSConv->getOperand(0), CI, "addconv"); - return new SExtInst(NewAdd, I.getType()); + return new SExtInst(NewAdd, Ty); } } } @@ -1210,7 +1211,7 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // Insert the new integer add. Value *NewAdd = Builder.CreateNSWAdd(LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); - return new SExtInst(NewAdd, I.getType()); + return new SExtInst(NewAdd, Ty); } } } @@ -1223,12 +1224,12 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { if (LHSConv->hasOneUse()) { Constant *CI = ConstantExpr::getTrunc(RHSC, LHSConv->getOperand(0)->getType()); - if (ConstantExpr::getZExt(CI, I.getType()) == RHSC && + if (ConstantExpr::getZExt(CI, Ty) == RHSC && willNotOverflowUnsignedAdd(LHSConv->getOperand(0), CI, I)) { // Insert the new, smaller add. Value *NewAdd = Builder.CreateNUWAdd(LHSConv->getOperand(0), CI, "addconv"); - return new ZExtInst(NewAdd, I.getType()); + return new ZExtInst(NewAdd, Ty); } } } @@ -1246,41 +1247,35 @@ Instruction *InstCombiner::visitAdd(BinaryOperator &I) { // Insert the new integer add. Value *NewAdd = Builder.CreateNUWAdd( LHSConv->getOperand(0), RHSConv->getOperand(0), "addconv"); - return new ZExtInst(NewAdd, I.getType()); + return new ZExtInst(NewAdd, Ty); } } } // (add (xor A, B) (and A, B)) --> (or A, B) - { - Value *A = nullptr, *B = nullptr; - if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - - if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) - return BinaryOperator::CreateOr(A, B); - } + if (match(LHS, m_Xor(m_Value(A), m_Value(B))) && + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); + + // (add (and A, B) (xor A, B)) --> (or A, B) + if (match(RHS, m_Xor(m_Value(A), m_Value(B))) && + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) + return BinaryOperator::CreateOr(A, B); // (add (or A, B) (and A, B)) --> (add A, B) - { - Value *A = nullptr, *B = nullptr; - if (match(RHS, m_Or(m_Value(A), m_Value(B))) && - match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { - auto *New = BinaryOperator::CreateAdd(A, B); - New->setHasNoSignedWrap(I.hasNoSignedWrap()); - New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - return New; - } + if (match(LHS, m_Or(m_Value(A), m_Value(B))) && + match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { + I.setOperand(0, A); + I.setOperand(1, B); + return &I; + } - if (match(LHS, m_Or(m_Value(A), m_Value(B))) && - match(RHS, m_c_And(m_Specific(A), m_Specific(B)))) { - auto *New = BinaryOperator::CreateAdd(A, B); - New->setHasNoSignedWrap(I.hasNoSignedWrap()); - New->setHasNoUnsignedWrap(I.hasNoUnsignedWrap()); - return New; - } + // (add (and A, B) (or A, B)) --> (add A, B) + if (match(RHS, m_Or(m_Value(A), m_Value(B))) && + match(LHS, m_c_And(m_Specific(A), m_Specific(B)))) { + I.setOperand(0, A); + I.setOperand(1, B); + return &I; } // TODO(jingyue): Consider willNotOverflowSignedAdd and @@ -1387,32 +1382,11 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { } } - // select C, 0, B + select C, A, 0 -> select C, A, B - { - Value *A1, *B1, *C1, *A2, *B2, *C2; - if (match(LHS, m_Select(m_Value(C1), m_Value(A1), m_Value(B1))) && - match(RHS, m_Select(m_Value(C2), m_Value(A2), m_Value(B2)))) { - if (C1 == C2) { - Constant *Z1=nullptr, *Z2=nullptr; - Value *A, *B, *C=C1; - if (match(A1, m_AnyZero()) && match(B2, m_AnyZero())) { - Z1 = dyn_cast<Constant>(A1); A = A2; - Z2 = dyn_cast<Constant>(B2); B = B1; - } else if (match(B1, m_AnyZero()) && match(A2, m_AnyZero())) { - Z1 = dyn_cast<Constant>(B1); B = B2; - Z2 = dyn_cast<Constant>(A2); A = A1; - } - - if (Z1 && Z2 && - (I.hasNoSignedZeros() || - (Z1->isNegativeZeroValue() && Z2->isNegativeZeroValue()))) { - return SelectInst::Create(C, A, B); - } - } - } - } + // Handle specials cases for FAdd with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, LHS, RHS)) + return replaceInstUsesWith(I, V); - if (I.hasUnsafeAlgebra()) { + if (I.isFast()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } @@ -1423,7 +1397,6 @@ Instruction *InstCombiner::visitFAdd(BinaryOperator &I) { /// Optimize pointer differences into the same array into a size. Consider: /// &A[10] - &A[0]: we should compile this to "10". LHS/RHS are the pointer /// operands to the ptrtoint instructions for the LHS/RHS of the subtract. -/// Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, Type *Ty) { // If LHS is a gep based on RHS or RHS is a gep based on LHS, we can optimize @@ -1465,12 +1438,31 @@ Value *InstCombiner::OptimizePointerDifference(Value *LHS, Value *RHS, } } - // Avoid duplicating the arithmetic if GEP2 has non-constant indices and - // multiple users. - if (!GEP1 || - (GEP2 && !GEP2->hasAllConstantIndices() && !GEP2->hasOneUse())) + if (!GEP1) + // No GEP found. return nullptr; + if (GEP2) { + // (gep X, ...) - (gep X, ...) + // + // Avoid duplicating the arithmetic if there are more than one non-constant + // indices between the two GEPs and either GEP has a non-constant index and + // multiple users. If zero non-constant index, the result is a constant and + // there is no duplication. If one non-constant index, the result is an add + // or sub with a constant, which is no larger than the original code, and + // there's no duplicated arithmetic, even if either GEP has multiple + // users. If more than one non-constant indices combined, as long as the GEP + // with at least one non-constant index doesn't have multiple users, there + // is no duplication. + unsigned NumNonConstantIndices1 = GEP1->countNonConstantIndices(); + unsigned NumNonConstantIndices2 = GEP2->countNonConstantIndices(); + if (NumNonConstantIndices1 + NumNonConstantIndices2 > 1 && + ((NumNonConstantIndices1 > 0 && !GEP1->hasOneUse()) || + (NumNonConstantIndices2 > 0 && !GEP2->hasOneUse()))) { + return nullptr; + } + } + // Emit the offset of the GEP and an intptr_t. Value *Result = EmitGEPOffset(GEP1); @@ -1528,8 +1520,13 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNot(Op1); if (Constant *C = dyn_cast<Constant>(Op0)) { + Value *X; + // C - zext(bool) -> bool ? C - 1 : C + if (match(Op1, m_ZExt(m_Value(X))) && + X->getType()->getScalarSizeInBits() == 1) + return SelectInst::Create(X, SubOne(C), C); + // C - ~X == X + (1+C) - Value *X = nullptr; if (match(Op1, m_Not(m_Value(X)))) return BinaryOperator::CreateAdd(X, AddOne(C)); @@ -1600,7 +1597,7 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { return BinaryOperator::CreateNeg(Y); } - // (sub (or A, B) (xor A, B)) --> (and A, B) + // (sub (or A, B), (xor A, B)) --> (and A, B) { Value *A, *B; if (match(Op1, m_Xor(m_Value(A), m_Value(B))) && @@ -1626,7 +1623,6 @@ Instruction *InstCombiner::visitSub(BinaryOperator &I) { Builder.CreateSub(Z, Y, Op1->getName())); // (X - (X & Y)) --> (X & ~Y) - // if (match(Op1, m_c_And(m_Value(Y), m_Specific(Op0)))) return BinaryOperator::CreateAnd(Op0, Builder.CreateNot(Y, Y->getName() + ".not")); @@ -1741,7 +1737,11 @@ Instruction *InstCombiner::visitFSub(BinaryOperator &I) { } } - if (I.hasUnsafeAlgebra()) { + // Handle specials cases for FSub with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + + if (I.isFast()) { if (Value *V = FAddCombine(Builder).simplify(&I)) return replaceInstUsesWith(I, V); } diff --git a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index fdc9c373b95e..2364202e5b69 100644 --- a/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -12,11 +12,11 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Transforms/Utils/CmpInstAnalysis.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; using namespace PatternMatch; @@ -120,35 +120,9 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, ConstantInt *AndRHS, BinaryOperator &TheAnd) { Value *X = Op->getOperand(0); - Constant *Together = nullptr; - if (!Op->isShift()) - Together = ConstantExpr::getAnd(AndRHS, OpRHS); switch (Op->getOpcode()) { default: break; - case Instruction::Xor: - if (Op->hasOneUse()) { - // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) - Value *And = Builder.CreateAnd(X, AndRHS); - And->takeName(Op); - return BinaryOperator::CreateXor(And, Together); - } - break; - case Instruction::Or: - if (Op->hasOneUse()){ - ConstantInt *TogetherCI = dyn_cast<ConstantInt>(Together); - if (TogetherCI && !TogetherCI->isZero()){ - // (X | C1) & C2 --> (X & (C2^(C1&C2))) | C1 - // NOTE: This reduces the number of bits set in the & mask, which - // can expose opportunities for store narrowing. - Together = ConstantExpr::getXor(AndRHS, Together); - Value *And = Builder.CreateAnd(X, Together); - And->takeName(Op); - return BinaryOperator::CreateOr(And, OpRHS); - } - } - - break; case Instruction::Add: if (Op->hasOneUse()) { // Adding a one to a single bit bit-field should be turned into an XOR @@ -182,64 +156,6 @@ Instruction *InstCombiner::OptAndOp(BinaryOperator *Op, } } break; - - case Instruction::Shl: { - // We know that the AND will not produce any of the bits shifted in, so if - // the anded constant includes them, clear them now! - // - uint32_t BitWidth = AndRHS->getType()->getBitWidth(); - uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); - APInt ShlMask(APInt::getHighBitsSet(BitWidth, BitWidth-OpRHSVal)); - ConstantInt *CI = Builder.getInt(AndRHS->getValue() & ShlMask); - - if (CI->getValue() == ShlMask) - // Masking out bits that the shift already masks. - return replaceInstUsesWith(TheAnd, Op); // No need for the and. - - if (CI != AndRHS) { // Reducing bits set in and. - TheAnd.setOperand(1, CI); - return &TheAnd; - } - break; - } - case Instruction::LShr: { - // We know that the AND will not produce any of the bits shifted in, so if - // the anded constant includes them, clear them now! This only applies to - // unsigned shifts, because a signed shr may bring in set bits! - // - uint32_t BitWidth = AndRHS->getType()->getBitWidth(); - uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); - APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); - ConstantInt *CI = Builder.getInt(AndRHS->getValue() & ShrMask); - - if (CI->getValue() == ShrMask) - // Masking out bits that the shift already masks. - return replaceInstUsesWith(TheAnd, Op); - - if (CI != AndRHS) { - TheAnd.setOperand(1, CI); // Reduce bits set in and cst. - return &TheAnd; - } - break; - } - case Instruction::AShr: - // Signed shr. - // See if this is shifting in some sign extension, then masking it out - // with an and. - if (Op->hasOneUse()) { - uint32_t BitWidth = AndRHS->getType()->getBitWidth(); - uint32_t OpRHSVal = OpRHS->getLimitedValue(BitWidth); - APInt ShrMask(APInt::getLowBitsSet(BitWidth, BitWidth - OpRHSVal)); - Constant *C = Builder.getInt(AndRHS->getValue() & ShrMask); - if (C == AndRHS) { // Masking out bits shifted in. - // (Val ashr C1) & C2 -> (Val lshr C1) & C2 - // Make the argument unsigned. - Value *ShVal = Op->getOperand(0); - ShVal = Builder.CreateLShr(ShVal, OpRHS, Op->getName()); - return BinaryOperator::CreateAnd(ShVal, AndRHS, TheAnd.getName()); - } - } - break; } return nullptr; } @@ -376,6 +292,18 @@ static unsigned conjugateICmpMask(unsigned Mask) { return NewMask; } +// Adapts the external decomposeBitTestICmp for local use. +static bool decomposeBitTestICmp(Value *LHS, Value *RHS, CmpInst::Predicate &Pred, + Value *&X, Value *&Y, Value *&Z) { + APInt Mask; + if (!llvm::decomposeBitTestICmp(LHS, RHS, Pred, X, Mask)) + return false; + + Y = ConstantInt::get(X->getType(), Mask); + Z = ConstantInt::get(X->getType(), 0); + return true; +} + /// Handle (icmp(A & B) ==/!= C) &/| (icmp(A & D) ==/!= E). /// Return the set of pattern classes (from MaskedICmpType) that both LHS and /// RHS satisfy. @@ -384,10 +312,9 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, ICmpInst *RHS, ICmpInst::Predicate &PredL, ICmpInst::Predicate &PredR) { - if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) - return 0; - // vectors are not (yet?) supported - if (LHS->getOperand(0)->getType()->isVectorTy()) + // vectors are not (yet?) supported. Don't support pointers either. + if (!LHS->getOperand(0)->getType()->isIntegerTy() || + !RHS->getOperand(0)->getType()->isIntegerTy()) return 0; // Here comes the tricky part: @@ -400,24 +327,18 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *L2 = LHS->getOperand(1); Value *L11, *L12, *L21, *L22; // Check whether the icmp can be decomposed into a bit test. - if (decomposeBitTestICmp(LHS, PredL, L11, L12, L2)) { + if (decomposeBitTestICmp(L1, L2, PredL, L11, L12, L2)) { L21 = L22 = L1 = nullptr; } else { // Look for ANDs in the LHS icmp. - if (!L1->getType()->isIntegerTy()) { - // You can icmp pointers, for example. They really aren't masks. - L11 = L12 = nullptr; - } else if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) { + if (!match(L1, m_And(m_Value(L11), m_Value(L12)))) { // Any icmp can be viewed as being trivially masked; if it allows us to // remove one, it's worth it. L11 = L1; L12 = Constant::getAllOnesValue(L1->getType()); } - if (!L2->getType()->isIntegerTy()) { - // You can icmp pointers, for example. They really aren't masks. - L21 = L22 = nullptr; - } else if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) { + if (!match(L2, m_And(m_Value(L21), m_Value(L22)))) { L21 = L2; L22 = Constant::getAllOnesValue(L2->getType()); } @@ -431,7 +352,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, Value *R2 = RHS->getOperand(1); Value *R11, *R12; bool Ok = false; - if (decomposeBitTestICmp(RHS, PredR, R11, R12, R2)) { + if (decomposeBitTestICmp(R1, R2, PredR, R11, R12, R2)) { if (R11 == L11 || R11 == L12 || R11 == L21 || R11 == L22) { A = R11; D = R12; @@ -444,7 +365,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, E = R2; R1 = nullptr; Ok = true; - } else if (R1->getType()->isIntegerTy()) { + } else { if (!match(R1, m_And(m_Value(R11), m_Value(R12)))) { // As before, model no mask as a trivial mask if it'll let us do an // optimization. @@ -470,7 +391,7 @@ static unsigned getMaskedTypeForICmpPair(Value *&A, Value *&B, Value *&C, return 0; // Look for ANDs on the right side of the RHS icmp. - if (!Ok && R2->getType()->isIntegerTy()) { + if (!Ok) { if (!match(R2, m_And(m_Value(R11), m_Value(R12)))) { R11 = R2; R12 = Constant::getAllOnesValue(R2->getType()); @@ -980,17 +901,15 @@ Value *InstCombiner::foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } -/// Optimize (fcmp)&(fcmp). NOTE: Unlike the rest of instcombine, this returns -/// a Value which should already be inserted into the function. -Value *InstCombiner::foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); +Value *InstCombiner::foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd) { + Value *LHS0 = LHS->getOperand(0), *LHS1 = LHS->getOperand(1); + Value *RHS0 = RHS->getOperand(0), *RHS1 = RHS->getOperand(1); + FCmpInst::Predicate PredL = LHS->getPredicate(), PredR = RHS->getPredicate(); - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { + if (LHS0 == RHS1 && RHS0 == LHS1) { // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); + PredR = FCmpInst::getSwappedPredicate(PredR); + std::swap(RHS0, RHS1); } // Simplify (fcmp cc0 x, y) & (fcmp cc1 x, y). @@ -1002,31 +921,30 @@ Value *InstCombiner::foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { // bool(R & CC0) && bool(R & CC1) // = bool((R & CC0) & (R & CC1)) // = bool(R & (CC0 & CC1)) <= by re-association, commutation, and idempotency - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) - return getFCmpValue(getFCmpCode(Op0CC) & getFCmpCode(Op1CC), Op0LHS, Op0RHS, - Builder); + // + // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: + // bool(R & CC0) || bool(R & CC1) + // = bool((R & CC0) | (R & CC1)) + // = bool(R & (CC0 | CC1)) <= by reversed distribution (contribution? ;) + if (LHS0 == RHS0 && LHS1 == RHS1) { + unsigned FCmpCodeL = getFCmpCode(PredL); + unsigned FCmpCodeR = getFCmpCode(PredR); + unsigned NewPred = IsAnd ? FCmpCodeL & FCmpCodeR : FCmpCodeL | FCmpCodeR; + return getFCmpValue(NewPred, LHS0, LHS1, Builder); + } - if (LHS->getPredicate() == FCmpInst::FCMP_ORD && - RHS->getPredicate() == FCmpInst::FCMP_ORD) { - if (LHS->getOperand(0)->getType() != RHS->getOperand(0)->getType()) + if ((PredL == FCmpInst::FCMP_ORD && PredR == FCmpInst::FCMP_ORD && IsAnd) || + (PredL == FCmpInst::FCMP_UNO && PredR == FCmpInst::FCMP_UNO && !IsAnd)) { + if (LHS0->getType() != RHS0->getType()) return nullptr; - // (fcmp ord x, c) & (fcmp ord y, c) -> (fcmp ord x, y) - if (ConstantFP *LHSC = dyn_cast<ConstantFP>(LHS->getOperand(1))) - if (ConstantFP *RHSC = dyn_cast<ConstantFP>(RHS->getOperand(1))) { - // If either of the constants are nans, then the whole thing returns - // false. - if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) - return Builder.getFalse(); - return Builder.CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); - } - - // Handle vector zeros. This occurs because the canonical form of - // "fcmp ord x,x" is "fcmp ord x, 0". - if (isa<ConstantAggregateZero>(LHS->getOperand(1)) && - isa<ConstantAggregateZero>(RHS->getOperand(1))) - return Builder.CreateFCmpORD(LHS->getOperand(0), RHS->getOperand(0)); - return nullptr; + // FCmp canonicalization ensures that (fcmp ord/uno X, X) and + // (fcmp ord/uno X, C) will be transformed to (fcmp X, 0.0). + if (match(LHS1, m_Zero()) && LHS1 == RHS1) + // Ignore the constants because they are obviously not NANs: + // (fcmp ord x, 0.0) & (fcmp ord y, 0.0) -> (fcmp ord x, y) + // (fcmp uno x, 0.0) | (fcmp uno y, 0.0) -> (fcmp uno x, y) + return Builder.CreateFCmp(PredL, LHS0, RHS0); } return nullptr; @@ -1069,30 +987,24 @@ bool InstCombiner::shouldOptimizeCast(CastInst *CI) { if (isEliminableCastPair(PrecedingCI, CI)) return false; - // If this is a vector sext from a compare, then we don't want to break the - // idiom where each element of the extended vector is either zero or all ones. - if (CI->getOpcode() == Instruction::SExt && - isa<CmpInst>(CastSrc) && CI->getDestTy()->isVectorTy()) - return false; - return true; } /// Fold {and,or,xor} (cast X), C. static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, InstCombiner::BuilderTy &Builder) { - Constant *C; - if (!match(Logic.getOperand(1), m_Constant(C))) + Constant *C = dyn_cast<Constant>(Logic.getOperand(1)); + if (!C) return nullptr; auto LogicOpc = Logic.getOpcode(); Type *DestTy = Logic.getType(); Type *SrcTy = Cast->getSrcTy(); - // Move the logic operation ahead of a zext if the constant is unchanged in - // the smaller source type. Performing the logic in a smaller type may provide - // more information to later folds, and the smaller logic instruction may be - // cheaper (particularly in the case of vectors). + // Move the logic operation ahead of a zext or sext if the constant is + // unchanged in the smaller source type. Performing the logic in a smaller + // type may provide more information to later folds, and the smaller logic + // instruction may be cheaper (particularly in the case of vectors). Value *X; if (match(Cast, m_OneUse(m_ZExt(m_Value(X))))) { Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); @@ -1104,6 +1016,16 @@ static Instruction *foldLogicCastConstant(BinaryOperator &Logic, CastInst *Cast, } } + if (match(Cast, m_OneUse(m_SExt(m_Value(X))))) { + Constant *TruncC = ConstantExpr::getTrunc(C, SrcTy); + Constant *SextTruncC = ConstantExpr::getSExt(TruncC, DestTy); + if (SextTruncC == C) { + // LogicOpc (sext X), C --> sext (LogicOpc X, C) + Value *NewOp = Builder.CreateBinOp(LogicOpc, X, TruncC); + return new SExtInst(NewOp, DestTy); + } + } + return nullptr; } @@ -1167,38 +1089,9 @@ Instruction *InstCombiner::foldCastedBitwiseLogic(BinaryOperator &I) { // cast is otherwise not optimizable. This happens for vector sexts. FCmpInst *FCmp0 = dyn_cast<FCmpInst>(Cast0Src); FCmpInst *FCmp1 = dyn_cast<FCmpInst>(Cast1Src); - if (FCmp0 && FCmp1) { - Value *Res = LogicOpc == Instruction::And ? foldAndOfFCmps(FCmp0, FCmp1) - : foldOrOfFCmps(FCmp0, FCmp1); - if (Res) - return CastInst::Create(CastOpcode, Res, DestTy); - return nullptr; - } - - return nullptr; -} - -static Instruction *foldBoolSextMaskToSelect(BinaryOperator &I) { - Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - // Canonicalize SExt or Not to the LHS - if (match(Op1, m_SExt(m_Value())) || match(Op1, m_Not(m_Value()))) { - std::swap(Op0, Op1); - } - - // Fold (and (sext bool to A), B) --> (select bool, B, 0) - Value *X = nullptr; - if (match(Op0, m_SExt(m_Value(X))) && X->getType()->isIntOrIntVectorTy(1)) { - Value *Zero = Constant::getNullValue(Op1->getType()); - return SelectInst::Create(X, Op1, Zero); - } - - // Fold (and ~(sext bool to A), B) --> (select bool, 0, B) - if (match(Op0, m_Not(m_SExt(m_Value(X)))) && - X->getType()->isIntOrIntVectorTy(1)) { - Value *Zero = Constant::getNullValue(Op0->getType()); - return SelectInst::Create(X, Zero, Op1); - } + if (FCmp0 && FCmp1) + if (Value *R = foldLogicOfFCmps(FCmp0, FCmp1, LogicOpc == Instruction::And)) + return CastInst::Create(CastOpcode, R, DestTy); return nullptr; } @@ -1284,14 +1177,61 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (Value *V = SimplifyBSwap(I, Builder)) return replaceInstUsesWith(I, V); - if (match(Op1, m_One())) { - // (1 << x) & 1 --> zext(x == 0) - // (1 >> x) & 1 --> zext(x == 0) - Value *X; - if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X))))) { + const APInt *C; + if (match(Op1, m_APInt(C))) { + Value *X, *Y; + if (match(Op0, m_OneUse(m_LogicalShift(m_One(), m_Value(X)))) && + C->isOneValue()) { + // (1 << X) & 1 --> zext(X == 0) + // (1 >> X) & 1 --> zext(X == 0) Value *IsZero = Builder.CreateICmpEQ(X, ConstantInt::get(I.getType(), 0)); return new ZExtInst(IsZero, I.getType()); } + + const APInt *XorC; + if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_APInt(XorC))))) { + // (X ^ C1) & C2 --> (X & C2) ^ (C1&C2) + Constant *NewC = ConstantInt::get(I.getType(), *C & *XorC); + Value *And = Builder.CreateAnd(X, Op1); + And->takeName(Op0); + return BinaryOperator::CreateXor(And, NewC); + } + + const APInt *OrC; + if (match(Op0, m_OneUse(m_Or(m_Value(X), m_APInt(OrC))))) { + // (X | C1) & C2 --> (X & C2^(C1&C2)) | (C1&C2) + // NOTE: This reduces the number of bits set in the & mask, which + // can expose opportunities for store narrowing for scalars. + // NOTE: SimplifyDemandedBits should have already removed bits from C1 + // that aren't set in C2. Meaning we can replace (C1&C2) with C1 in + // above, but this feels safer. + APInt Together = *C & *OrC; + Value *And = Builder.CreateAnd(X, ConstantInt::get(I.getType(), + Together ^ *C)); + And->takeName(Op0); + return BinaryOperator::CreateOr(And, ConstantInt::get(I.getType(), + Together)); + } + + // If the mask is only needed on one incoming arm, push the 'and' op up. + if (match(Op0, m_OneUse(m_Xor(m_Value(X), m_Value(Y)))) || + match(Op0, m_OneUse(m_Or(m_Value(X), m_Value(Y))))) { + APInt NotAndMask(~(*C)); + BinaryOperator::BinaryOps BinOp = cast<BinaryOperator>(Op0)->getOpcode(); + if (MaskedValueIsZero(X, NotAndMask, 0, &I)) { + // Not masking anything out for the LHS, move mask to RHS. + // and ({x}or X, Y), C --> {x}or X, (and Y, C) + Value *NewRHS = Builder.CreateAnd(Y, Op1, Y->getName() + ".masked"); + return BinaryOperator::Create(BinOp, X, NewRHS); + } + if (!isa<Constant>(Y) && MaskedValueIsZero(Y, NotAndMask, 0, &I)) { + // Not masking anything out for the RHS, move mask to LHS. + // and ({x}or X, Y), C --> {x}or (and X, C), Y + Value *NewLHS = Builder.CreateAnd(X, Op1, X->getName() + ".masked"); + return BinaryOperator::Create(BinOp, NewLHS, Y); + } + } + } if (ConstantInt *AndRHS = dyn_cast<ConstantInt>(Op1)) { @@ -1299,34 +1239,6 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { // Optimize a variety of ((val OP C1) & C2) combinations... if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - Value *Op0LHS = Op0I->getOperand(0); - Value *Op0RHS = Op0I->getOperand(1); - switch (Op0I->getOpcode()) { - default: break; - case Instruction::Xor: - case Instruction::Or: { - // If the mask is only needed on one incoming arm, push it up. - if (!Op0I->hasOneUse()) break; - - APInt NotAndRHS(~AndRHSMask); - if (MaskedValueIsZero(Op0LHS, NotAndRHS, 0, &I)) { - // Not masking anything out for the LHS, move to RHS. - Value *NewRHS = Builder.CreateAnd(Op0RHS, AndRHS, - Op0RHS->getName()+".masked"); - return BinaryOperator::Create(Op0I->getOpcode(), Op0LHS, NewRHS); - } - if (!isa<Constant>(Op0RHS) && - MaskedValueIsZero(Op0RHS, NotAndRHS, 0, &I)) { - // Not masking anything out for the RHS, move to LHS. - Value *NewLHS = Builder.CreateAnd(Op0LHS, AndRHS, - Op0LHS->getName()+".masked"); - return BinaryOperator::Create(Op0I->getOpcode(), NewLHS, Op0RHS); - } - - break; - } - } - // ((C1 OP zext(X)) & C2) -> zext((C1-X) & C2) if C2 fits in the bitwidth // of X and OP behaves well when given trunc(C1) and X. switch (Op0I->getOpcode()) { @@ -1343,6 +1255,7 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { if (AndRHSMask.isIntN(X->getType()->getScalarSizeInBits())) { auto *TruncC1 = ConstantExpr::getTrunc(C1, X->getType()); Value *BinOp; + Value *Op0LHS = Op0I->getOperand(0); if (isa<ZExtInst>(Op0LHS)) BinOp = Builder.CreateBinOp(Op0I->getOpcode(), X, TruncC1); else @@ -1467,17 +1380,22 @@ Instruction *InstCombiner::visitAnd(BinaryOperator &I) { } } - // If and'ing two fcmp, try combine them into one. if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = foldAndOfFCmps(LHS, RHS)) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, true)) return replaceInstUsesWith(I, Res); if (Instruction *CastedAnd = foldCastedBitwiseLogic(I)) return CastedAnd; - if (Instruction *Select = foldBoolSextMaskToSelect(I)) - return Select; + // and(sext(A), B) / and(B, sext(A)) --> A ? B : 0, where A is i1 or <N x i1>. + Value *A; + if (match(Op0, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Op1, Constant::getNullValue(I.getType())); + if (match(Op1, m_OneUse(m_SExt(m_Value(A)))) && + A->getType()->isIntOrIntVectorTy(1)) + return SelectInst::Create(A, Op0, Constant::getNullValue(I.getType())); return Changed ? &I : nullptr; } @@ -1567,8 +1485,9 @@ static Value *getSelectCondition(Value *A, Value *B, // If both operands are constants, see if the constants are inverse bitmasks. Constant *AC, *BC; if (match(A, m_Constant(AC)) && match(B, m_Constant(BC)) && - areInverseVectorBitmasks(AC, BC)) - return ConstantExpr::getTrunc(AC, CmpInst::makeCmpResultType(Ty)); + areInverseVectorBitmasks(AC, BC)) { + return Builder.CreateZExtOrTrunc(AC, CmpInst::makeCmpResultType(Ty)); + } // If both operands are xor'd with constants using the same sexted boolean // operand, see if the constants are inverse bitmasks. @@ -1832,120 +1751,6 @@ Value *InstCombiner::foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, return nullptr; } -/// Optimize (fcmp)|(fcmp). NOTE: Unlike the rest of instcombine, this returns -/// a Value which should already be inserted into the function. -Value *InstCombiner::foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS) { - Value *Op0LHS = LHS->getOperand(0), *Op0RHS = LHS->getOperand(1); - Value *Op1LHS = RHS->getOperand(0), *Op1RHS = RHS->getOperand(1); - FCmpInst::Predicate Op0CC = LHS->getPredicate(), Op1CC = RHS->getPredicate(); - - if (Op0LHS == Op1RHS && Op0RHS == Op1LHS) { - // Swap RHS operands to match LHS. - Op1CC = FCmpInst::getSwappedPredicate(Op1CC); - std::swap(Op1LHS, Op1RHS); - } - - // Simplify (fcmp cc0 x, y) | (fcmp cc1 x, y). - // This is a similar transformation to the one in FoldAndOfFCmps. - // - // Since (R & CC0) and (R & CC1) are either R or 0, we actually have this: - // bool(R & CC0) || bool(R & CC1) - // = bool((R & CC0) | (R & CC1)) - // = bool(R & (CC0 | CC1)) <= by reversed distribution (contribution? ;) - if (Op0LHS == Op1LHS && Op0RHS == Op1RHS) - return getFCmpValue(getFCmpCode(Op0CC) | getFCmpCode(Op1CC), Op0LHS, Op0RHS, - Builder); - - if (LHS->getPredicate() == FCmpInst::FCMP_UNO && - RHS->getPredicate() == FCmpInst::FCMP_UNO && - LHS->getOperand(0)->getType() == RHS->getOperand(0)->getType()) { - if (ConstantFP *LHSC = dyn_cast<ConstantFP>(LHS->getOperand(1))) - if (ConstantFP *RHSC = dyn_cast<ConstantFP>(RHS->getOperand(1))) { - // If either of the constants are nans, then the whole thing returns - // true. - if (LHSC->getValueAPF().isNaN() || RHSC->getValueAPF().isNaN()) - return Builder.getTrue(); - - // Otherwise, no need to compare the two constants, compare the - // rest. - return Builder.CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); - } - - // Handle vector zeros. This occurs because the canonical form of - // "fcmp uno x,x" is "fcmp uno x, 0". - if (isa<ConstantAggregateZero>(LHS->getOperand(1)) && - isa<ConstantAggregateZero>(RHS->getOperand(1))) - return Builder.CreateFCmpUNO(LHS->getOperand(0), RHS->getOperand(0)); - - return nullptr; - } - - return nullptr; -} - -/// This helper function folds: -/// -/// ((A | B) & C1) | (B & C2) -/// -/// into: -/// -/// (A & C1) | B -/// -/// when the XOR of the two constants is "all ones" (-1). -static Instruction *FoldOrWithConstants(BinaryOperator &I, Value *Op, - Value *A, Value *B, Value *C, - InstCombiner::BuilderTy &Builder) { - ConstantInt *CI1 = dyn_cast<ConstantInt>(C); - if (!CI1) return nullptr; - - Value *V1 = nullptr; - ConstantInt *CI2 = nullptr; - if (!match(Op, m_And(m_Value(V1), m_ConstantInt(CI2)))) return nullptr; - - APInt Xor = CI1->getValue() ^ CI2->getValue(); - if (!Xor.isAllOnesValue()) return nullptr; - - if (V1 == A || V1 == B) { - Value *NewOp = Builder.CreateAnd((V1 == A) ? B : A, CI1); - return BinaryOperator::CreateOr(NewOp, V1); - } - - return nullptr; -} - -/// \brief This helper function folds: -/// -/// ((A ^ B) & C1) | (B & C2) -/// -/// into: -/// -/// (A & C1) ^ B -/// -/// when the XOR of the two constants is "all ones" (-1). -static Instruction *FoldXorWithConstants(BinaryOperator &I, Value *Op, - Value *A, Value *B, Value *C, - InstCombiner::BuilderTy &Builder) { - ConstantInt *CI1 = dyn_cast<ConstantInt>(C); - if (!CI1) - return nullptr; - - Value *V1 = nullptr; - ConstantInt *CI2 = nullptr; - if (!match(Op, m_And(m_Value(V1), m_ConstantInt(CI2)))) - return nullptr; - - APInt Xor = CI1->getValue() ^ CI2->getValue(); - if (!Xor.isAllOnesValue()) - return nullptr; - - if (V1 == A || V1 == B) { - Value *NewOp = Builder.CreateAnd(V1 == A ? B : A, CI1); - return BinaryOperator::CreateXor(NewOp, V1); - } - - return nullptr; -} - // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -2011,10 +1816,10 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Value *C = nullptr, *D = nullptr; if (match(Op0, m_And(m_Value(A), m_Value(C))) && match(Op1, m_And(m_Value(B), m_Value(D)))) { - Value *V1 = nullptr, *V2 = nullptr; ConstantInt *C1 = dyn_cast<ConstantInt>(C); ConstantInt *C2 = dyn_cast<ConstantInt>(D); if (C1 && C2) { // (A & C1)|(B & C2) + Value *V1 = nullptr, *V2 = nullptr; if ((C1->getValue() & C2->getValue()).isNullValue()) { // ((V | N) & C1) | (V & C2) --> (V|N) & (C1|C2) // iff (C1&C2) == 0 and (N&~C1) == 0 @@ -2046,6 +1851,24 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { Builder.getInt(C1->getValue()|C2->getValue())); } } + + if (C1->getValue() == ~C2->getValue()) { + Value *X; + + // ((X|B)&C1)|(B&C2) -> (X&C1) | B iff C1 == ~C2 + if (match(A, m_c_Or(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, C1), B); + // (A&C2)|((X|A)&C1) -> (X&C2) | A iff C1 == ~C2 + if (match(B, m_c_Or(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateOr(Builder.CreateAnd(X, C2), A); + + // ((X^B)&C1)|(B&C2) -> (X&C1) ^ B iff C1 == ~C2 + if (match(A, m_c_Xor(m_Value(X), m_Specific(B)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, C1), B); + // (A&C2)|((X^A)&C1) -> (X&C2) ^ A iff C1 == ~C2 + if (match(B, m_c_Xor(m_Specific(A), m_Value(X)))) + return BinaryOperator::CreateXor(Builder.CreateAnd(X, C2), A); + } } // Don't try to form a select if it's unlikely that we'll get rid of at @@ -2070,27 +1893,6 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { if (Value *V = matchSelectFromAndOr(D, B, C, A, Builder)) return replaceInstUsesWith(I, V); } - - // ((A|B)&1)|(B&-2) -> (A&1) | B - if (match(A, m_c_Or(m_Value(V1), m_Specific(B)))) { - if (Instruction *Ret = FoldOrWithConstants(I, Op1, V1, B, C, Builder)) - return Ret; - } - // (B&-2)|((A|B)&1) -> (A&1) | B - if (match(B, m_c_Or(m_Specific(A), m_Value(V1)))) { - if (Instruction *Ret = FoldOrWithConstants(I, Op0, A, V1, D, Builder)) - return Ret; - } - // ((A^B)&1)|(B&-2) -> (A&1) ^ B - if (match(A, m_c_Xor(m_Value(V1), m_Specific(B)))) { - if (Instruction *Ret = FoldXorWithConstants(I, Op1, V1, B, C, Builder)) - return Ret; - } - // (B&-2)|((A^B)&1) -> (A&1) ^ B - if (match(B, m_c_Xor(m_Specific(A), m_Value(V1)))) { - if (Instruction *Ret = FoldXorWithConstants(I, Op0, A, V1, D, Builder)) - return Ret; - } } // (A ^ B) | ((B ^ C) ^ A) -> (A ^ B) | C @@ -2182,10 +1984,9 @@ Instruction *InstCombiner::visitOr(BinaryOperator &I) { } } - // (fcmp uno x, c) | (fcmp uno y, c) -> (fcmp uno x, y) if (FCmpInst *LHS = dyn_cast<FCmpInst>(I.getOperand(0))) if (FCmpInst *RHS = dyn_cast<FCmpInst>(I.getOperand(1))) - if (Value *Res = foldOrOfFCmps(LHS, RHS)) + if (Value *Res = foldLogicOfFCmps(LHS, RHS, false)) return replaceInstUsesWith(I, Res); if (Instruction *CastedOr = foldCastedBitwiseLogic(I)) @@ -2434,60 +2235,51 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { return replaceInstUsesWith(I, Op0); } - if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { - // fold (xor(zext(cmp)), 1) and (xor(sext(cmp)), -1) to ext(!cmp). - if (CastInst *Op0C = dyn_cast<CastInst>(Op0)) { - if (CmpInst *CI = dyn_cast<CmpInst>(Op0C->getOperand(0))) { - if (CI->hasOneUse() && Op0C->hasOneUse()) { - Instruction::CastOps Opcode = Op0C->getOpcode(); - if ((Opcode == Instruction::ZExt || Opcode == Instruction::SExt) && - (RHSC == ConstantExpr::getCast(Opcode, Builder.getTrue(), - Op0C->getDestTy()))) { - CI->setPredicate(CI->getInversePredicate()); - return CastInst::Create(Opcode, CI, Op0C->getType()); - } + { + const APInt *RHSC; + if (match(Op1, m_APInt(RHSC))) { + Value *X; + const APInt *C; + if (match(Op0, m_Sub(m_APInt(C), m_Value(X)))) { + // ~(c-X) == X-c-1 == X+(-c-1) + if (RHSC->isAllOnesValue()) { + Constant *NewC = ConstantInt::get(I.getType(), -(*C) - 1); + return BinaryOperator::CreateAdd(X, NewC); + } + if (RHSC->isSignMask()) { + // (C - X) ^ signmask -> (C + signmask - X) + Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); + return BinaryOperator::CreateSub(NewC, X); + } + } else if (match(Op0, m_Add(m_Value(X), m_APInt(C)))) { + // ~(X-c) --> (-c-1)-X + if (RHSC->isAllOnesValue()) { + Constant *NewC = ConstantInt::get(I.getType(), -(*C) - 1); + return BinaryOperator::CreateSub(NewC, X); } + if (RHSC->isSignMask()) { + // (X + C) ^ signmask -> (X + C + signmask) + Constant *NewC = ConstantInt::get(I.getType(), *C + *RHSC); + return BinaryOperator::CreateAdd(X, NewC); + } + } + + // (X|C1)^C2 -> X^(C1^C2) iff X&~C1 == 0 + if (match(Op0, m_Or(m_Value(X), m_APInt(C))) && + MaskedValueIsZero(X, *C, 0, &I)) { + Constant *NewC = ConstantInt::get(I.getType(), *C ^ *RHSC); + Worklist.Add(cast<Instruction>(Op0)); + I.setOperand(0, X); + I.setOperand(1, NewC); + return &I; } } + } + if (ConstantInt *RHSC = dyn_cast<ConstantInt>(Op1)) { if (BinaryOperator *Op0I = dyn_cast<BinaryOperator>(Op0)) { - // ~(c-X) == X-c-1 == X+(-c-1) - if (Op0I->getOpcode() == Instruction::Sub && RHSC->isMinusOne()) - if (Constant *Op0I0C = dyn_cast<Constant>(Op0I->getOperand(0))) { - Constant *NegOp0I0C = ConstantExpr::getNeg(Op0I0C); - return BinaryOperator::CreateAdd(Op0I->getOperand(1), - SubOne(NegOp0I0C)); - } - if (ConstantInt *Op0CI = dyn_cast<ConstantInt>(Op0I->getOperand(1))) { - if (Op0I->getOpcode() == Instruction::Add) { - // ~(X-c) --> (-c-1)-X - if (RHSC->isMinusOne()) { - Constant *NegOp0CI = ConstantExpr::getNeg(Op0CI); - return BinaryOperator::CreateSub(SubOne(NegOp0CI), - Op0I->getOperand(0)); - } else if (RHSC->getValue().isSignMask()) { - // (X + C) ^ signmask -> (X + C + signmask) - Constant *C = Builder.getInt(RHSC->getValue() + Op0CI->getValue()); - return BinaryOperator::CreateAdd(Op0I->getOperand(0), C); - - } - } else if (Op0I->getOpcode() == Instruction::Or) { - // (X|C1)^C2 -> X^(C1|C2) iff X&~C1 == 0 - if (MaskedValueIsZero(Op0I->getOperand(0), Op0CI->getValue(), - 0, &I)) { - Constant *NewRHS = ConstantExpr::getOr(Op0CI, RHSC); - // Anything in both C1 and C2 is known to be zero, remove it from - // NewRHS. - Constant *CommonBits = ConstantExpr::getAnd(Op0CI, RHSC); - NewRHS = ConstantExpr::getAnd(NewRHS, - ConstantExpr::getNot(CommonBits)); - Worklist.Add(Op0I); - I.setOperand(0, Op0I->getOperand(0)); - I.setOperand(1, NewRHS); - return &I; - } - } else if (Op0I->getOpcode() == Instruction::LShr) { + if (Op0I->getOpcode() == Instruction::LShr) { // ((X^C1) >> C2) ^ C3 -> (X>>C2) ^ ((C1>>C2)^C3) // E1 = "X ^ C1" BinaryOperator *E1; @@ -2605,5 +2397,25 @@ Instruction *InstCombiner::visitXor(BinaryOperator &I) { if (Instruction *CastedXor = foldCastedBitwiseLogic(I)) return CastedXor; + // Canonicalize the shifty way to code absolute value to the common pattern. + // There are 4 potential commuted variants. Move the 'ashr' candidate to Op1. + // We're relying on the fact that we only do this transform when the shift has + // exactly 2 uses and the add has exactly 1 use (otherwise, we might increase + // instructions). + if (Op0->getNumUses() == 2) + std::swap(Op0, Op1); + + const APInt *ShAmt; + Type *Ty = I.getType(); + if (match(Op1, m_AShr(m_Value(A), m_APInt(ShAmt))) && + Op1->getNumUses() == 2 && *ShAmt == Ty->getScalarSizeInBits() - 1 && + match(Op0, m_OneUse(m_c_Add(m_Specific(A), m_Specific(Op1))))) { + // B = ashr i32 A, 31 ; smear the sign bit + // xor (add A, B), B ; add -1 and flip bits if negative + // --> (A < 0) ? -A : A + Value *Cmp = Builder.CreateICmpSLT(A, ConstantInt::getNullValue(Ty)); + return SelectInst::Create(Cmp, Builder.CreateNeg(A), A); + } + return Changed ? &I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineCalls.cpp b/lib/Transforms/InstCombine/InstCombineCalls.cpp index 391c430dab75..aa055121e710 100644 --- a/lib/Transforms/InstCombine/InstCombineCalls.cpp +++ b/lib/Transforms/InstCombine/InstCombineCalls.cpp @@ -16,16 +16,20 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" @@ -40,18 +44,26 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Statepoint.h" #include "llvm/IR/Type.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SimplifyLibCalls.h" #include <algorithm> #include <cassert> #include <cstdint> #include <cstring> +#include <utility> #include <vector> using namespace llvm; @@ -94,8 +106,8 @@ static Constant *getNegativeIsTrueBoolVec(ConstantDataVector *V) { return ConstantVector::get(BoolVec); } -Instruction *InstCombiner::SimplifyElementUnorderedAtomicMemCpy( - ElementUnorderedAtomicMemCpyInst *AMI) { +Instruction * +InstCombiner::SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI) { // Try to unfold this intrinsic into sequence of explicit atomic loads and // stores. // First check that number of elements is compile time constant. @@ -515,7 +527,7 @@ static Value *simplifyX86varShift(const IntrinsicInst &II, // If all elements out of range or UNDEF, return vector of zeros/undefs. // ArithmeticShift should only hit this if they are all UNDEF. auto OutOfRange = [&](int Idx) { return (Idx < 0) || (BitWidth <= Idx); }; - if (all_of(ShiftAmts, OutOfRange)) { + if (llvm::all_of(ShiftAmts, OutOfRange)) { SmallVector<Constant *, 8> ConstantVec; for (int Idx : ShiftAmts) { if (Idx < 0) { @@ -1094,72 +1106,6 @@ static Value *simplifyX86vpermv(const IntrinsicInst &II, return Builder.CreateShuffleVector(V1, V2, ShuffleMask); } -/// The shuffle mask for a perm2*128 selects any two halves of two 256-bit -/// source vectors, unless a zero bit is set. If a zero bit is set, -/// then ignore that half of the mask and clear that half of the vector. -static Value *simplifyX86vperm2(const IntrinsicInst &II, - InstCombiner::BuilderTy &Builder) { - auto *CInt = dyn_cast<ConstantInt>(II.getArgOperand(2)); - if (!CInt) - return nullptr; - - VectorType *VecTy = cast<VectorType>(II.getType()); - ConstantAggregateZero *ZeroVector = ConstantAggregateZero::get(VecTy); - - // The immediate permute control byte looks like this: - // [1:0] - select 128 bits from sources for low half of destination - // [2] - ignore - // [3] - zero low half of destination - // [5:4] - select 128 bits from sources for high half of destination - // [6] - ignore - // [7] - zero high half of destination - - uint8_t Imm = CInt->getZExtValue(); - - bool LowHalfZero = Imm & 0x08; - bool HighHalfZero = Imm & 0x80; - - // If both zero mask bits are set, this was just a weird way to - // generate a zero vector. - if (LowHalfZero && HighHalfZero) - return ZeroVector; - - // If 0 or 1 zero mask bits are set, this is a simple shuffle. - unsigned NumElts = VecTy->getNumElements(); - unsigned HalfSize = NumElts / 2; - SmallVector<uint32_t, 8> ShuffleMask(NumElts); - - // The high bit of the selection field chooses the 1st or 2nd operand. - bool LowInputSelect = Imm & 0x02; - bool HighInputSelect = Imm & 0x20; - - // The low bit of the selection field chooses the low or high half - // of the selected operand. - bool LowHalfSelect = Imm & 0x01; - bool HighHalfSelect = Imm & 0x10; - - // Determine which operand(s) are actually in use for this instruction. - Value *V0 = LowInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); - Value *V1 = HighInputSelect ? II.getArgOperand(1) : II.getArgOperand(0); - - // If needed, replace operands based on zero mask. - V0 = LowHalfZero ? ZeroVector : V0; - V1 = HighHalfZero ? ZeroVector : V1; - - // Permute low half of result. - unsigned StartIndex = LowHalfSelect ? HalfSize : 0; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i] = StartIndex + i; - - // Permute high half of result. - StartIndex = HighHalfSelect ? HalfSize : 0; - StartIndex += NumElts; - for (unsigned i = 0; i < HalfSize; ++i) - ShuffleMask[i + HalfSize] = StartIndex + i; - - return Builder.CreateShuffleVector(V0, V1, ShuffleMask); -} - /// Decode XOP integer vector comparison intrinsics. static Value *simplifyX86vpcom(const IntrinsicInst &II, InstCombiner::BuilderTy &Builder, @@ -1650,7 +1596,6 @@ static Instruction *SimplifyNVVMIntrinsic(IntrinsicInst *II, InstCombiner &IC) { // IntrinsicInstr with target-generic LLVM IR. const SimplifyAction Action = [II]() -> SimplifyAction { switch (II->getIntrinsicID()) { - // NVVM intrinsics that map directly to LLVM intrinsics. case Intrinsic::nvvm_ceil_d: return {Intrinsic::ceil, FTZ_Any}; @@ -1932,7 +1877,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { if (Changed) return II; } - if (auto *AMI = dyn_cast<ElementUnorderedAtomicMemCpyInst>(II)) { + if (auto *AMI = dyn_cast<AtomicMemCpyInst>(II)) { if (Constant *C = dyn_cast<Constant>(AMI->getLength())) if (C->isNullValue()) return eraseInstFromFunction(*AMI); @@ -2072,7 +2017,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } case Intrinsic::fmuladd: { // Canonicalize fast fmuladd to the separate fmul + fadd. - if (II->hasUnsafeAlgebra()) { + if (II->isFast()) { BuilderTy::FastMathFlagGuard Guard(Builder); Builder.setFastMathFlags(II->getFastMathFlags()); Value *Mul = Builder.CreateFMul(II->getArgOperand(0), @@ -2248,6 +2193,52 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; + case Intrinsic::x86_bmi_bextr_32: + case Intrinsic::x86_bmi_bextr_64: + case Intrinsic::x86_tbm_bextri_u32: + case Intrinsic::x86_tbm_bextri_u64: + // If the RHS is a constant we can try some simplifications. + if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + uint64_t Shift = C->getZExtValue(); + uint64_t Length = (Shift >> 8) & 0xff; + Shift &= 0xff; + unsigned BitWidth = II->getType()->getIntegerBitWidth(); + // If the length is 0 or the shift is out of range, replace with zero. + if (Length == 0 || Shift >= BitWidth) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); + // If the LHS is also a constant, we can completely constant fold this. + if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { + uint64_t Result = InC->getZExtValue() >> Shift; + if (Length > BitWidth) + Length = BitWidth; + Result &= maskTrailingOnes<uint64_t>(Length); + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); + } + // TODO should we turn this into 'and' if shift is 0? Or 'shl' if we + // are only masking bits that a shift already cleared? + } + break; + + case Intrinsic::x86_bmi_bzhi_32: + case Intrinsic::x86_bmi_bzhi_64: + // If the RHS is a constant we can try some simplifications. + if (auto *C = dyn_cast<ConstantInt>(II->getArgOperand(1))) { + uint64_t Index = C->getZExtValue() & 0xff; + unsigned BitWidth = II->getType()->getIntegerBitWidth(); + if (Index >= BitWidth) + return replaceInstUsesWith(CI, II->getArgOperand(0)); + if (Index == 0) + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), 0)); + // If the LHS is also a constant, we can completely constant fold this. + if (auto *InC = dyn_cast<ConstantInt>(II->getArgOperand(0))) { + uint64_t Result = InC->getZExtValue(); + Result &= maskTrailingOnes<uint64_t>(Index); + return replaceInstUsesWith(CI, ConstantInt::get(II->getType(), Result)); + } + // TODO should we convert this to an AND if the RHS is constant? + } + break; + case Intrinsic::x86_vcvtph2ps_128: case Intrinsic::x86_vcvtph2ps_256: { auto Arg = II->getArgOperand(0); @@ -2333,11 +2324,10 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::x86_sse2_pmovmskb_128: case Intrinsic::x86_avx_movmsk_pd_256: case Intrinsic::x86_avx_movmsk_ps_256: - case Intrinsic::x86_avx2_pmovmskb: { + case Intrinsic::x86_avx2_pmovmskb: if (Value *V = simplifyX86movmsk(*II)) return replaceInstUsesWith(*II, V); break; - } case Intrinsic::x86_sse_comieq_ss: case Intrinsic::x86_sse_comige_ss: @@ -2972,14 +2962,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { } break; - case Intrinsic::x86_avx_vperm2f128_pd_256: - case Intrinsic::x86_avx_vperm2f128_ps_256: - case Intrinsic::x86_avx_vperm2f128_si_256: - case Intrinsic::x86_avx2_vperm2i128: - if (Value *V = simplifyX86vperm2(*II, Builder)) - return replaceInstUsesWith(*II, V); - break; - case Intrinsic::x86_avx_maskload_ps: case Intrinsic::x86_avx_maskload_pd: case Intrinsic::x86_avx_maskload_ps_256: @@ -3399,7 +3381,6 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return II; break; - } case Intrinsic::amdgcn_fmed3: { // Note this does not preserve proper sNaN behavior if IEEE-mode is enabled @@ -3560,6 +3541,21 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { break; } + case Intrinsic::amdgcn_wqm_vote: { + // wqm_vote is identity when the argument is constant. + if (!isa<Constant>(II->getArgOperand(0))) + break; + + return replaceInstUsesWith(*II, II->getArgOperand(0)); + } + case Intrinsic::amdgcn_kill: { + const ConstantInt *C = dyn_cast<ConstantInt>(II->getArgOperand(0)); + if (!C || !C->getZExtValue()) + break; + + // amdgcn.kill(i1 1) is a no-op + return eraseInstFromFunction(CI); + } case Intrinsic::stackrestore: { // If the save is right next to the restore, remove the restore. This can // happen when variable allocas are DCE'd. @@ -3611,7 +3607,8 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { case Intrinsic::lifetime_start: // Asan needs to poison memory to detect invalid access which is possible // even for empty lifetime range. - if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress)) + if (II->getFunction()->hasFnAttribute(Attribute::SanitizeAddress) || + II->getFunction()->hasFnAttribute(Attribute::SanitizeHWAddress)) break; if (removeTriviallyEmptyRange(*II, Intrinsic::lifetime_start, @@ -3697,7 +3694,7 @@ Instruction *InstCombiner::visitCallInst(CallInst &CI) { return replaceInstUsesWith(*II, ConstantPointerNull::get(PT)); // isKnownNonNull -> nonnull attribute - if (isKnownNonNullAt(DerivedPtr, II, &DT)) + if (isKnownNonZero(DerivedPtr, DL, 0, &AC, II, &DT)) II->addAttribute(AttributeList::ReturnIndex, Attribute::NonNull); } @@ -3740,7 +3737,6 @@ Instruction *InstCombiner::visitFenceInst(FenceInst &FI) { } // InvokeInst simplification -// Instruction *InstCombiner::visitInvokeInst(InvokeInst &II) { return visitCallSite(&II); } @@ -3784,7 +3780,7 @@ Instruction *InstCombiner::tryOptimizeCall(CallInst *CI) { auto InstCombineRAUW = [this](Instruction *From, Value *With) { replaceInstUsesWith(*From, With); }; - LibCallSimplifier Simplifier(DL, &TLI, InstCombineRAUW); + LibCallSimplifier Simplifier(DL, &TLI, ORE, InstCombineRAUW); if (Value *With = Simplifier.optimizeCall(CI)) { ++NumSimplified; return CI->use_empty() ? CI : replaceInstUsesWith(*CI, With); @@ -3853,7 +3849,6 @@ static IntrinsicInst *findInitTrampolineFromBB(IntrinsicInst *AdjustTramp, // Given a call to llvm.adjust.trampoline, find and return the corresponding // call to llvm.init.trampoline if the call to the trampoline can be optimized // to a direct call to a function. Otherwise return NULL. -// static IntrinsicInst *findInitTrampoline(Value *Callee) { Callee = Callee->stripPointerCasts(); IntrinsicInst *AdjustTramp = dyn_cast<IntrinsicInst>(Callee); @@ -3886,7 +3881,7 @@ Instruction *InstCombiner::visitCallSite(CallSite CS) { for (Value *V : CS.args()) { if (V->getType()->isPointerTy() && !CS.paramHasAttr(ArgNo, Attribute::NonNull) && - isKnownNonNullAt(V, CS.getInstruction(), &DT)) + isKnownNonZero(V, DL, 0, &AC, CS.getInstruction(), &DT)) ArgNos.push_back(ArgNo); ArgNo++; } @@ -4021,7 +4016,6 @@ bool InstCombiner::transformConstExprCastCall(CallSite CS) { // Okay, this is a cast from a function to a different type. Unless doing so // would cause a type conversion of one of our arguments, change this call to // be a direct call with arguments casted to the appropriate types. - // FunctionType *FT = Callee->getFunctionType(); Type *OldRetTy = Caller->getType(); Type *NewRetTy = FT->getReturnType(); diff --git a/lib/Transforms/InstCombine/InstCombineCasts.cpp b/lib/Transforms/InstCombine/InstCombineCasts.cpp index dfdfd3e9da84..178c8eaf2502 100644 --- a/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -235,8 +235,8 @@ Instruction::CastOps InstCombiner::isEliminableCastPair(const CastInst *CI1, Type *MidTy = CI1->getDestTy(); Type *DstTy = CI2->getDestTy(); - Instruction::CastOps firstOp = Instruction::CastOps(CI1->getOpcode()); - Instruction::CastOps secondOp = Instruction::CastOps(CI2->getOpcode()); + Instruction::CastOps firstOp = CI1->getOpcode(); + Instruction::CastOps secondOp = CI2->getOpcode(); Type *SrcIntPtrTy = SrcTy->isPtrOrPtrVectorTy() ? DL.getIntPtrType(SrcTy) : nullptr; Type *MidIntPtrTy = @@ -346,29 +346,50 @@ static bool canEvaluateTruncated(Value *V, Type *Ty, InstCombiner &IC, } break; } - case Instruction::Shl: + case Instruction::Shl: { // If we are truncating the result of this SHL, and if it's a shift of a // constant amount, we can always perform a SHL in a smaller type. - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t BitWidth = Ty->getScalarSizeInBits(); - if (CI->getLimitedValue(BitWidth) < BitWidth) + if (Amt->getLimitedValue(BitWidth) < BitWidth) return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } break; - case Instruction::LShr: + } + case Instruction::LShr: { // If this is a truncate of a logical shr, we can truncate it to a smaller // lshr iff we know that the bits we would otherwise be shifting in are // already zeros. - if (ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); uint32_t BitWidth = Ty->getScalarSizeInBits(); if (IC.MaskedValueIsZero(I->getOperand(0), APInt::getHighBitsSet(OrigBitWidth, OrigBitWidth-BitWidth), 0, CxtI) && - CI->getLimitedValue(BitWidth) < BitWidth) { + Amt->getLimitedValue(BitWidth) < BitWidth) { return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); } } break; + } + case Instruction::AShr: { + // If this is a truncate of an arithmetic shr, we can truncate it to a + // smaller ashr iff we know that all the bits from the sign bit of the + // original type and the sign bit of the truncate type are similar. + // TODO: It is enough to check that the bits we would be shifting in are + // similar to sign bit of the truncate type. + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { + uint32_t OrigBitWidth = OrigTy->getScalarSizeInBits(); + uint32_t BitWidth = Ty->getScalarSizeInBits(); + if (Amt->getLimitedValue(BitWidth) < BitWidth && + OrigBitWidth - BitWidth < + IC.ComputeNumSignBits(I->getOperand(0), 0, CxtI)) + return canEvaluateTruncated(I->getOperand(0), Ty, IC, CxtI); + } + break; + } case Instruction::Trunc: // trunc(trunc(x)) -> trunc(x) return true; @@ -443,24 +464,130 @@ static Instruction *foldVecTruncToExtElt(TruncInst &Trunc, InstCombiner &IC) { return ExtractElementInst::Create(VecInput, IC.Builder.getInt32(Elt)); } -/// Try to narrow the width of bitwise logic instructions with constants. -Instruction *InstCombiner::shrinkBitwiseLogic(TruncInst &Trunc) { +/// Rotate left/right may occur in a wider type than necessary because of type +/// promotion rules. Try to narrow all of the component instructions. +Instruction *InstCombiner::narrowRotate(TruncInst &Trunc) { + assert((isa<VectorType>(Trunc.getSrcTy()) || + shouldChangeType(Trunc.getSrcTy(), Trunc.getType())) && + "Don't narrow to an illegal scalar type"); + + // First, find an or'd pair of opposite shifts with the same shifted operand: + // trunc (or (lshr ShVal, ShAmt0), (shl ShVal, ShAmt1)) + Value *Or0, *Or1; + if (!match(Trunc.getOperand(0), m_OneUse(m_Or(m_Value(Or0), m_Value(Or1))))) + return nullptr; + + Value *ShVal, *ShAmt0, *ShAmt1; + if (!match(Or0, m_OneUse(m_LogicalShift(m_Value(ShVal), m_Value(ShAmt0)))) || + !match(Or1, m_OneUse(m_LogicalShift(m_Specific(ShVal), m_Value(ShAmt1))))) + return nullptr; + + auto ShiftOpcode0 = cast<BinaryOperator>(Or0)->getOpcode(); + auto ShiftOpcode1 = cast<BinaryOperator>(Or1)->getOpcode(); + if (ShiftOpcode0 == ShiftOpcode1) + return nullptr; + + // The shift amounts must add up to the narrow bit width. + Value *ShAmt; + bool SubIsOnLHS; + Type *DestTy = Trunc.getType(); + unsigned NarrowWidth = DestTy->getScalarSizeInBits(); + if (match(ShAmt0, + m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), m_Specific(ShAmt1))))) { + ShAmt = ShAmt1; + SubIsOnLHS = true; + } else if (match(ShAmt1, m_OneUse(m_Sub(m_SpecificInt(NarrowWidth), + m_Specific(ShAmt0))))) { + ShAmt = ShAmt0; + SubIsOnLHS = false; + } else { + return nullptr; + } + + // The shifted value must have high zeros in the wide type. Typically, this + // will be a zext, but it could also be the result of an 'and' or 'shift'. + unsigned WideWidth = Trunc.getSrcTy()->getScalarSizeInBits(); + APInt HiBitMask = APInt::getHighBitsSet(WideWidth, WideWidth - NarrowWidth); + if (!MaskedValueIsZero(ShVal, HiBitMask, 0, &Trunc)) + return nullptr; + + // We have an unnecessarily wide rotate! + // trunc (or (lshr ShVal, ShAmt), (shl ShVal, BitWidth - ShAmt)) + // Narrow it down to eliminate the zext/trunc: + // or (lshr trunc(ShVal), ShAmt0'), (shl trunc(ShVal), ShAmt1') + Value *NarrowShAmt = Builder.CreateTrunc(ShAmt, DestTy); + Value *NegShAmt = Builder.CreateNeg(NarrowShAmt); + + // Mask both shift amounts to ensure there's no UB from oversized shifts. + Constant *MaskC = ConstantInt::get(DestTy, NarrowWidth - 1); + Value *MaskedShAmt = Builder.CreateAnd(NarrowShAmt, MaskC); + Value *MaskedNegShAmt = Builder.CreateAnd(NegShAmt, MaskC); + + // Truncate the original value and use narrow ops. + Value *X = Builder.CreateTrunc(ShVal, DestTy); + Value *NarrowShAmt0 = SubIsOnLHS ? MaskedNegShAmt : MaskedShAmt; + Value *NarrowShAmt1 = SubIsOnLHS ? MaskedShAmt : MaskedNegShAmt; + Value *NarrowSh0 = Builder.CreateBinOp(ShiftOpcode0, X, NarrowShAmt0); + Value *NarrowSh1 = Builder.CreateBinOp(ShiftOpcode1, X, NarrowShAmt1); + return BinaryOperator::CreateOr(NarrowSh0, NarrowSh1); +} + +/// Try to narrow the width of math or bitwise logic instructions by pulling a +/// truncate ahead of binary operators. +/// TODO: Transforms for truncated shifts should be moved into here. +Instruction *InstCombiner::narrowBinOp(TruncInst &Trunc) { Type *SrcTy = Trunc.getSrcTy(); Type *DestTy = Trunc.getType(); - if (isa<IntegerType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) + if (!isa<VectorType>(SrcTy) && !shouldChangeType(SrcTy, DestTy)) return nullptr; - BinaryOperator *LogicOp; - Constant *C; - if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(LogicOp))) || - !LogicOp->isBitwiseLogicOp() || - !match(LogicOp->getOperand(1), m_Constant(C))) + BinaryOperator *BinOp; + if (!match(Trunc.getOperand(0), m_OneUse(m_BinOp(BinOp)))) return nullptr; - // trunc (logic X, C) --> logic (trunc X, C') - Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); - Value *NarrowOp0 = Builder.CreateTrunc(LogicOp->getOperand(0), DestTy); - return BinaryOperator::Create(LogicOp->getOpcode(), NarrowOp0, NarrowC); + Value *BinOp0 = BinOp->getOperand(0); + Value *BinOp1 = BinOp->getOperand(1); + switch (BinOp->getOpcode()) { + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: { + Constant *C; + if (match(BinOp0, m_Constant(C))) { + // trunc (binop C, X) --> binop (trunc C', X) + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowC, TruncX); + } + if (match(BinOp1, m_Constant(C))) { + // trunc (binop X, C) --> binop (trunc X, C') + Constant *NarrowC = ConstantExpr::getTrunc(C, DestTy); + Value *TruncX = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), TruncX, NarrowC); + } + Value *X; + if (match(BinOp0, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop (ext X), Y) --> binop X, (trunc Y) + Value *NarrowOp1 = Builder.CreateTrunc(BinOp1, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), X, NarrowOp1); + } + if (match(BinOp1, m_ZExtOrSExt(m_Value(X))) && X->getType() == DestTy) { + // trunc (binop Y, (ext X)) --> binop (trunc Y), X + Value *NarrowOp0 = Builder.CreateTrunc(BinOp0, DestTy); + return BinaryOperator::Create(BinOp->getOpcode(), NarrowOp0, X); + } + break; + } + + default: break; + } + + if (Instruction *NarrowOr = narrowRotate(Trunc)) + return NarrowOr; + + return nullptr; } /// Try to narrow the width of a splat shuffle. This could be generalized to any @@ -616,7 +743,7 @@ Instruction *InstCombiner::visitTrunc(TruncInst &CI) { } } - if (Instruction *I = shrinkBitwiseLogic(CI)) + if (Instruction *I = narrowBinOp(CI)) return I; if (Instruction *I = shrinkSplatShuffle(CI, Builder)) @@ -655,13 +782,13 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // If we are just checking for a icmp eq of a single bit and zext'ing it // to an integer, then shift the bit to the appropriate place and then // cast to integer to avoid the comparison. - if (ConstantInt *Op1C = dyn_cast<ConstantInt>(ICI->getOperand(1))) { - const APInt &Op1CV = Op1C->getValue(); + const APInt *Op1CV; + if (match(ICI->getOperand(1), m_APInt(Op1CV))) { // zext (x <s 0) to i32 --> x>>u31 true if signbit set. // zext (x >s -1) to i32 --> (x>>u31)^1 true if signbit clear. - if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV.isNullValue()) || - (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV.isAllOnesValue())) { + if ((ICI->getPredicate() == ICmpInst::ICMP_SLT && Op1CV->isNullValue()) || + (ICI->getPredicate() == ICmpInst::ICMP_SGT && Op1CV->isAllOnesValue())) { if (!DoTransform) return ICI; Value *In = ICI->getOperand(0); @@ -687,7 +814,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, // zext (X != 0) to i32 --> X>>1 iff X has only the 2nd bit set. // zext (X != 1) to i32 --> X^1 iff X has only the low bit set. // zext (X != 2) to i32 --> (X>>1)^1 iff X has only the 2nd bit set. - if ((Op1CV.isNullValue() || Op1CV.isPowerOf2()) && + if ((Op1CV->isNullValue() || Op1CV->isPowerOf2()) && // This only works for EQ and NE ICI->isEquality()) { // If Op1C some other power of two, convert: @@ -698,12 +825,10 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, if (!DoTransform) return ICI; bool isNE = ICI->getPredicate() == ICmpInst::ICMP_NE; - if (!Op1CV.isNullValue() && (Op1CV != KnownZeroMask)) { + if (!Op1CV->isNullValue() && (*Op1CV != KnownZeroMask)) { // (X&4) == 2 --> false // (X&4) != 2 --> true - Constant *Res = ConstantInt::get(Type::getInt1Ty(CI.getContext()), - isNE); - Res = ConstantExpr::getZExt(Res, CI.getType()); + Constant *Res = ConstantInt::get(CI.getType(), isNE); return replaceInstUsesWith(CI, Res); } @@ -716,7 +841,7 @@ Instruction *InstCombiner::transformZExtICmp(ICmpInst *ICI, ZExtInst &CI, In->getName() + ".lobit"); } - if (!Op1CV.isNullValue() == isNE) { // Toggle the low bit. + if (!Op1CV->isNullValue() == isNE) { // Toggle the low bit. Constant *One = ConstantInt::get(In->getType(), 1); In = Builder.CreateXor(In, One); } @@ -833,17 +958,23 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, unsigned VSize = V->getType()->getScalarSizeInBits(); if (IC.MaskedValueIsZero(I->getOperand(1), APInt::getHighBitsSet(VSize, BitsToClear), - 0, CxtI)) + 0, CxtI)) { + // If this is an And instruction and all of the BitsToClear are + // known to be zero we can reset BitsToClear. + if (Opc == Instruction::And) + BitsToClear = 0; return true; + } } // Otherwise, we don't know how to analyze this BitsToClear case yet. return false; - case Instruction::Shl: + case Instruction::Shl: { // We can promote shl(x, cst) if we can promote x. Since shl overwrites the // upper bits we can reduce BitsToClear by the shift amount. - if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; uint64_t ShiftAmt = Amt->getZExtValue(); @@ -851,10 +982,12 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, return true; } return false; - case Instruction::LShr: + } + case Instruction::LShr: { // We can promote lshr(x, cst) if we can promote x. This requires the // ultimate 'and' to clear out the high zero bits we're clearing out though. - if (ConstantInt *Amt = dyn_cast<ConstantInt>(I->getOperand(1))) { + const APInt *Amt; + if (match(I->getOperand(1), m_APInt(Amt))) { if (!canEvaluateZExtd(I->getOperand(0), Ty, BitsToClear, IC, CxtI)) return false; BitsToClear += Amt->getZExtValue(); @@ -864,6 +997,7 @@ static bool canEvaluateZExtd(Value *V, Type *Ty, unsigned &BitsToClear, } // Cannot promote variable LSHR. return false; + } case Instruction::Select: if (!canEvaluateZExtd(I->getOperand(1), Ty, Tmp, IC, CxtI) || !canEvaluateZExtd(I->getOperand(2), Ty, BitsToClear, IC, CxtI) || diff --git a/lib/Transforms/InstCombine/InstCombineCompares.cpp b/lib/Transforms/InstCombine/InstCombineCompares.cpp index a8faaecb5c34..3bc7fae77cb1 100644 --- a/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -17,9 +17,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" @@ -37,77 +35,30 @@ using namespace PatternMatch; STATISTIC(NumSel, "Number of select opts"); -static ConstantInt *extractElement(Constant *V, Constant *Idx) { - return cast<ConstantInt>(ConstantExpr::getExtractElement(V, Idx)); -} - -static bool hasAddOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ult(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().sgt(In1->getValue()); - return Result->getValue().slt(In1->getValue()); -} - /// Compute Result = In1+In2, returning true if the result overflowed for this /// type. -static bool addWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getAdd(In1, In2); - - if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasAddOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } - - return hasAddOverflow(cast<ConstantInt>(Result), - cast<ConstantInt>(In1), cast<ConstantInt>(In2), - IsSigned); -} - -static bool hasSubOverflow(ConstantInt *Result, - ConstantInt *In1, ConstantInt *In2, - bool IsSigned) { - if (!IsSigned) - return Result->getValue().ugt(In1->getValue()); - - if (In2->isNegative()) - return Result->getValue().slt(In1->getValue()); +static bool addWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.sadd_ov(In2, Overflow); + else + Result = In1.uadd_ov(In2, Overflow); - return Result->getValue().sgt(In1->getValue()); + return Overflow; } /// Compute Result = In1-In2, returning true if the result overflowed for this /// type. -static bool subWithOverflow(Constant *&Result, Constant *In1, - Constant *In2, bool IsSigned = false) { - Result = ConstantExpr::getSub(In1, In2); - - if (VectorType *VTy = dyn_cast<VectorType>(In1->getType())) { - for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { - Constant *Idx = ConstantInt::get(Type::getInt32Ty(In1->getContext()), i); - if (hasSubOverflow(extractElement(Result, Idx), - extractElement(In1, Idx), - extractElement(In2, Idx), - IsSigned)) - return true; - } - return false; - } +static bool subWithOverflow(APInt &Result, const APInt &In1, + const APInt &In2, bool IsSigned = false) { + bool Overflow; + if (IsSigned) + Result = In1.ssub_ov(In2, Overflow); + else + Result = In1.usub_ov(In2, Overflow); - return hasSubOverflow(cast<ConstantInt>(Result), - cast<ConstantInt>(In1), cast<ConstantInt>(In2), - IsSigned); + return Overflow; } /// Given an icmp instruction, return true if any use of this comparison is a @@ -473,8 +424,7 @@ Instruction *InstCombiner::foldCmpLoadFromIndexedGlobal(GetElementPtrInst *GEP, // Look for an appropriate type: // - The type of Idx if the magic fits - // - The smallest fitting legal type if we have a DataLayout - // - Default to i32 + // - The smallest fitting legal type if (ArrayElementCount <= Idx->getType()->getIntegerBitWidth()) Ty = Idx->getType(); else @@ -1108,7 +1058,6 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, // because we don't allow ptrtoint. Memcpy and memmove are safe because // we don't allow stores, so src cannot point to V. case Intrinsic::lifetime_start: case Intrinsic::lifetime_end: - case Intrinsic::dbg_declare: case Intrinsic::dbg_value: case Intrinsic::memcpy: case Intrinsic::memmove: case Intrinsic::memset: continue; default: @@ -1131,8 +1080,7 @@ Instruction *InstCombiner::foldAllocaCmp(ICmpInst &ICI, } /// Fold "icmp pred (X+CI), X". -Instruction *InstCombiner::foldICmpAddOpConst(Instruction &ICI, - Value *X, ConstantInt *CI, +Instruction *InstCombiner::foldICmpAddOpConst(Value *X, ConstantInt *CI, ICmpInst::Predicate Pred) { // From this point on, we know that (X+C <= X) --> (X+C < X) because C != 0, // so the values can never be equal. Similarly for all other "or equals" @@ -1367,6 +1315,24 @@ static Instruction *processUGT_ADDCST_ADD(ICmpInst &I, Value *A, Value *B, return ExtractValueInst::Create(Call, 1, "sadd.overflow"); } +// Handle (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) +Instruction *InstCombiner::foldICmpWithZero(ICmpInst &Cmp) { + CmpInst::Predicate Pred = Cmp.getPredicate(); + Value *X = Cmp.getOperand(0); + + if (match(Cmp.getOperand(1), m_Zero()) && Pred == ICmpInst::ICMP_SGT) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(X, A, B); + if (SPR.Flavor == SPF_SMIN) { + if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, B, Cmp.getOperand(1)); + if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) + return new ICmpInst(Pred, A, Cmp.getOperand(1)); + } + } + return nullptr; +} + // Fold icmp Pred X, C. Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { CmpInst::Predicate Pred = Cmp.getPredicate(); @@ -1398,17 +1364,6 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { return Res; } - // (icmp sgt smin(PosA, B) 0) -> (icmp sgt B 0) - if (C->isNullValue() && Pred == ICmpInst::ICMP_SGT) { - SelectPatternResult SPR = matchSelectPattern(X, A, B); - if (SPR.Flavor == SPF_SMIN) { - if (isKnownPositive(A, DL, 0, &AC, &Cmp, &DT)) - return new ICmpInst(Pred, B, Cmp.getOperand(1)); - if (isKnownPositive(B, DL, 0, &AC, &Cmp, &DT)) - return new ICmpInst(Pred, A, Cmp.getOperand(1)); - } - } - // FIXME: Use m_APInt to allow folds for splat constants. ConstantInt *CI = dyn_cast<ConstantInt>(Cmp.getOperand(1)); if (!CI) @@ -1462,11 +1417,11 @@ Instruction *InstCombiner::foldICmpWithConstant(ICmpInst &Cmp) { /// Fold icmp (trunc X, Y), C. Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, - Instruction *Trunc, - const APInt *C) { + TruncInst *Trunc, + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); Value *X = Trunc->getOperand(0); - if (C->isOneValue() && C->getBitWidth() > 1) { + if (C.isOneValue() && C.getBitWidth() > 1) { // icmp slt trunc(signum(V)) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(X, m_Signum(m_Value(V)))) @@ -1484,7 +1439,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, // If all the high bits are known, we can do this xform. if ((Known.Zero | Known.One).countLeadingOnes() >= SrcBits - DstBits) { // Pull in the high bits from known-ones set. - APInt NewRHS = C->zext(SrcBits); + APInt NewRHS = C.zext(SrcBits); NewRHS |= Known.One & APInt::getHighBitsSet(SrcBits, SrcBits - DstBits); return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), NewRHS)); } @@ -1496,7 +1451,7 @@ Instruction *InstCombiner::foldICmpTruncConstant(ICmpInst &Cmp, /// Fold icmp (xor X, Y), C. Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, - const APInt *C) { + const APInt &C) { Value *X = Xor->getOperand(0); Value *Y = Xor->getOperand(1); const APInt *XorC; @@ -1506,8 +1461,8 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, // If this is a comparison that tests the signbit (X < 0) or (x > -1), // fold the xor. ICmpInst::Predicate Pred = Cmp.getPredicate(); - if ((Pred == ICmpInst::ICMP_SLT && C->isNullValue()) || - (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue())) { + bool TrueIfSigned = false; + if (isSignBitCheck(Cmp.getPredicate(), C, TrueIfSigned)) { // If the sign bit of the XorCst is not set, there is no change to // the operation, just stop using the Xor. @@ -1517,17 +1472,13 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, return &Cmp; } - // Was the old condition true if the operand is positive? - bool isTrueIfPositive = Pred == ICmpInst::ICMP_SGT; - - // If so, the new one isn't. - isTrueIfPositive ^= true; - - Constant *CmpConstant = cast<Constant>(Cmp.getOperand(1)); - if (isTrueIfPositive) - return new ICmpInst(ICmpInst::ICMP_SGT, X, SubOne(CmpConstant)); + // Emit the opposite comparison. + if (TrueIfSigned) + return new ICmpInst(ICmpInst::ICMP_SGT, X, + ConstantInt::getAllOnesValue(X->getType())); else - return new ICmpInst(ICmpInst::ICMP_SLT, X, AddOne(CmpConstant)); + return new ICmpInst(ICmpInst::ICMP_SLT, X, + ConstantInt::getNullValue(X->getType())); } if (Xor->hasOneUse()) { @@ -1535,7 +1486,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, if (!Cmp.isEquality() && XorC->isSignMask()) { Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } // (icmp u/s (xor X ~SignMask), C) -> (icmp s/u X, (xor C ~SignMask)) @@ -1543,18 +1494,18 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, Pred = Cmp.isSigned() ? Cmp.getUnsignedPredicate() : Cmp.getSignedPredicate(); Pred = Cmp.getSwappedPredicate(Pred); - return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), *C ^ *XorC)); + return new ICmpInst(Pred, X, ConstantInt::get(X->getType(), C ^ *XorC)); } } // (icmp ugt (xor X, C), ~C) -> (icmp ult X, C) // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && *XorC == ~(*C) && (*C + 1).isPowerOf2()) + if (Pred == ICmpInst::ICMP_UGT && *XorC == ~C && (C + 1).isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_ULT, X, Y); // (icmp ult (xor X, C), -C) -> (icmp uge X, C) // iff -C is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && *XorC == -(*C) && C->isPowerOf2()) + if (Pred == ICmpInst::ICMP_ULT && *XorC == -C && C.isPowerOf2()) return new ICmpInst(ICmpInst::ICMP_UGE, X, Y); return nullptr; @@ -1562,7 +1513,7 @@ Instruction *InstCombiner::foldICmpXorConstant(ICmpInst &Cmp, /// Fold icmp (and (sh X, Y), C2), C1. Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1, const APInt *C2) { + const APInt &C1, const APInt &C2) { BinaryOperator *Shift = dyn_cast<BinaryOperator>(And->getOperand(0)); if (!Shift || !Shift->isShift()) return nullptr; @@ -1577,32 +1528,35 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, const APInt *C3; if (match(Shift->getOperand(1), m_APInt(C3))) { bool CanFold = false; - if (ShiftOpcode == Instruction::AShr) { - // There may be some constraints that make this possible, but nothing - // simple has been discovered yet. - CanFold = false; - } else if (ShiftOpcode == Instruction::Shl) { + if (ShiftOpcode == Instruction::Shl) { // For a left shift, we can fold if the comparison is not signed. We can // also fold a signed comparison if the mask value and comparison value // are not negative. These constraints may not be obvious, but we can // prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || (!C2->isNegative() && !C1->isNegative())) + if (!Cmp.isSigned() || (!C2.isNegative() && !C1.isNegative())) CanFold = true; - } else if (ShiftOpcode == Instruction::LShr) { + } else { + bool IsAshr = ShiftOpcode == Instruction::AShr; // For a logical right shift, we can fold if the comparison is not signed. // We can also fold a signed comparison if the shifted mask value and the // shifted comparison value are not negative. These constraints may not be // obvious, but we can prove that they are correct using an SMT solver. - if (!Cmp.isSigned() || - (!C2->shl(*C3).isNegative() && !C1->shl(*C3).isNegative())) - CanFold = true; + // For an arithmetic shift right we can do the same, if we ensure + // the And doesn't use any bits being shifted in. Normally these would + // be turned into lshr by SimplifyDemandedBits, but not if there is an + // additional user. + if (!IsAshr || (C2.shl(*C3).lshr(*C3) == C2)) { + if (!Cmp.isSigned() || + (!C2.shl(*C3).isNegative() && !C1.shl(*C3).isNegative())) + CanFold = true; + } } if (CanFold) { - APInt NewCst = IsShl ? C1->lshr(*C3) : C1->shl(*C3); + APInt NewCst = IsShl ? C1.lshr(*C3) : C1.shl(*C3); APInt SameAsC1 = IsShl ? NewCst.shl(*C3) : NewCst.lshr(*C3); // Check to see if we are shifting out any of the bits being compared. - if (SameAsC1 != *C1) { + if (SameAsC1 != C1) { // If we shifted bits out, the fold is not going to work out. As a // special case, check to see if this means that the result is always // true or false now. @@ -1612,7 +1566,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, return replaceInstUsesWith(Cmp, ConstantInt::getTrue(Cmp.getType())); } else { Cmp.setOperand(1, ConstantInt::get(And->getType(), NewCst)); - APInt NewAndCst = IsShl ? C2->lshr(*C3) : C2->shl(*C3); + APInt NewAndCst = IsShl ? C2.lshr(*C3) : C2.shl(*C3); And->setOperand(1, ConstantInt::get(And->getType(), NewAndCst)); And->setOperand(0, Shift->getOperand(0)); Worklist.Add(Shift); // Shift is dead. @@ -1624,7 +1578,7 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, // Turn ((X >> Y) & C2) == 0 into (X & (C2 << Y)) == 0. The latter is // preferable because it allows the C2 << Y expression to be hoisted out of a // loop if Y is invariant and X is not. - if (Shift->hasOneUse() && C1->isNullValue() && Cmp.isEquality() && + if (Shift->hasOneUse() && C1.isNullValue() && Cmp.isEquality() && !Shift->isArithmeticShift() && !isa<Constant>(Shift->getOperand(0))) { // Compute C2 << Y. Value *NewShift = @@ -1643,12 +1597,12 @@ Instruction *InstCombiner::foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, /// Fold icmp (and X, C2), C1. Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1) { + const APInt &C1) { const APInt *C2; if (!match(And->getOperand(1), m_APInt(C2))) return nullptr; - if (!And->hasOneUse() || !And->getOperand(0)->hasOneUse()) + if (!And->hasOneUse()) return nullptr; // If the LHS is an 'and' of a truncate and we can widen the and/compare to @@ -1660,29 +1614,29 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, // set or if it is an equality comparison. Extending a relational comparison // when we're checking the sign bit would not work. Value *W; - if (match(And->getOperand(0), m_Trunc(m_Value(W))) && - (Cmp.isEquality() || (!C1->isNegative() && !C2->isNegative()))) { + if (match(And->getOperand(0), m_OneUse(m_Trunc(m_Value(W)))) && + (Cmp.isEquality() || (!C1.isNegative() && !C2->isNegative()))) { // TODO: Is this a good transform for vectors? Wider types may reduce // throughput. Should this transform be limited (even for scalars) by using // shouldChangeType()? if (!Cmp.getType()->isVectorTy()) { Type *WideType = W->getType(); unsigned WideScalarBits = WideType->getScalarSizeInBits(); - Constant *ZextC1 = ConstantInt::get(WideType, C1->zext(WideScalarBits)); + Constant *ZextC1 = ConstantInt::get(WideType, C1.zext(WideScalarBits)); Constant *ZextC2 = ConstantInt::get(WideType, C2->zext(WideScalarBits)); Value *NewAnd = Builder.CreateAnd(W, ZextC2, And->getName()); return new ICmpInst(Cmp.getPredicate(), NewAnd, ZextC1); } } - if (Instruction *I = foldICmpAndShift(Cmp, And, C1, C2)) + if (Instruction *I = foldICmpAndShift(Cmp, And, C1, *C2)) return I; // (icmp pred (and (or (lshr A, B), A), 1), 0) --> // (icmp pred (and A, (or (shl 1, B), 1), 0)) // // iff pred isn't signed - if (!Cmp.isSigned() && C1->isNullValue() && + if (!Cmp.isSigned() && C1.isNullValue() && And->getOperand(0)->hasOneUse() && match(And->getOperand(1), m_One())) { Constant *One = cast<Constant>(And->getOperand(1)); Value *Or = And->getOperand(0); @@ -1716,22 +1670,13 @@ Instruction *InstCombiner::foldICmpAndConstConst(ICmpInst &Cmp, } } - // (X & C2) > C1 --> (X & C2) != 0, if any bit set in (X & C2) will produce a - // result greater than C1. - unsigned NumTZ = C2->countTrailingZeros(); - if (Cmp.getPredicate() == ICmpInst::ICMP_UGT && NumTZ < C2->getBitWidth() && - APInt::getOneBitSet(C2->getBitWidth(), NumTZ).ugt(*C1)) { - Constant *Zero = Constant::getNullValue(And->getType()); - return new ICmpInst(ICmpInst::ICMP_NE, And, Zero); - } - return nullptr; } /// Fold icmp (and X, Y), C. Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C) { + const APInt &C) { if (Instruction *I = foldICmpAndConstConst(Cmp, And, C)) return I; @@ -1756,7 +1701,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, // X & -C == -C -> X > u ~C // X & -C != -C -> X <= u ~C // iff C is a power of 2 - if (Cmp.getOperand(1) == Y && (-(*C)).isPowerOf2()) { + if (Cmp.getOperand(1) == Y && (-C).isPowerOf2()) { auto NewPred = Cmp.getPredicate() == CmpInst::ICMP_EQ ? CmpInst::ICMP_UGT : CmpInst::ICMP_ULE; return new ICmpInst(NewPred, X, SubOne(cast<Constant>(Cmp.getOperand(1)))); @@ -1766,7 +1711,7 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, // (X & C2) != 0 -> (trunc X) < 0 // iff C2 is a power of 2 and it masks the sign bit of a legal integer type. const APInt *C2; - if (And->hasOneUse() && C->isNullValue() && match(Y, m_APInt(C2))) { + if (And->hasOneUse() && C.isNullValue() && match(Y, m_APInt(C2))) { int32_t ExactLogBase2 = C2->exactLogBase2(); if (ExactLogBase2 != -1 && DL.isLegalInteger(ExactLogBase2 + 1)) { Type *NTy = IntegerType::get(Cmp.getContext(), ExactLogBase2 + 1); @@ -1784,9 +1729,9 @@ Instruction *InstCombiner::foldICmpAndConstant(ICmpInst &Cmp, /// Fold icmp (or X, Y), C. Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, - const APInt *C) { + const APInt &C) { ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (C->isOneValue()) { + if (C.isOneValue()) { // icmp slt signum(V) 1 --> icmp slt V, 1 Value *V = nullptr; if (Pred == ICmpInst::ICMP_SLT && match(Or, m_Signum(m_Value(V)))) @@ -1798,12 +1743,12 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, // X | C != C --> X >u C // iff C+1 is a power of 2 (C is a bitmask of the low bits) if (Cmp.isEquality() && Cmp.getOperand(1) == Or->getOperand(1) && - (*C + 1).isPowerOf2()) { + (C + 1).isPowerOf2()) { Pred = (Pred == CmpInst::ICMP_EQ) ? CmpInst::ICMP_ULE : CmpInst::ICMP_UGT; return new ICmpInst(Pred, Or->getOperand(0), Or->getOperand(1)); } - if (!Cmp.isEquality() || !C->isNullValue() || !Or->hasOneUse()) + if (!Cmp.isEquality() || !C.isNullValue() || !Or->hasOneUse()) return nullptr; Value *P, *Q; @@ -1837,7 +1782,7 @@ Instruction *InstCombiner::foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, /// Fold icmp (mul X, Y), C. Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, - const APInt *C) { + const APInt &C) { const APInt *MulC; if (!match(Mul->getOperand(1), m_APInt(MulC))) return nullptr; @@ -1845,7 +1790,7 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, // If this is a test of the sign bit and the multiply is sign-preserving with // a constant operand, use the multiply LHS operand instead. ICmpInst::Predicate Pred = Cmp.getPredicate(); - if (isSignTest(Pred, *C) && Mul->hasNoSignedWrap()) { + if (isSignTest(Pred, C) && Mul->hasNoSignedWrap()) { if (MulC->isNegative()) Pred = ICmpInst::getSwappedPredicate(Pred); return new ICmpInst(Pred, Mul->getOperand(0), @@ -1857,14 +1802,14 @@ Instruction *InstCombiner::foldICmpMulConstant(ICmpInst &Cmp, /// Fold icmp (shl 1, Y), C. static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, - const APInt *C) { + const APInt &C) { Value *Y; if (!match(Shl, m_Shl(m_One(), m_Value(Y)))) return nullptr; Type *ShiftType = Shl->getType(); - uint32_t TypeBits = C->getBitWidth(); - bool CIsPowerOf2 = C->isPowerOf2(); + unsigned TypeBits = C.getBitWidth(); + bool CIsPowerOf2 = C.isPowerOf2(); ICmpInst::Predicate Pred = Cmp.getPredicate(); if (Cmp.isUnsigned()) { // (1 << Y) pred C -> Y pred Log2(C) @@ -1881,7 +1826,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, // (1 << Y) >= 2147483648 -> Y >= 31 -> Y == 31 // (1 << Y) < 2147483648 -> Y < 31 -> Y != 31 - unsigned CLog2 = C->logBase2(); + unsigned CLog2 = C.logBase2(); if (CLog2 == TypeBits - 1) { if (Pred == ICmpInst::ICMP_UGE) Pred = ICmpInst::ICMP_EQ; @@ -1891,7 +1836,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, CLog2)); } else if (Cmp.isSigned()) { Constant *BitWidthMinusOne = ConstantInt::get(ShiftType, TypeBits - 1); - if (C->isAllOnesValue()) { + if (C.isAllOnesValue()) { // (1 << Y) <= -1 -> Y == 31 if (Pred == ICmpInst::ICMP_SLE) return new ICmpInst(ICmpInst::ICMP_EQ, Y, BitWidthMinusOne); @@ -1899,7 +1844,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, // (1 << Y) > -1 -> Y != 31 if (Pred == ICmpInst::ICMP_SGT) return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); - } else if (!(*C)) { + } else if (!C) { // (1 << Y) < 0 -> Y == 31 // (1 << Y) <= 0 -> Y == 31 if (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SLE) @@ -1911,7 +1856,7 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, return new ICmpInst(ICmpInst::ICMP_NE, Y, BitWidthMinusOne); } } else if (Cmp.isEquality() && CIsPowerOf2) { - return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C->logBase2())); + return new ICmpInst(Pred, Y, ConstantInt::get(ShiftType, C.logBase2())); } return nullptr; @@ -1920,10 +1865,10 @@ static Instruction *foldICmpShlOne(ICmpInst &Cmp, Instruction *Shl, /// Fold icmp (shl X, Y), C. Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, - const APInt *C) { + const APInt &C) { const APInt *ShiftVal; if (Cmp.isEquality() && match(Shl->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShlConstConst(Cmp, Shl->getOperand(1), *C, *ShiftVal); + return foldICmpShlConstConst(Cmp, Shl->getOperand(1), C, *ShiftVal); const APInt *ShiftAmt; if (!match(Shl->getOperand(1), m_APInt(ShiftAmt))) @@ -1931,7 +1876,7 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited, it will be simplified. - unsigned TypeBits = C->getBitWidth(); + unsigned TypeBits = C.getBitWidth(); if (ShiftAmt->uge(TypeBits)) return nullptr; @@ -1945,15 +1890,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasNoSignedWrap()) { if (Pred == ICmpInst::ICMP_SGT) { // icmp Pred (shl nsw X, ShiftAmt), C --> icmp Pred X, (C >>s ShiftAmt) - APInt ShiftedC = C->ashr(*ShiftAmt); + APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { // This is the same code as the SGT case, but assert the pre-condition // that is needed for this to work with equality predicates. - assert(C->ashr(*ShiftAmt).shl(*ShiftAmt) == *C && + assert(C.ashr(*ShiftAmt).shl(*ShiftAmt) == C && "Compare known true or false was not folded"); - APInt ShiftedC = C->ashr(*ShiftAmt); + APInt ShiftedC = C.ashr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_SLT) { @@ -1961,14 +1906,14 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // (X << S) <=s C is equiv to X <=s (C >> S) for all C // (X << S) <s (C + 1) is equiv to X <s (C >> S) + 1 if C <s SMAX // (X << S) <s C is equiv to X <s ((C - 1) >> S) + 1 if C >s SMIN - assert(!C->isMinSignedValue() && "Unexpected icmp slt"); - APInt ShiftedC = (*C - 1).ashr(*ShiftAmt) + 1; + assert(!C.isMinSignedValue() && "Unexpected icmp slt"); + APInt ShiftedC = (C - 1).ashr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } // If this is a signed comparison to 0 and the shift is sign preserving, // use the shift LHS operand instead; isSignTest may change 'Pred', so only // do that if we're sure to not continue on in this function. - if (isSignTest(Pred, *C)) + if (isSignTest(Pred, C)) return new ICmpInst(Pred, X, Constant::getNullValue(ShType)); } @@ -1978,15 +1923,15 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, if (Shl->hasNoUnsignedWrap()) { if (Pred == ICmpInst::ICMP_UGT) { // icmp Pred (shl nuw X, ShiftAmt), C --> icmp Pred X, (C >>u ShiftAmt) - APInt ShiftedC = C->lshr(*ShiftAmt); + APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) { // This is the same code as the UGT case, but assert the pre-condition // that is needed for this to work with equality predicates. - assert(C->lshr(*ShiftAmt).shl(*ShiftAmt) == *C && + assert(C.lshr(*ShiftAmt).shl(*ShiftAmt) == C && "Compare known true or false was not folded"); - APInt ShiftedC = C->lshr(*ShiftAmt); + APInt ShiftedC = C.lshr(*ShiftAmt); return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } if (Pred == ICmpInst::ICMP_ULT) { @@ -1994,8 +1939,8 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // (X << S) <=u C is equiv to X <=u (C >> S) for all C // (X << S) <u (C + 1) is equiv to X <u (C >> S) + 1 if C <u ~0u // (X << S) <u C is equiv to X <u ((C - 1) >> S) + 1 if C >u 0 - assert(C->ugt(0) && "ult 0 should have been eliminated"); - APInt ShiftedC = (*C - 1).lshr(*ShiftAmt) + 1; + assert(C.ugt(0) && "ult 0 should have been eliminated"); + APInt ShiftedC = (C - 1).lshr(*ShiftAmt) + 1; return new ICmpInst(Pred, X, ConstantInt::get(ShType, ShiftedC)); } } @@ -2006,13 +1951,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, ShType, APInt::getLowBitsSet(TypeBits, TypeBits - ShiftAmt->getZExtValue())); Value *And = Builder.CreateAnd(X, Mask, Shl->getName() + ".mask"); - Constant *LShrC = ConstantInt::get(ShType, C->lshr(*ShiftAmt)); + Constant *LShrC = ConstantInt::get(ShType, C.lshr(*ShiftAmt)); return new ICmpInst(Pred, And, LShrC); } // Otherwise, if this is a comparison of the sign bit, simplify to and/test. bool TrueIfSigned = false; - if (Shl->hasOneUse() && isSignBitCheck(Pred, *C, TrueIfSigned)) { + if (Shl->hasOneUse() && isSignBitCheck(Pred, C, TrueIfSigned)) { // (X << 31) <s 0 --> (X & 1) != 0 Constant *Mask = ConstantInt::get( ShType, @@ -2029,13 +1974,13 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, // free on the target. It has the additional benefit of comparing to a // smaller constant that may be more target-friendly. unsigned Amt = ShiftAmt->getLimitedValue(TypeBits - 1); - if (Shl->hasOneUse() && Amt != 0 && C->countTrailingZeros() >= Amt && + if (Shl->hasOneUse() && Amt != 0 && C.countTrailingZeros() >= Amt && DL.isLegalInteger(TypeBits - Amt)) { Type *TruncTy = IntegerType::get(Cmp.getContext(), TypeBits - Amt); if (ShType->isVectorTy()) TruncTy = VectorType::get(TruncTy, ShType->getVectorNumElements()); Constant *NewC = - ConstantInt::get(TruncTy, C->ashr(*ShiftAmt).trunc(TypeBits - Amt)); + ConstantInt::get(TruncTy, C.ashr(*ShiftAmt).trunc(TypeBits - Amt)); return new ICmpInst(Pred, Builder.CreateTrunc(X, TruncTy), NewC); } @@ -2045,18 +1990,18 @@ Instruction *InstCombiner::foldICmpShlConstant(ICmpInst &Cmp, /// Fold icmp ({al}shr X, Y), C. Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, - const APInt *C) { + const APInt &C) { // An exact shr only shifts out zero bits, so: // icmp eq/ne (shr X, Y), 0 --> icmp eq/ne X, 0 Value *X = Shr->getOperand(0); CmpInst::Predicate Pred = Cmp.getPredicate(); if (Cmp.isEquality() && Shr->isExact() && Shr->hasOneUse() && - C->isNullValue()) + C.isNullValue()) return new ICmpInst(Pred, X, Cmp.getOperand(1)); const APInt *ShiftVal; if (Cmp.isEquality() && match(Shr->getOperand(0), m_APInt(ShiftVal))) - return foldICmpShrConstConst(Cmp, Shr->getOperand(1), *C, *ShiftVal); + return foldICmpShrConstConst(Cmp, Shr->getOperand(1), C, *ShiftVal); const APInt *ShiftAmt; if (!match(Shr->getOperand(1), m_APInt(ShiftAmt))) @@ -2064,71 +2009,73 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, // Check that the shift amount is in range. If not, don't perform undefined // shifts. When the shift is visited it will be simplified. - unsigned TypeBits = C->getBitWidth(); + unsigned TypeBits = C.getBitWidth(); unsigned ShAmtVal = ShiftAmt->getLimitedValue(TypeBits); if (ShAmtVal >= TypeBits || ShAmtVal == 0) return nullptr; bool IsAShr = Shr->getOpcode() == Instruction::AShr; - if (!Cmp.isEquality()) { - // If we have an unsigned comparison and an ashr, we can't simplify this. - // Similarly for signed comparisons with lshr. - if (Cmp.isSigned() != IsAShr) - return nullptr; - - // Otherwise, all lshr and most exact ashr's are equivalent to a udiv/sdiv - // by a power of 2. Since we already have logic to simplify these, - // transform to div and then simplify the resultant comparison. - if (IsAShr && (!Shr->isExact() || ShAmtVal == TypeBits - 1)) - return nullptr; - - // Revisit the shift (to delete it). - Worklist.Add(Shr); - - Constant *DivCst = ConstantInt::get( - Shr->getType(), APInt::getOneBitSet(TypeBits, ShAmtVal)); - - Value *Tmp = IsAShr ? Builder.CreateSDiv(X, DivCst, "", Shr->isExact()) - : Builder.CreateUDiv(X, DivCst, "", Shr->isExact()); - - Cmp.setOperand(0, Tmp); - - // If the builder folded the binop, just return it. - BinaryOperator *TheDiv = dyn_cast<BinaryOperator>(Tmp); - if (!TheDiv) - return &Cmp; - - // Otherwise, fold this div/compare. - assert(TheDiv->getOpcode() == Instruction::SDiv || - TheDiv->getOpcode() == Instruction::UDiv); - - Instruction *Res = foldICmpDivConstant(Cmp, TheDiv, C); - assert(Res && "This div/cst should have folded!"); - return Res; + bool IsExact = Shr->isExact(); + Type *ShrTy = Shr->getType(); + // TODO: If we could guarantee that InstSimplify would handle all of the + // constant-value-based preconditions in the folds below, then we could assert + // those conditions rather than checking them. This is difficult because of + // undef/poison (PR34838). + if (IsAShr) { + if (Pred == CmpInst::ICMP_SLT || (Pred == CmpInst::ICMP_SGT && IsExact)) { + // icmp slt (ashr X, ShAmtC), C --> icmp slt X, (C << ShAmtC) + // icmp sgt (ashr exact X, ShAmtC), C --> icmp sgt X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.ashr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_SGT) { + // icmp sgt (ashr X, ShAmtC), C --> icmp sgt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if (!C.isMaxSignedValue() && !(C + 1).shl(ShAmtVal).isMinSignedValue() && + (ShiftedC + 1).ashr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + } else { + if (Pred == CmpInst::ICMP_ULT || (Pred == CmpInst::ICMP_UGT && IsExact)) { + // icmp ult (lshr X, ShAmtC), C --> icmp ult X, (C << ShAmtC) + // icmp ugt (lshr exact X, ShAmtC), C --> icmp ugt X, (C << ShAmtC) + APInt ShiftedC = C.shl(ShAmtVal); + if (ShiftedC.lshr(ShAmtVal) == C) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } + if (Pred == CmpInst::ICMP_UGT) { + // icmp ugt (lshr X, ShAmtC), C --> icmp ugt X, ((C + 1) << ShAmtC) - 1 + APInt ShiftedC = (C + 1).shl(ShAmtVal) - 1; + if ((ShiftedC + 1).lshr(ShAmtVal) == (C + 1)) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, ShiftedC)); + } } + if (!Cmp.isEquality()) + return nullptr; + // Handle equality comparisons of shift-by-constant. // If the comparison constant changes with the shift, the comparison cannot // succeed (bits of the comparison constant cannot match the shifted value). // This should be known by InstSimplify and already be folded to true/false. - assert(((IsAShr && C->shl(ShAmtVal).ashr(ShAmtVal) == *C) || - (!IsAShr && C->shl(ShAmtVal).lshr(ShAmtVal) == *C)) && + assert(((IsAShr && C.shl(ShAmtVal).ashr(ShAmtVal) == C) || + (!IsAShr && C.shl(ShAmtVal).lshr(ShAmtVal) == C)) && "Expected icmp+shr simplify did not occur."); - // Check if the bits shifted out are known to be zero. If so, we can compare - // against the unshifted value: + // If the bits shifted out are known zero, compare the unshifted value: // (X & 4) >> 1 == 2 --> (X & 4) == 4. - Constant *ShiftedCmpRHS = ConstantInt::get(Shr->getType(), *C << ShAmtVal); - if (Shr->hasOneUse()) { - if (Shr->isExact()) - return new ICmpInst(Pred, X, ShiftedCmpRHS); + if (Shr->isExact()) + return new ICmpInst(Pred, X, ConstantInt::get(ShrTy, C << ShAmtVal)); - // Otherwise strength reduce the shift into an 'and'. + if (Shr->hasOneUse()) { + // Canonicalize the shift into an 'and': + // icmp eq/ne (shr X, ShAmt), C --> icmp eq/ne (and X, HiMask), (C << ShAmt) APInt Val(APInt::getHighBitsSet(TypeBits, TypeBits - ShAmtVal)); - Constant *Mask = ConstantInt::get(Shr->getType(), Val); + Constant *Mask = ConstantInt::get(ShrTy, Val); Value *And = Builder.CreateAnd(X, Mask, Shr->getName() + ".mask"); - return new ICmpInst(Pred, And, ShiftedCmpRHS); + return new ICmpInst(Pred, And, ConstantInt::get(ShrTy, C << ShAmtVal)); } return nullptr; @@ -2137,7 +2084,7 @@ Instruction *InstCombiner::foldICmpShrConstant(ICmpInst &Cmp, /// Fold icmp (udiv X, Y), C. Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, - const APInt *C) { + const APInt &C) { const APInt *C2; if (!match(UDiv->getOperand(0), m_APInt(C2))) return nullptr; @@ -2147,17 +2094,17 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, // (icmp ugt (udiv C2, Y), C) -> (icmp ule Y, C2/(C+1)) Value *Y = UDiv->getOperand(1); if (Cmp.getPredicate() == ICmpInst::ICMP_UGT) { - assert(!C->isMaxValue() && + assert(!C.isMaxValue() && "icmp ugt X, UINT_MAX should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_ULE, Y, - ConstantInt::get(Y->getType(), C2->udiv(*C + 1))); + ConstantInt::get(Y->getType(), C2->udiv(C + 1))); } // (icmp ult (udiv C2, Y), C) -> (icmp ugt Y, C2/C) if (Cmp.getPredicate() == ICmpInst::ICMP_ULT) { - assert(*C != 0 && "icmp ult X, 0 should have been simplified already."); + assert(C != 0 && "icmp ult X, 0 should have been simplified already."); return new ICmpInst(ICmpInst::ICMP_UGT, Y, - ConstantInt::get(Y->getType(), C2->udiv(*C))); + ConstantInt::get(Y->getType(), C2->udiv(C))); } return nullptr; @@ -2166,7 +2113,7 @@ Instruction *InstCombiner::foldICmpUDivConstant(ICmpInst &Cmp, /// Fold icmp ({su}div X, Y), C. Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, - const APInt *C) { + const APInt &C) { // Fold: icmp pred ([us]div X, C2), C -> range test // Fold this div into the comparison, producing a range check. // Determine, based on the divide type, what the range is being @@ -2197,28 +2144,22 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, (DivIsSigned && C2->isAllOnesValue())) return nullptr; - // TODO: We could do all of the computations below using APInt. - Constant *CmpRHS = cast<Constant>(Cmp.getOperand(1)); - Constant *DivRHS = cast<Constant>(Div->getOperand(1)); - - // Compute Prod = CmpRHS * DivRHS. We are essentially solving an equation of - // form X / C2 = C. We solve for X by multiplying C2 (DivRHS) and C (CmpRHS). + // Compute Prod = C * C2. We are essentially solving an equation of + // form X / C2 = C. We solve for X by multiplying C2 and C. // By solving for X, we can turn this into a range check instead of computing // a divide. - Constant *Prod = ConstantExpr::getMul(CmpRHS, DivRHS); + APInt Prod = C * *C2; // Determine if the product overflows by seeing if the product is not equal to // the divide. Make sure we do the same kind of divide as in the LHS // instruction that we're folding. - bool ProdOV = (DivIsSigned ? ConstantExpr::getSDiv(Prod, DivRHS) - : ConstantExpr::getUDiv(Prod, DivRHS)) != CmpRHS; + bool ProdOV = (DivIsSigned ? Prod.sdiv(*C2) : Prod.udiv(*C2)) != C; ICmpInst::Predicate Pred = Cmp.getPredicate(); // If the division is known to be exact, then there is no remainder from the // divide, so the covered range size is unit, otherwise it is the divisor. - Constant *RangeSize = - Div->isExact() ? ConstantInt::get(Div->getType(), 1) : DivRHS; + APInt RangeSize = Div->isExact() ? APInt(C2->getBitWidth(), 1) : *C2; // Figure out the interval that is being checked. For example, a comparison // like "X /u 5 == 0" is really checking that X is in the interval [0, 5). @@ -2228,7 +2169,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, // overflow variable is set to 0 if it's corresponding bound variable is valid // -1 if overflowed off the bottom end, or +1 if overflowed off the top end. int LoOverflow = 0, HiOverflow = 0; - Constant *LoBound = nullptr, *HiBound = nullptr; + APInt LoBound, HiBound; if (!DivIsSigned) { // udiv // e.g. X/5 op 3 --> [15, 20) @@ -2240,38 +2181,38 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, HiOverflow = addWithOverflow(HiBound, LoBound, RangeSize, false); } } else if (C2->isStrictlyPositive()) { // Divisor is > 0. - if (C->isNullValue()) { // (X / pos) op 0 + if (C.isNullValue()) { // (X / pos) op 0 // Can't overflow. e.g. X/2 op 0 --> [-1, 2) - LoBound = ConstantExpr::getNeg(SubOne(RangeSize)); + LoBound = -(RangeSize - 1); HiBound = RangeSize; - } else if (C->isStrictlyPositive()) { // (X / pos) op pos + } else if (C.isStrictlyPositive()) { // (X / pos) op pos LoBound = Prod; // e.g. X/5 op 3 --> [15, 20) HiOverflow = LoOverflow = ProdOV; if (!HiOverflow) HiOverflow = addWithOverflow(HiBound, Prod, RangeSize, true); } else { // (X / pos) op neg // e.g. X/5 op -3 --> [-15-4, -15+1) --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; LoOverflow = HiOverflow = ProdOV ? -1 : 0; if (!LoOverflow) { - Constant *DivNeg = ConstantExpr::getNeg(RangeSize); + APInt DivNeg = -RangeSize; LoOverflow = addWithOverflow(LoBound, HiBound, DivNeg, true) ? -1 : 0; } } } else if (C2->isNegative()) { // Divisor is < 0. if (Div->isExact()) - RangeSize = ConstantExpr::getNeg(RangeSize); - if (C->isNullValue()) { // (X / neg) op 0 + RangeSize.negate(); + if (C.isNullValue()) { // (X / neg) op 0 // e.g. X/-5 op 0 --> [-4, 5) - LoBound = AddOne(RangeSize); - HiBound = ConstantExpr::getNeg(RangeSize); - if (HiBound == DivRHS) { // -INTMIN = INTMIN + LoBound = RangeSize + 1; + HiBound = -RangeSize; + if (HiBound == *C2) { // -INTMIN = INTMIN HiOverflow = 1; // [INTMIN+1, overflow) - HiBound = nullptr; // e.g. X/INTMIN = 0 --> X > INTMIN + HiBound = APInt(); // e.g. X/INTMIN = 0 --> X > INTMIN } - } else if (C->isStrictlyPositive()) { // (X / neg) op pos + } else if (C.isStrictlyPositive()) { // (X / neg) op pos // e.g. X/-5 op 3 --> [-19, -14) - HiBound = AddOne(Prod); + HiBound = Prod + 1; HiOverflow = LoOverflow = ProdOV ? -1 : 0; if (!LoOverflow) LoOverflow = addWithOverflow(LoBound, HiBound, RangeSize, true) ? -1:0; @@ -2294,25 +2235,27 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, return replaceInstUsesWith(Cmp, Builder.getFalse()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, LoBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, HiBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith( - Cmp, insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), DivIsSigned, true)); + Cmp, insertRangeTest(X, LoBound, HiBound, DivIsSigned, true)); case ICmpInst::ICMP_NE: if (LoOverflow && HiOverflow) return replaceInstUsesWith(Cmp, Builder.getTrue()); if (HiOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SLT : - ICmpInst::ICMP_ULT, X, LoBound); + ICmpInst::ICMP_ULT, X, + ConstantInt::get(Div->getType(), LoBound)); if (LoOverflow) return new ICmpInst(DivIsSigned ? ICmpInst::ICMP_SGE : - ICmpInst::ICMP_UGE, X, HiBound); + ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); return replaceInstUsesWith(Cmp, - insertRangeTest(X, LoBound->getUniqueInteger(), - HiBound->getUniqueInteger(), + insertRangeTest(X, LoBound, HiBound, DivIsSigned, false)); case ICmpInst::ICMP_ULT: case ICmpInst::ICMP_SLT: @@ -2320,7 +2263,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, return replaceInstUsesWith(Cmp, Builder.getTrue()); if (LoOverflow == -1) // Low bound is less than input range. return replaceInstUsesWith(Cmp, Builder.getFalse()); - return new ICmpInst(Pred, X, LoBound); + return new ICmpInst(Pred, X, ConstantInt::get(Div->getType(), LoBound)); case ICmpInst::ICMP_UGT: case ICmpInst::ICMP_SGT: if (HiOverflow == +1) // High bound greater than input range. @@ -2328,8 +2271,10 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, if (HiOverflow == -1) // High bound less than input range. return replaceInstUsesWith(Cmp, Builder.getTrue()); if (Pred == ICmpInst::ICMP_UGT) - return new ICmpInst(ICmpInst::ICMP_UGE, X, HiBound); - return new ICmpInst(ICmpInst::ICMP_SGE, X, HiBound); + return new ICmpInst(ICmpInst::ICMP_UGE, X, + ConstantInt::get(Div->getType(), HiBound)); + return new ICmpInst(ICmpInst::ICMP_SGE, X, + ConstantInt::get(Div->getType(), HiBound)); } return nullptr; @@ -2338,7 +2283,7 @@ Instruction *InstCombiner::foldICmpDivConstant(ICmpInst &Cmp, /// Fold icmp (sub X, Y), C. Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, - const APInt *C) { + const APInt &C) { Value *X = Sub->getOperand(0), *Y = Sub->getOperand(1); ICmpInst::Predicate Pred = Cmp.getPredicate(); @@ -2349,19 +2294,19 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, if (Sub->hasNoSignedWrap()) { // (icmp sgt (sub nsw X, Y), -1) -> (icmp sge X, Y) - if (Pred == ICmpInst::ICMP_SGT && C->isAllOnesValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isAllOnesValue()) return new ICmpInst(ICmpInst::ICMP_SGE, X, Y); // (icmp sgt (sub nsw X, Y), 0) -> (icmp sgt X, Y) - if (Pred == ICmpInst::ICMP_SGT && C->isNullValue()) + if (Pred == ICmpInst::ICMP_SGT && C.isNullValue()) return new ICmpInst(ICmpInst::ICMP_SGT, X, Y); // (icmp slt (sub nsw X, Y), 0) -> (icmp slt X, Y) - if (Pred == ICmpInst::ICMP_SLT && C->isNullValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isNullValue()) return new ICmpInst(ICmpInst::ICMP_SLT, X, Y); // (icmp slt (sub nsw X, Y), 1) -> (icmp sle X, Y) - if (Pred == ICmpInst::ICMP_SLT && C->isOneValue()) + if (Pred == ICmpInst::ICMP_SLT && C.isOneValue()) return new ICmpInst(ICmpInst::ICMP_SLE, X, Y); } @@ -2371,14 +2316,14 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, // C2 - Y <u C -> (Y | (C - 1)) == C2 // iff (C2 & (C - 1)) == C - 1 and C is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && - (*C2 & (*C - 1)) == (*C - 1)) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, *C - 1), X); + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && + (*C2 & (C - 1)) == (C - 1)) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateOr(Y, C - 1), X); // C2 - Y >u C -> (Y | C) != C2 // iff C2 & C == C and C + 1 is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == *C) - return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, *C), X); + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == C) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateOr(Y, C), X); return nullptr; } @@ -2386,7 +2331,7 @@ Instruction *InstCombiner::foldICmpSubConstant(ICmpInst &Cmp, /// Fold icmp (add X, Y), C. Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, - const APInt *C) { + const APInt &C) { Value *Y = Add->getOperand(1); const APInt *C2; if (Cmp.isEquality() || !match(Y, m_APInt(C2))) @@ -2403,7 +2348,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, if (Add->hasNoSignedWrap() && (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_SLT)) { bool Overflow; - APInt NewC = C->ssub_ov(*C2, Overflow); + APInt NewC = C.ssub_ov(*C2, Overflow); // If there is overflow, the result must be true or false. // TODO: Can we assert there is no overflow because InstSimplify always // handles those cases? @@ -2412,7 +2357,7 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, return new ICmpInst(Pred, X, ConstantInt::get(Ty, NewC)); } - auto CR = ConstantRange::makeExactICmpRegion(Pred, *C).subtract(*C2); + auto CR = ConstantRange::makeExactICmpRegion(Pred, C).subtract(*C2); const APInt &Upper = CR.getUpper(); const APInt &Lower = CR.getLower(); if (Cmp.isSigned()) { @@ -2433,15 +2378,15 @@ Instruction *InstCombiner::foldICmpAddConstant(ICmpInst &Cmp, // X+C <u C2 -> (X & -C2) == C // iff C & (C2-1) == 0 // C2 is a power of 2 - if (Pred == ICmpInst::ICMP_ULT && C->isPowerOf2() && (*C2 & (*C - 1)) == 0) - return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -(*C)), + if (Pred == ICmpInst::ICMP_ULT && C.isPowerOf2() && (*C2 & (C - 1)) == 0) + return new ICmpInst(ICmpInst::ICMP_EQ, Builder.CreateAnd(X, -C), ConstantExpr::getNeg(cast<Constant>(Y))); // X+C >u C2 -> (X & ~C2) != C // iff C & C2 == 0 // C2+1 is a power of 2 - if (Pred == ICmpInst::ICMP_UGT && (*C + 1).isPowerOf2() && (*C2 & *C) == 0) - return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~(*C)), + if (Pred == ICmpInst::ICMP_UGT && (C + 1).isPowerOf2() && (*C2 & C) == 0) + return new ICmpInst(ICmpInst::ICMP_NE, Builder.CreateAnd(X, ~C), ConstantExpr::getNeg(cast<Constant>(Y))); return nullptr; @@ -2471,7 +2416,7 @@ bool InstCombiner::matchThreeWayIntCompare(SelectInst *SI, Value *&LHS, } Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, - Instruction *Select, + SelectInst *Select, ConstantInt *C) { assert(C && "Cmp RHS should be a constant int!"); @@ -2483,8 +2428,8 @@ Instruction *InstCombiner::foldICmpSelectConstant(ICmpInst &Cmp, Value *OrigLHS, *OrigRHS; ConstantInt *C1LessThan, *C2Equal, *C3GreaterThan; if (Cmp.hasOneUse() && - matchThreeWayIntCompare(cast<SelectInst>(Select), OrigLHS, OrigRHS, - C1LessThan, C2Equal, C3GreaterThan)) { + matchThreeWayIntCompare(Select, OrigLHS, OrigRHS, C1LessThan, C2Equal, + C3GreaterThan)) { assert(C1LessThan && C2Equal && C3GreaterThan); bool TrueWhenLessThan = @@ -2525,82 +2470,74 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { if (!match(Cmp.getOperand(1), m_APInt(C))) return nullptr; - BinaryOperator *BO; - if (match(Cmp.getOperand(0), m_BinOp(BO))) { + if (auto *BO = dyn_cast<BinaryOperator>(Cmp.getOperand(0))) { switch (BO->getOpcode()) { case Instruction::Xor: - if (Instruction *I = foldICmpXorConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpXorConstant(Cmp, BO, *C)) return I; break; case Instruction::And: - if (Instruction *I = foldICmpAndConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpAndConstant(Cmp, BO, *C)) return I; break; case Instruction::Or: - if (Instruction *I = foldICmpOrConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpOrConstant(Cmp, BO, *C)) return I; break; case Instruction::Mul: - if (Instruction *I = foldICmpMulConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpMulConstant(Cmp, BO, *C)) return I; break; case Instruction::Shl: - if (Instruction *I = foldICmpShlConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpShlConstant(Cmp, BO, *C)) return I; break; case Instruction::LShr: case Instruction::AShr: - if (Instruction *I = foldICmpShrConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpShrConstant(Cmp, BO, *C)) return I; break; case Instruction::UDiv: - if (Instruction *I = foldICmpUDivConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpUDivConstant(Cmp, BO, *C)) return I; LLVM_FALLTHROUGH; case Instruction::SDiv: - if (Instruction *I = foldICmpDivConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpDivConstant(Cmp, BO, *C)) return I; break; case Instruction::Sub: - if (Instruction *I = foldICmpSubConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpSubConstant(Cmp, BO, *C)) return I; break; case Instruction::Add: - if (Instruction *I = foldICmpAddConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpAddConstant(Cmp, BO, *C)) return I; break; default: break; } // TODO: These folds could be refactored to be part of the above calls. - if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, C)) + if (Instruction *I = foldICmpBinOpEqualityWithConstant(Cmp, BO, *C)) return I; } // Match against CmpInst LHS being instructions other than binary operators. - Instruction *LHSI; - if (match(Cmp.getOperand(0), m_Instruction(LHSI))) { - switch (LHSI->getOpcode()) { - case Instruction::Select: - { - // For now, we only support constant integers while folding the - // ICMP(SELECT)) pattern. We can extend this to support vector of integers - // similar to the cases handled by binary ops above. - if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) - if (Instruction *I = foldICmpSelectConstant(Cmp, LHSI, ConstRHS)) - return I; - break; - } - case Instruction::Trunc: - if (Instruction *I = foldICmpTruncConstant(Cmp, LHSI, C)) + + if (auto *SI = dyn_cast<SelectInst>(Cmp.getOperand(0))) { + // For now, we only support constant integers while folding the + // ICMP(SELECT)) pattern. We can extend this to support vector of integers + // similar to the cases handled by binary ops above. + if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(Cmp.getOperand(1))) + if (Instruction *I = foldICmpSelectConstant(Cmp, SI, ConstRHS)) return I; - break; - default: - break; - } } - if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, C)) + if (auto *TI = dyn_cast<TruncInst>(Cmp.getOperand(0))) { + if (Instruction *I = foldICmpTruncConstant(Cmp, TI, *C)) + return I; + } + + if (Instruction *I = foldICmpIntrinsicWithConstant(Cmp, *C)) return I; return nullptr; @@ -2610,7 +2547,7 @@ Instruction *InstCombiner::foldICmpInstWithConstant(ICmpInst &Cmp) { /// icmp eq/ne BO, C. Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, BinaryOperator *BO, - const APInt *C) { + const APInt &C) { // TODO: Some of these folds could work with arbitrary constants, but this // function is limited to scalar and vector splat constants. if (!Cmp.isEquality()) @@ -2624,7 +2561,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, switch (BO->getOpcode()) { case Instruction::SRem: // If we have a signed (X % (2^c)) == 0, turn it into an unsigned one. - if (C->isNullValue() && BO->hasOneUse()) { + if (C.isNullValue() && BO->hasOneUse()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && BOC->sgt(1) && BOC->isPowerOf2()) { Value *NewRem = Builder.CreateURem(BOp0, BOp1, BO->getName()); @@ -2641,7 +2578,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, Constant *SubC = ConstantExpr::getSub(RHS, cast<Constant>(BOp1)); return new ICmpInst(Pred, BOp0, SubC); } - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((add A, B) != 0) with (A != -B) if A or B is // efficiently invertible, or if the add has just this one use. if (Value *NegVal = dyn_castNegVal(BOp1)) @@ -2662,7 +2599,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // For the xor case, we can xor two constants together, eliminating // the explicit xor. return new ICmpInst(Pred, BOp0, ConstantExpr::getXor(RHS, BOC)); - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((xor A, B) != 0) with (A != B) return new ICmpInst(Pred, BOp0, BOp1); } @@ -2675,7 +2612,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, // Replace ((sub BOC, B) != C) with (B != BOC-C). Constant *SubC = ConstantExpr::getSub(cast<Constant>(BOp0), RHS); return new ICmpInst(Pred, BOp1, SubC); - } else if (C->isNullValue()) { + } else if (C.isNullValue()) { // Replace ((sub A, B) != 0) with (A != B). return new ICmpInst(Pred, BOp0, BOp1); } @@ -2697,7 +2634,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, const APInt *BOC; if (match(BOp1, m_APInt(BOC))) { // If we have ((X & C) == C), turn it into ((X & C) != 0). - if (C == BOC && C->isPowerOf2()) + if (C == *BOC && C.isPowerOf2()) return new ICmpInst(isICMP_NE ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, BO, Constant::getNullValue(RHS->getType())); @@ -2713,7 +2650,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } // ((X & ~7) == 0) --> X < 8 - if (C->isNullValue() && (~(*BOC) + 1).isPowerOf2()) { + if (C.isNullValue() && (~(*BOC) + 1).isPowerOf2()) { Constant *NegBOC = ConstantExpr::getNeg(cast<Constant>(BOp1)); auto NewPred = isICMP_NE ? ICmpInst::ICMP_UGE : ICmpInst::ICMP_ULT; return new ICmpInst(NewPred, BOp0, NegBOC); @@ -2722,7 +2659,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, break; } case Instruction::Mul: - if (C->isNullValue() && BO->hasNoSignedWrap()) { + if (C.isNullValue() && BO->hasNoSignedWrap()) { const APInt *BOC; if (match(BOp1, m_APInt(BOC)) && !BOC->isNullValue()) { // The trivial case (mul X, 0) is handled by InstSimplify. @@ -2733,7 +2670,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, } break; case Instruction::UDiv: - if (C->isNullValue()) { + if (C.isNullValue()) { // (icmp eq/ne (udiv A, B), 0) -> (icmp ugt/ule i32 B, A) auto NewPred = isICMP_NE ? ICmpInst::ICMP_ULE : ICmpInst::ICMP_UGT; return new ICmpInst(NewPred, BOp1, BOp0); @@ -2747,7 +2684,7 @@ Instruction *InstCombiner::foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, /// Fold an icmp with LLVM intrinsic and constant operand: icmp Pred II, C. Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, - const APInt *C) { + const APInt &C) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(Cmp.getOperand(0)); if (!II || !Cmp.isEquality()) return nullptr; @@ -2758,13 +2695,13 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, case Intrinsic::bswap: Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); - Cmp.setOperand(1, ConstantInt::get(Ty, C->byteSwap())); + Cmp.setOperand(1, ConstantInt::get(Ty, C.byteSwap())); return &Cmp; case Intrinsic::ctlz: case Intrinsic::cttz: // ctz(A) == bitwidth(A) -> A == 0 and likewise for != - if (*C == C->getBitWidth()) { + if (C == C.getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); Cmp.setOperand(1, ConstantInt::getNullValue(Ty)); @@ -2775,8 +2712,8 @@ Instruction *InstCombiner::foldICmpIntrinsicWithConstant(ICmpInst &Cmp, case Intrinsic::ctpop: { // popcount(A) == 0 -> A == 0 and likewise for != // popcount(A) == bitwidth(A) -> A == -1 and likewise for != - bool IsZero = C->isNullValue(); - if (IsZero || *C == C->getBitWidth()) { + bool IsZero = C.isNullValue(); + if (IsZero || C == C.getBitWidth()) { Worklist.Add(II); Cmp.setOperand(0, II->getArgOperand(0)); auto *NewOp = @@ -3924,31 +3861,29 @@ static Instruction *processUMulZExtIdiom(ICmpInst &I, Value *MulVal, /// When performing a comparison against a constant, it is possible that not all /// the bits in the LHS are demanded. This helper method computes the mask that /// IS demanded. -static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth, - bool isSignCheck) { - if (isSignCheck) - return APInt::getSignMask(BitWidth); +static APInt getDemandedBitsLHSMask(ICmpInst &I, unsigned BitWidth) { + const APInt *RHS; + if (!match(I.getOperand(1), m_APInt(RHS))) + return APInt::getAllOnesValue(BitWidth); - ConstantInt *CI = dyn_cast<ConstantInt>(I.getOperand(1)); - if (!CI) return APInt::getAllOnesValue(BitWidth); - const APInt &RHS = CI->getValue(); + // If this is a normal comparison, it demands all bits. If it is a sign bit + // comparison, it only demands the sign bit. + bool UnusedBit; + if (isSignBitCheck(I.getPredicate(), *RHS, UnusedBit)) + return APInt::getSignMask(BitWidth); switch (I.getPredicate()) { // For a UGT comparison, we don't care about any bits that // correspond to the trailing ones of the comparand. The value of these // bits doesn't impact the outcome of the comparison, because any value // greater than the RHS must differ in a bit higher than these due to carry. - case ICmpInst::ICMP_UGT: { - unsigned trailingOnes = RHS.countTrailingOnes(); - return APInt::getBitsSetFrom(BitWidth, trailingOnes); - } + case ICmpInst::ICMP_UGT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingOnes()); // Similarly, for a ULT comparison, we don't care about the trailing zeros. // Any value less than the RHS must differ in a higher bit because of carries. - case ICmpInst::ICMP_ULT: { - unsigned trailingZeros = RHS.countTrailingZeros(); - return APInt::getBitsSetFrom(BitWidth, trailingZeros); - } + case ICmpInst::ICMP_ULT: + return APInt::getBitsSetFrom(BitWidth, RHS->countTrailingZeros()); default: return APInt::getAllOnesValue(BitWidth); @@ -4122,20 +4057,11 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (!BitWidth) return nullptr; - // If this is a normal comparison, it demands all bits. If it is a sign bit - // comparison, it only demands the sign bit. - bool IsSignBit = false; - const APInt *CmpC; - if (match(Op1, m_APInt(CmpC))) { - bool UnusedBit; - IsSignBit = isSignBitCheck(Pred, *CmpC, UnusedBit); - } - KnownBits Op0Known(BitWidth); KnownBits Op1Known(BitWidth); if (SimplifyDemandedBits(&I, 0, - getDemandedBitsLHSMask(I, BitWidth, IsSignBit), + getDemandedBitsLHSMask(I, BitWidth), Op0Known, 0)) return &I; @@ -4233,20 +4159,22 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { const APInt *CmpC; if (match(Op1, m_APInt(CmpC))) { // A <u C -> A == C-1 if min(A)+1 == C - if (Op1Max == Op0Min + 1) { - Constant *CMinus1 = ConstantInt::get(Op0->getType(), *CmpC - 1); - return new ICmpInst(ICmpInst::ICMP_EQ, Op0, CMinus1); - } + if (*CmpC == Op0Min + 1) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + ConstantInt::get(Op1->getType(), *CmpC - 1)); + // X <u C --> X == 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->ceilLogBase2()) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, + Constant::getNullValue(Op1->getType())); } break; } case ICmpInst::ICMP_UGT: { if (Op0Min.ugt(Op1Max)) // A >u B -> true if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); - if (Op0Max.ule(Op1Min)) // A >u B -> false if max(A) <= max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >u B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); @@ -4256,42 +4184,52 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { if (*CmpC == Op0Max - 1) return new ICmpInst(ICmpInst::ICMP_EQ, Op0, ConstantInt::get(Op1->getType(), *CmpC + 1)); + // X >u C --> X != 0, if the number of zero bits in the bottom of X + // exceeds the log2 of C. + if (Op0Known.countMinTrailingZeros() >= CmpC->getActiveBits()) + return new ICmpInst(ICmpInst::ICMP_NE, Op0, + Constant::getNullValue(Op1->getType())); } break; } - case ICmpInst::ICMP_SLT: + case ICmpInst::ICMP_SLT: { if (Op0Max.slt(Op1Min)) // A <s B -> true if max(A) < min(C) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sge(Op1Max)) // A <s B -> false if min(A) >= max(C) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); if (Op1Min == Op0Max) // A <s B -> A != B if max(A) == min(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Max == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Min + 1) // A <s C -> A == C-1 if min(A)+1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder.getInt(CI->getValue() - 1)); + ConstantInt::get(Op1->getType(), *CmpC - 1)); } break; - case ICmpInst::ICMP_SGT: + } + case ICmpInst::ICMP_SGT: { if (Op0Min.sgt(Op1Max)) // A >s B -> true if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.sle(Op1Min)) // A >s B -> false if max(A) <= min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); - if (Op1Max == Op0Min) // A >s B -> A != B if min(A) == max(B) return new ICmpInst(ICmpInst::ICMP_NE, Op0, Op1); - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op1)) { - if (Op1Min == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C + const APInt *CmpC; + if (match(Op1, m_APInt(CmpC))) { + if (*CmpC == Op0Max - 1) // A >s C -> A == C+1 if max(A)-1 == C return new ICmpInst(ICmpInst::ICMP_EQ, Op0, - Builder.getInt(CI->getValue() + 1)); + ConstantInt::get(Op1->getType(), *CmpC + 1)); } break; + } case ICmpInst::ICMP_SGE: assert(!isa<ConstantInt>(Op1) && "ICMP_SGE with ConstantInt not folded!"); if (Op0Min.sge(Op1Max)) // A >=s B -> true if min(A) >= max(B) return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.slt(Op1Min)) // A >=s B -> false if max(A) < min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=s B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_SLE: assert(!isa<ConstantInt>(Op1) && "ICMP_SLE with ConstantInt not folded!"); @@ -4299,6 +4237,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.sgt(Op1Max)) // A <=s B -> false if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=s B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_UGE: assert(!isa<ConstantInt>(Op1) && "ICMP_UGE with ConstantInt not folded!"); @@ -4306,6 +4246,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Max.ult(Op1Min)) // A >=u B -> false if max(A) < min(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Min == Op0Max) // A >=u B -> A == B if max(A) == min(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; case ICmpInst::ICMP_ULE: assert(!isa<ConstantInt>(Op1) && "ICMP_ULE with ConstantInt not folded!"); @@ -4313,6 +4255,8 @@ Instruction *InstCombiner::foldICmpUsingKnownBits(ICmpInst &I) { return replaceInstUsesWith(I, ConstantInt::getTrue(I.getType())); if (Op0Min.ugt(Op1Max)) // A <=u B -> false if min(A) > max(B) return replaceInstUsesWith(I, ConstantInt::getFalse(I.getType())); + if (Op1Max == Op0Min) // A <=u B -> A == B if min(A) == max(B) + return new ICmpInst(ICmpInst::ICMP_EQ, Op0, Op1); break; } @@ -4478,7 +4422,7 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - // comparing -val or val with non-zero is the same as just comparing val + // Comparing -val or val with non-zero is the same as just comparing val // ie, abs(val) != 0 -> val != 0 if (I.getPredicate() == ICmpInst::ICMP_NE && match(Op1, m_Zero())) { Value *Cond, *SelectTrue, *SelectFalse; @@ -4515,11 +4459,19 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { // and CodeGen. And in this case, at least one of the comparison // operands has at least one user besides the compare (the select), // which would often largely negate the benefit of folding anyway. + // + // Do the same for the other patterns recognized by matchSelectPattern. if (I.hasOneUse()) - if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) return nullptr; + } + + // Do this after checking for min/max to prevent infinite looping. + if (Instruction *Res = foldICmpWithZero(I)) + return Res; // FIXME: We only do this after checking for min/max to prevent infinite // looping caused by a reverse canonicalization of these patterns for min/max. @@ -4684,11 +4636,11 @@ Instruction *InstCombiner::visitICmpInst(ICmpInst &I) { Value *X; ConstantInt *Cst; // icmp X+Cst, X if (match(Op0, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op1 == X) - return foldICmpAddOpConst(I, X, Cst, I.getPredicate()); + return foldICmpAddOpConst(X, Cst, I.getPredicate()); // icmp X, X+Cst if (match(Op1, m_Add(m_Value(X), m_ConstantInt(Cst))) && Op0 == X) - return foldICmpAddOpConst(I, X, Cst, I.getSwappedPredicate()); + return foldICmpAddOpConst(X, Cst, I.getSwappedPredicate()); } return Changed ? &I : nullptr; } @@ -4943,17 +4895,16 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { Changed = true; } + const CmpInst::Predicate Pred = I.getPredicate(); Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); - - if (Value *V = - SimplifyFCmpInst(I.getPredicate(), Op0, Op1, I.getFastMathFlags(), - SQ.getWithInstruction(&I))) + if (Value *V = SimplifyFCmpInst(Pred, Op0, Op1, I.getFastMathFlags(), + SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); // Simplify 'fcmp pred X, X' if (Op0 == Op1) { - switch (I.getPredicate()) { - default: llvm_unreachable("Unknown predicate!"); + switch (Pred) { + default: break; case FCmpInst::FCMP_UNO: // True if unordered: isnan(X) | isnan(Y) case FCmpInst::FCMP_ULT: // True if unordered or less than case FCmpInst::FCMP_UGT: // True if unordered or greater than @@ -4974,6 +4925,19 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { } } + // If we're just checking for a NaN (ORD/UNO) and have a non-NaN operand, + // then canonicalize the operand to 0.0. + if (Pred == CmpInst::FCMP_ORD || Pred == CmpInst::FCMP_UNO) { + if (!match(Op0, m_Zero()) && isKnownNeverNaN(Op0)) { + I.setOperand(0, ConstantFP::getNullValue(Op0->getType())); + return &I; + } + if (!match(Op1, m_Zero()) && isKnownNeverNaN(Op1)) { + I.setOperand(1, ConstantFP::getNullValue(Op0->getType())); + return &I; + } + } + // Test if the FCmpInst instruction is used exclusively by a select as // part of a minimum or maximum operation. If so, refrain from doing // any other folding. This helps out other analyses which understand @@ -4982,10 +4946,12 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { // operands has at least one user besides the compare (the select), // which would often largely negate the benefit of folding anyway. if (I.hasOneUse()) - if (SelectInst *SI = dyn_cast<SelectInst>(*I.user_begin())) - if ((SI->getOperand(1) == Op0 && SI->getOperand(2) == Op1) || - (SI->getOperand(2) == Op0 && SI->getOperand(1) == Op1)) + if (SelectInst *SI = dyn_cast<SelectInst>(I.user_back())) { + Value *A, *B; + SelectPatternResult SPR = matchSelectPattern(SI, A, B); + if (SPR.Flavor != SPF_UNKNOWN) return nullptr; + } // Handle fcmp with constant RHS if (Constant *RHSC = dyn_cast<Constant>(Op1)) { @@ -5027,7 +4993,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { ((Fabs.compare(APFloat::getSmallestNormalized(*Sem)) != APFloat::cmpLessThan) || Fabs.isZero())) - return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), + return new FCmpInst(Pred, LHSExt->getOperand(0), ConstantFP::get(RHSC->getContext(), F)); break; } @@ -5072,7 +5038,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { break; // Various optimization for fabs compared with zero. - switch (I.getPredicate()) { + switch (Pred) { default: break; // fabs(x) < 0 --> false @@ -5093,7 +5059,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { case FCmpInst::FCMP_UEQ: case FCmpInst::FCMP_ONE: case FCmpInst::FCMP_UNE: - return new FCmpInst(I.getPredicate(), CI->getArgOperand(0), RHSC); + return new FCmpInst(Pred, CI->getArgOperand(0), RHSC); } } } @@ -5108,8 +5074,7 @@ Instruction *InstCombiner::visitFCmpInst(FCmpInst &I) { if (FPExtInst *LHSExt = dyn_cast<FPExtInst>(Op0)) if (FPExtInst *RHSExt = dyn_cast<FPExtInst>(Op1)) if (LHSExt->getSrcTy() == RHSExt->getSrcTy()) - return new FCmpInst(I.getPredicate(), LHSExt->getOperand(0), - RHSExt->getOperand(0)); + return new FCmpInst(Pred, LHSExt->getOperand(0), RHSExt->getOperand(0)); return Changed ? &I : nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineInternal.h b/lib/Transforms/InstCombine/InstCombineInternal.h index c38a4981bf1d..f1f66d86cb73 100644 --- a/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/lib/Transforms/InstCombine/InstCombineInternal.h @@ -6,42 +6,59 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// /// This file provides internal interfaces used to implement the InstCombine. -/// +// //===----------------------------------------------------------------------===// #ifndef LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H #define LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H +#include "llvm/ADT/ArrayRef.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Dominators.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Operator.h" -#include "llvm/IR/PatternMatch.h" -#include "llvm/Pass.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> #define DEBUG_TYPE "instcombine" namespace llvm { + +class APInt; +class AssumptionCache; class CallSite; class DataLayout; class DominatorTree; +class GEPOperator; +class GlobalVariable; +class LoopInfo; +class OptimizationRemarkEmitter; class TargetLibraryInfo; -class DbgDeclareInst; -class MemIntrinsic; -class MemSetInst; +class User; /// Assign a complexity or rank value to LLVM Values. This is used to reduce /// the amount of pattern matching needed for compares and commutative @@ -109,6 +126,7 @@ static inline Value *peekThroughBitcast(Value *V, bool OneUseOnly = false) { static inline Constant *AddOne(Constant *C) { return ConstantExpr::getAdd(C, ConstantInt::get(C->getType(), 1)); } + /// \brief Subtract one from a Constant static inline Constant *SubOne(Constant *C) { return ConstantExpr::getSub(C, ConstantInt::get(C->getType(), 1)); @@ -118,7 +136,6 @@ static inline Constant *SubOne(Constant *C) { /// This happens in cases where the ~ can be eliminated. If WillInvertAllUses /// is true, work under the assumption that the caller intends to remove all /// uses of V and only keep uses of ~V. -/// static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { // ~(~(X)) -> X. if (BinaryOperator::isNot(V)) @@ -161,7 +178,6 @@ static inline bool IsFreeToInvert(Value *V, bool WillInvertAllUses) { return false; } - /// \brief Specific patterns of overflow check idioms that we match. enum OverflowCheckFlavor { OCF_UNSIGNED_ADD, @@ -209,12 +225,13 @@ public: /// \brief An IRBuilder that automatically inserts new instructions into the /// worklist. - typedef IRBuilder<TargetFolder, IRBuilderCallbackInserter> BuilderTy; + using BuilderTy = IRBuilder<TargetFolder, IRBuilderCallbackInserter>; BuilderTy &Builder; private: // Mode in which we are running the combiner. const bool MinimizeSize; + /// Enable combines that trigger rarely but are costly in compiletime. const bool ExpensiveCombines; @@ -226,20 +243,23 @@ private: DominatorTree &DT; const DataLayout &DL; const SimplifyQuery SQ; + OptimizationRemarkEmitter &ORE; + // Optional analyses. When non-null, these can both be used to do better // combining and will be updated to reflect any changes. LoopInfo *LI; - bool MadeIRChange; + bool MadeIRChange = false; public: InstCombiner(InstCombineWorklist &Worklist, BuilderTy &Builder, bool MinimizeSize, bool ExpensiveCombines, AliasAnalysis *AA, AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, - const DataLayout &DL, LoopInfo *LI) + OptimizationRemarkEmitter &ORE, const DataLayout &DL, + LoopInfo *LI) : Worklist(Worklist), Builder(Builder), MinimizeSize(MinimizeSize), ExpensiveCombines(ExpensiveCombines), AA(AA), AC(AC), TLI(TLI), DT(DT), - DL(DL), SQ(DL, &TLI, &DT, &AC), LI(LI), MadeIRChange(false) {} + DL(DL), SQ(DL, &TLI, &DT, &AC), ORE(ORE), LI(LI) {} /// \brief Run the combiner over the entire worklist until it is empty. /// @@ -275,7 +295,7 @@ public: Instruction *visitURem(BinaryOperator &I); Instruction *visitSRem(BinaryOperator &I); Instruction *visitFRem(BinaryOperator &I); - bool SimplifyDivRemOfSelect(BinaryOperator &I); + bool simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I); Instruction *commonRemTransforms(BinaryOperator &I); Instruction *commonIRemTransforms(BinaryOperator &I); Instruction *commonDivTransforms(BinaryOperator &I); @@ -411,32 +431,38 @@ private: bool DoTransform = true); Instruction *transformSExtICmp(ICmpInst *ICI, Instruction &CI); + bool willNotOverflowSignedAdd(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForSignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; - }; + } + bool willNotOverflowUnsignedAdd(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForUnsignedAdd(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; - }; + } + bool willNotOverflowSignedSub(const Value *LHS, const Value *RHS, const Instruction &CxtI) const; bool willNotOverflowUnsignedSub(const Value *LHS, const Value *RHS, const Instruction &CxtI) const; bool willNotOverflowSignedMul(const Value *LHS, const Value *RHS, const Instruction &CxtI) const; + bool willNotOverflowUnsignedMul(const Value *LHS, const Value *RHS, const Instruction &CxtI) const { return computeOverflowForUnsignedMul(LHS, RHS, &CxtI) == OverflowResult::NeverOverflows; - }; + } + Value *EmitGEPOffset(User *GEP); Instruction *scalarizePHI(ExtractElementInst &EI, PHINode *PN); Value *EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask); Instruction *foldCastedBitwiseLogic(BinaryOperator &I); - Instruction *shrinkBitwiseLogic(TruncInst &Trunc); + Instruction *narrowBinOp(TruncInst &Trunc); + Instruction *narrowRotate(TruncInst &Trunc); Instruction *optimizeBitCastFromPhi(CastInst &CI, PHINode *PN); /// Determine if a pair of casts can be replaced by a single cast. @@ -453,11 +479,14 @@ private: const CastInst *CI2); Value *foldAndOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldAndOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, Instruction &CxtI); - Value *foldOrOfFCmps(FCmpInst *LHS, FCmpInst *RHS); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS); + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). + /// NOTE: Unlike most of instcombine, this returns a Value which should + /// already be inserted into the function. + Value *foldLogicOfFCmps(FCmpInst *LHS, FCmpInst *RHS, bool IsAnd); + Value *foldAndOrOfICmpsOfAndWithPow2(ICmpInst *LHS, ICmpInst *RHS, bool JoinedByAnd, Instruction &CxtI); public: @@ -542,6 +571,7 @@ public: unsigned Depth, const Instruction *CxtI) const { llvm::computeKnownBits(V, Known, DL, Depth, &AC, CxtI, &DT); } + KnownBits computeKnownBits(const Value *V, unsigned Depth, const Instruction *CxtI) const { return llvm::computeKnownBits(V, DL, Depth, &AC, CxtI, &DT); @@ -557,20 +587,24 @@ public: const Instruction *CxtI = nullptr) const { return llvm::MaskedValueIsZero(V, Mask, DL, Depth, &AC, CxtI, &DT); } + unsigned ComputeNumSignBits(const Value *Op, unsigned Depth = 0, const Instruction *CxtI = nullptr) const { return llvm::ComputeNumSignBits(Op, DL, Depth, &AC, CxtI, &DT); } + OverflowResult computeOverflowForUnsignedMul(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { return llvm::computeOverflowForUnsignedMul(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForUnsignedAdd(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { return llvm::computeOverflowForUnsignedAdd(LHS, RHS, DL, &AC, CxtI, &DT); } + OverflowResult computeOverflowForSignedAdd(const Value *LHS, const Value *RHS, const Instruction *CxtI) const { @@ -594,6 +628,11 @@ private: /// value, or null if it didn't simplify. Value *SimplifyUsingDistributiveLaws(BinaryOperator &I); + // Binary Op helper for select operations where the expression can be + // efficiently reorganized. + Value *SimplifySelectsFeedingBinaryOp(BinaryOperator &I, Value *LHS, + Value *RHS); + /// This tries to simplify binary operations by factorizing out common terms /// (e. g. "(A*B)+(A*C)" -> "A*(B+C)"). Value *tryFactorization(BinaryOperator &, Instruction::BinaryOps, Value *, @@ -615,6 +654,7 @@ private: bool SimplifyDemandedBits(Instruction *I, unsigned Op, const APInt &DemandedMask, KnownBits &Known, unsigned Depth = 0); + /// Helper routine of SimplifyDemandedUseBits. It computes KnownZero/KnownOne /// bits. It also tries to handle simplifications that can be done based on /// DemandedMask, but without modifying the Instruction. @@ -622,6 +662,7 @@ private: const APInt &DemandedMask, KnownBits &Known, unsigned Depth, Instruction *CxtI); + /// Helper routine of SimplifyDemandedUseBits. It tries to simplify demanded /// bit for "r1 = shr x, c1; r2 = shl r1, c2" instruction sequence. Value *simplifyShrShlDemandedBits( @@ -652,6 +693,8 @@ private: /// This is a convenience wrapper function for the above two functions. Instruction *foldOpWithConstantIntoOperand(BinaryOperator &I); + Instruction *foldAddWithConstant(BinaryOperator &Add); + /// \brief Try to rotate an operation below a PHI node, using PHI nodes for /// its operands. Instruction *FoldPHIArgOpIntoPHI(PHINode &PN); @@ -660,9 +703,14 @@ private: Instruction *FoldPHIArgLoadIntoPHI(PHINode &PN); Instruction *FoldPHIArgZextsIntoPHI(PHINode &PN); - /// Helper function for FoldPHIArgXIntoPHI() to get debug location for the + /// If an integer typed PHI has only one use which is an IntToPtr operation, + /// replace the PHI with an existing pointer typed PHI if it exists. Otherwise + /// insert a new pointer typed PHI and replace the original one. + Instruction *FoldIntegerTypedPHI(PHINode &PN); + + /// Helper function for FoldPHIArgXIntoPHI() to set debug location for the /// folded operation. - DebugLoc PHIArgMergedDebugLoc(PHINode &PN); + void PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN); Instruction *foldGEPICmp(GEPOperator *GEPLHS, Value *RHS, ICmpInst::Predicate Cond, Instruction &I); @@ -673,7 +721,7 @@ private: ConstantInt *AndCst = nullptr); Instruction *foldFCmpIntToFPConst(FCmpInst &I, Instruction *LHSI, Constant *RHSC); - Instruction *foldICmpAddOpConst(Instruction &ICI, Value *X, ConstantInt *CI, + Instruction *foldICmpAddOpConst(Value *X, ConstantInt *CI, ICmpInst::Predicate Pred); Instruction *foldICmpWithCastAndCast(ICmpInst &ICI); @@ -683,35 +731,36 @@ private: Instruction *foldICmpInstWithConstantNotInt(ICmpInst &Cmp); Instruction *foldICmpBinOp(ICmpInst &Cmp); Instruction *foldICmpEquality(ICmpInst &Cmp); + Instruction *foldICmpWithZero(ICmpInst &Cmp); - Instruction *foldICmpSelectConstant(ICmpInst &Cmp, Instruction *Select, + Instruction *foldICmpSelectConstant(ICmpInst &Cmp, SelectInst *Select, ConstantInt *C); - Instruction *foldICmpTruncConstant(ICmpInst &Cmp, Instruction *Trunc, - const APInt *C); + Instruction *foldICmpTruncConstant(ICmpInst &Cmp, TruncInst *Trunc, + const APInt &C); Instruction *foldICmpAndConstant(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C); + const APInt &C); Instruction *foldICmpXorConstant(ICmpInst &Cmp, BinaryOperator *Xor, - const APInt *C); + const APInt &C); Instruction *foldICmpOrConstant(ICmpInst &Cmp, BinaryOperator *Or, - const APInt *C); + const APInt &C); Instruction *foldICmpMulConstant(ICmpInst &Cmp, BinaryOperator *Mul, - const APInt *C); + const APInt &C); Instruction *foldICmpShlConstant(ICmpInst &Cmp, BinaryOperator *Shl, - const APInt *C); + const APInt &C); Instruction *foldICmpShrConstant(ICmpInst &Cmp, BinaryOperator *Shr, - const APInt *C); + const APInt &C); Instruction *foldICmpUDivConstant(ICmpInst &Cmp, BinaryOperator *UDiv, - const APInt *C); + const APInt &C); Instruction *foldICmpDivConstant(ICmpInst &Cmp, BinaryOperator *Div, - const APInt *C); + const APInt &C); Instruction *foldICmpSubConstant(ICmpInst &Cmp, BinaryOperator *Sub, - const APInt *C); + const APInt &C); Instruction *foldICmpAddConstant(ICmpInst &Cmp, BinaryOperator *Add, - const APInt *C); + const APInt &C); Instruction *foldICmpAndConstConst(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1); + const APInt &C1); Instruction *foldICmpAndShift(ICmpInst &Cmp, BinaryOperator *And, - const APInt *C1, const APInt *C2); + const APInt &C1, const APInt &C2); Instruction *foldICmpShrConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, const APInt &C2); Instruction *foldICmpShlConstConst(ICmpInst &I, Value *ShAmt, const APInt &C1, @@ -719,8 +768,8 @@ private: Instruction *foldICmpBinOpEqualityWithConstant(ICmpInst &Cmp, BinaryOperator *BO, - const APInt *C); - Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, const APInt *C); + const APInt &C); + Instruction *foldICmpIntrinsicWithConstant(ICmpInst &ICI, const APInt &C); // Helpers of visitSelectInst(). Instruction *foldSelectExtConst(SelectInst &Sel); @@ -740,8 +789,7 @@ private: Instruction *MatchBSwap(BinaryOperator &I); bool SimplifyStoreAtEndOfBlock(StoreInst &SI); - Instruction * - SimplifyElementUnorderedAtomicMemCpy(ElementUnorderedAtomicMemCpyInst *AMI); + Instruction *SimplifyElementUnorderedAtomicMemCpy(AtomicMemCpyInst *AMI); Instruction *SimplifyMemTransfer(MemIntrinsic *MI); Instruction *SimplifyMemSet(MemSetInst *MI); @@ -753,8 +801,8 @@ private: Value *Descale(Value *Val, APInt Scale, bool &NoSignedWrap); }; -} // end namespace llvm. +} // end namespace llvm #undef DEBUG_TYPE -#endif +#endif // LLVM_LIB_TRANSFORMS_INSTCOMBINE_INSTCOMBINEINTERNAL_H diff --git a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp index 451036545741..d4f06e18b957 100644 --- a/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp +++ b/lib/Transforms/InstCombine/InstCombineLoadStoreAlloca.cpp @@ -18,13 +18,14 @@ #include "llvm/Analysis/Loads.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; +using namespace PatternMatch; #define DEBUG_TYPE "instcombine" @@ -561,6 +562,28 @@ static StoreInst *combineStoreToNewValue(InstCombiner &IC, StoreInst &SI, Value return NewStore; } +/// Returns true if instruction represent minmax pattern like: +/// select ((cmp load V1, load V2), V1, V2). +static bool isMinMaxWithLoads(Value *V) { + assert(V->getType()->isPointerTy() && "Expected pointer type."); + // Ignore possible ty* to ixx* bitcast. + V = peekThroughBitcast(V); + // Check that select is select ((cmp load V1, load V2), V1, V2) - minmax + // pattern. + CmpInst::Predicate Pred; + Instruction *L1; + Instruction *L2; + Value *LHS; + Value *RHS; + if (!match(V, m_Select(m_Cmp(Pred, m_Instruction(L1), m_Instruction(L2)), + m_Value(LHS), m_Value(RHS)))) + return false; + return (match(L1, m_Load(m_Specific(LHS))) && + match(L2, m_Load(m_Specific(RHS)))) || + (match(L1, m_Load(m_Specific(RHS))) && + match(L2, m_Load(m_Specific(LHS)))); +} + /// \brief Combine loads to match the type of their uses' value after looking /// through intervening bitcasts. /// @@ -598,10 +621,14 @@ static Instruction *combineLoadToOperationType(InstCombiner &IC, LoadInst &LI) { // integers instead of any other type. We only do this when the loaded type // is sized and has a size exactly the same as its store size and the store // size is a legal integer type. + // Do not perform canonicalization if minmax pattern is found (to avoid + // infinite loop). if (!Ty->isIntegerTy() && Ty->isSized() && DL.isLegalInteger(DL.getTypeStoreSizeInBits(Ty)) && DL.getTypeStoreSizeInBits(Ty) == DL.getTypeSizeInBits(Ty) && - !DL.isNonIntegralPointerType(Ty)) { + !DL.isNonIntegralPointerType(Ty) && + !isMinMaxWithLoads( + peekThroughBitcast(LI.getPointerOperand(), /*OneUseOnly=*/true))) { if (all_of(LI.users(), [&LI](User *U) { auto *SI = dyn_cast<StoreInst>(U); return SI && SI->getPointerOperand() != &LI && @@ -931,6 +958,16 @@ static Instruction *replaceGEPIdxWithZero(InstCombiner &IC, Value *Ptr, return nullptr; } +static bool canSimplifyNullStoreOrGEP(StoreInst &SI) { + if (SI.getPointerAddressSpace() != 0) + return false; + + auto *Ptr = SI.getPointerOperand(); + if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Ptr)) + Ptr = GEPI->getOperand(0); + return isa<ConstantPointerNull>(Ptr); +} + static bool canSimplifyNullLoadOrGEP(LoadInst &LI, Value *Op) { if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(Op)) { const Value *GEPI0 = GEPI->getOperand(0); @@ -1298,6 +1335,46 @@ static bool equivalentAddressValues(Value *A, Value *B) { return false; } +/// Converts store (bitcast (load (bitcast (select ...)))) to +/// store (load (select ...)), where select is minmax: +/// select ((cmp load V1, load V2), V1, V2). +static bool removeBitcastsFromLoadStoreOnMinMax(InstCombiner &IC, + StoreInst &SI) { + // bitcast? + if (!match(SI.getPointerOperand(), m_BitCast(m_Value()))) + return false; + // load? integer? + Value *LoadAddr; + if (!match(SI.getValueOperand(), m_Load(m_BitCast(m_Value(LoadAddr))))) + return false; + auto *LI = cast<LoadInst>(SI.getValueOperand()); + if (!LI->getType()->isIntegerTy()) + return false; + if (!isMinMaxWithLoads(LoadAddr)) + return false; + + if (!all_of(LI->users(), [LI, LoadAddr](User *U) { + auto *SI = dyn_cast<StoreInst>(U); + return SI && SI->getPointerOperand() != LI && + peekThroughBitcast(SI->getPointerOperand()) != LoadAddr && + !SI->getPointerOperand()->isSwiftError(); + })) + return false; + + IC.Builder.SetInsertPoint(LI); + LoadInst *NewLI = combineLoadToNewType( + IC, *LI, LoadAddr->getType()->getPointerElementType()); + // Replace all the stores with stores of the newly loaded value. + for (auto *UI : LI->users()) { + auto *USI = cast<StoreInst>(UI); + IC.Builder.SetInsertPoint(USI); + combineStoreToNewValue(IC, *USI, NewLI); + } + IC.replaceInstUsesWith(*LI, UndefValue::get(LI->getType())); + IC.eraseInstFromFunction(*LI); + return true; +} + Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { Value *Val = SI.getOperand(0); Value *Ptr = SI.getOperand(1); @@ -1322,6 +1399,9 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { if (unpackStoreToAggregate(*this, SI)) return eraseInstFromFunction(SI); + if (removeBitcastsFromLoadStoreOnMinMax(*this, SI)) + return eraseInstFromFunction(SI); + // Replace GEP indices if possible. if (Instruction *NewGEPI = replaceGEPIdxWithZero(*this, Ptr, SI)) { Worklist.Add(NewGEPI); @@ -1392,7 +1472,8 @@ Instruction *InstCombiner::visitStoreInst(StoreInst &SI) { } // store X, null -> turns into 'unreachable' in SimplifyCFG - if (isa<ConstantPointerNull>(Ptr) && SI.getPointerAddressSpace() == 0) { + // store X, GEP(null, Y) -> turns into 'unreachable' in SimplifyCFG + if (canSimplifyNullStoreOrGEP(SI)) { if (!isa<UndefValue>(Val)) { SI.setOperand(0, UndefValue::get(Val->getType())); if (Instruction *U = dyn_cast<Instruction>(Val)) @@ -1544,8 +1625,7 @@ bool InstCombiner::SimplifyStoreAtEndOfBlock(StoreInst &SI) { SI.getSyncScopeID()); InsertNewInstBefore(NewSI, *BBI); // The debug locations of the original instructions might differ; merge them. - NewSI->setDebugLoc(DILocation::getMergedLocation(SI.getDebugLoc(), - OtherStore->getDebugLoc())); + NewSI->applyMergedLocation(SI.getDebugLoc(), OtherStore->getDebugLoc()); // If the two stores had AA tags, merge them. AAMDNodes AATags; diff --git a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp index e3a50220f94e..87666360c1a0 100644 --- a/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp +++ b/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp @@ -13,15 +13,36 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <utility> + using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "instcombine" - /// The specific integer value is used in a context where it is known to be /// non-zero. If this allows us to simplify the computation, do so and return /// the new operand, otherwise return null. @@ -73,7 +94,6 @@ static Value *simplifyValueKnownNonZero(Value *V, InstCombiner &IC, return MadeChange ? V : nullptr; } - /// True if the multiply can not be expressed in an int this size. static bool MultiplyOverflows(const APInt &C1, const APInt &C2, APInt &Product, bool IsSigned) { @@ -467,7 +487,7 @@ static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { IntrinsicInst *II = dyn_cast<IntrinsicInst>(Op); if (!II) return; - if (II->getIntrinsicID() != Intrinsic::log2 || !II->hasUnsafeAlgebra()) + if (II->getIntrinsicID() != Intrinsic::log2 || !II->isFast()) return; Log2 = II; @@ -478,7 +498,8 @@ static void detectLog2OfHalf(Value *&Op, Value *&Y, IntrinsicInst *&Log2) { Instruction *I = dyn_cast<Instruction>(OpLog2Of); if (!I) return; - if (I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) + + if (I->getOpcode() != Instruction::FMul || !I->isFast()) return; if (match(I->getOperand(0), m_SpecificFP(0.5))) @@ -540,7 +561,6 @@ static bool isFMulOrFDivWithConstant(Value *V) { /// This function is to simplify "FMulOrDiv * C" and returns the /// resulting expression. Note that this function could return NULL in /// case the constants cannot be folded into a normal floating-point. -/// Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, Constant *C, Instruction *InsertBefore) { assert(isFMulOrFDivWithConstant(FMulOrDiv) && "V is invalid"); @@ -582,7 +602,7 @@ Value *InstCombiner::foldFMulConst(Instruction *FMulOrDiv, Constant *C, } if (R) { - R->setHasUnsafeAlgebra(true); + R->setFast(true); InsertNewInstWith(R, *InsertBefore); } @@ -603,7 +623,7 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { SQ.getWithInstruction(&I))) return replaceInstUsesWith(I, V); - bool AllowReassociate = I.hasUnsafeAlgebra(); + bool AllowReassociate = I.isFast(); // Simplify mul instructions with a constant RHS. if (isa<Constant>(Op1)) { @@ -736,6 +756,10 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { } } + // Handle specials cases for FMul with selects feeding the operation + if (Value *V = SimplifySelectsFeedingBinaryOp(I, Op0, Op1)) + return replaceInstUsesWith(I, V); + // (X*Y) * X => (X*X) * Y where Y != X // The purpose is two-fold: // 1) to form a power expression (of X). @@ -743,7 +767,6 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { // latency of the instruction Y is amortized by the expression of X*X, // and therefore Y is in a "less critical" position compared to what it // was before the transformation. - // if (AllowReassociate) { Value *Opnd0_0, *Opnd0_1; if (Opnd0->hasOneUse() && @@ -774,24 +797,23 @@ Instruction *InstCombiner::visitFMul(BinaryOperator &I) { return Changed ? &I : nullptr; } -/// Try to fold a divide or remainder of a select instruction. -bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { - SelectInst *SI = cast<SelectInst>(I.getOperand(1)); - - // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y - int NonNullOperand = -1; - if (Constant *ST = dyn_cast<Constant>(SI->getOperand(1))) - if (ST->isNullValue()) - NonNullOperand = 2; - // div/rem X, (Cond ? Y : 0) -> div/rem X, Y - if (Constant *ST = dyn_cast<Constant>(SI->getOperand(2))) - if (ST->isNullValue()) - NonNullOperand = 1; - - if (NonNullOperand == -1) +/// Fold a divide or remainder with a select instruction divisor when one of the +/// select operands is zero. In that case, we can use the other select operand +/// because div/rem by zero is undefined. +bool InstCombiner::simplifyDivRemOfSelectWithZeroOp(BinaryOperator &I) { + SelectInst *SI = dyn_cast<SelectInst>(I.getOperand(1)); + if (!SI) return false; - Value *SelectCond = SI->getOperand(0); + int NonNullOperand; + if (match(SI->getTrueValue(), m_Zero())) + // div/rem X, (Cond ? 0 : Y) -> div/rem X, Y + NonNullOperand = 2; + else if (match(SI->getFalseValue(), m_Zero())) + // div/rem X, (Cond ? Y : 0) -> div/rem X, Y + NonNullOperand = 1; + else + return false; // Change the div/rem to use 'Y' instead of the select. I.setOperand(1, SI->getOperand(NonNullOperand)); @@ -804,12 +826,13 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { // If the select and condition only have a single use, don't bother with this, // early exit. + Value *SelectCond = SI->getCondition(); if (SI->use_empty() && SelectCond->hasOneUse()) return true; // Scan the current block backward, looking for other uses of SI. BasicBlock::iterator BBI = I.getIterator(), BBFront = I.getParent()->begin(); - + Type *CondTy = SelectCond->getType(); while (BBI != BBFront) { --BBI; // If we found a call to a function, we can't assume it will return, so @@ -824,7 +847,8 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { *I = SI->getOperand(NonNullOperand); Worklist.Add(&*BBI); } else if (*I == SelectCond) { - *I = Builder.getInt1(NonNullOperand == 1); + *I = NonNullOperand == 1 ? ConstantInt::getTrue(CondTy) + : ConstantInt::getFalse(CondTy); Worklist.Add(&*BBI); } } @@ -843,7 +867,6 @@ bool InstCombiner::SimplifyDivRemOfSelect(BinaryOperator &I) { return true; } - /// This function implements the transforms common to both integer division /// instructions (udiv and sdiv). It is called by the visitors to those integer /// division instructions. @@ -859,7 +882,7 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { // Handle cases involving: [su]div X, (select Cond, Y, Z) // This does not apply for fdiv. - if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; if (Instruction *LHS = dyn_cast<Instruction>(Op0)) { @@ -969,38 +992,29 @@ Instruction *InstCombiner::commonIDivTransforms(BinaryOperator &I) { return nullptr; } -/// dyn_castZExtVal - Checks if V is a zext or constant that can -/// be truncated to Ty without losing bits. -static Value *dyn_castZExtVal(Value *V, Type *Ty) { - if (ZExtInst *Z = dyn_cast<ZExtInst>(V)) { - if (Z->getSrcTy() == Ty) - return Z->getOperand(0); - } else if (ConstantInt *C = dyn_cast<ConstantInt>(V)) { - if (C->getValue().getActiveBits() <= cast<IntegerType>(Ty)->getBitWidth()) - return ConstantExpr::getTrunc(C, Ty); - } - return nullptr; -} +static const unsigned MaxDepth = 6; namespace { -const unsigned MaxDepth = 6; -typedef Instruction *(*FoldUDivOperandCb)(Value *Op0, Value *Op1, - const BinaryOperator &I, - InstCombiner &IC); + +using FoldUDivOperandCb = Instruction *(*)(Value *Op0, Value *Op1, + const BinaryOperator &I, + InstCombiner &IC); /// \brief Used to maintain state for visitUDivOperand(). struct UDivFoldAction { - FoldUDivOperandCb FoldAction; ///< Informs visitUDiv() how to fold this - ///< operand. This can be zero if this action - ///< joins two actions together. + /// Informs visitUDiv() how to fold this operand. This can be zero if this + /// action joins two actions together. + FoldUDivOperandCb FoldAction; + + /// Which operand to fold. + Value *OperandToFold; - Value *OperandToFold; ///< Which operand to fold. union { - Instruction *FoldResult; ///< The instruction returned when FoldAction is - ///< invoked. + /// The instruction returned when FoldAction is invoked. + Instruction *FoldResult; - size_t SelectLHSIdx; ///< Stores the LHS action index if this action - ///< joins two actions together. + /// Stores the LHS action index if this action joins two actions together. + size_t SelectLHSIdx; }; UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand) @@ -1008,7 +1022,8 @@ struct UDivFoldAction { UDivFoldAction(FoldUDivOperandCb FA, Value *InputOperand, size_t SLHS) : FoldAction(FA), OperandToFold(InputOperand), SelectLHSIdx(SLHS) {} }; -} + +} // end anonymous namespace // X udiv 2^C -> X >> C static Instruction *foldUDivPow2Cst(Value *Op0, Value *Op1, @@ -1095,6 +1110,43 @@ static size_t visitUDivOperand(Value *Op0, Value *Op1, const BinaryOperator &I, return 0; } +/// If we have zero-extended operands of an unsigned div or rem, we may be able +/// to narrow the operation (sink the zext below the math). +static Instruction *narrowUDivURem(BinaryOperator &I, + InstCombiner::BuilderTy &Builder) { + Instruction::BinaryOps Opcode = I.getOpcode(); + Value *N = I.getOperand(0); + Value *D = I.getOperand(1); + Type *Ty = I.getType(); + Value *X, *Y; + if (match(N, m_ZExt(m_Value(X))) && match(D, m_ZExt(m_Value(Y))) && + X->getType() == Y->getType() && (N->hasOneUse() || D->hasOneUse())) { + // udiv (zext X), (zext Y) --> zext (udiv X, Y) + // urem (zext X), (zext Y) --> zext (urem X, Y) + Value *NarrowOp = Builder.CreateBinOp(Opcode, X, Y); + return new ZExtInst(NarrowOp, Ty); + } + + Constant *C; + if ((match(N, m_OneUse(m_ZExt(m_Value(X)))) && match(D, m_Constant(C))) || + (match(D, m_OneUse(m_ZExt(m_Value(X)))) && match(N, m_Constant(C)))) { + // If the constant is the same in the smaller type, use the narrow version. + Constant *TruncC = ConstantExpr::getTrunc(C, X->getType()); + if (ConstantExpr::getZExt(TruncC, Ty) != C) + return nullptr; + + // udiv (zext X), C --> zext (udiv X, C') + // urem (zext X), C --> zext (urem X, C') + // udiv C, (zext X) --> zext (udiv C', X) + // urem C, (zext X) --> zext (urem C', X) + Value *NarrowOp = isa<Constant>(D) ? Builder.CreateBinOp(Opcode, X, TruncC) + : Builder.CreateBinOp(Opcode, TruncC, X); + return new ZExtInst(NarrowOp, Ty); + } + + return nullptr; +} + Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1); @@ -1127,12 +1179,8 @@ Instruction *InstCombiner::visitUDiv(BinaryOperator &I) { } } - // (zext A) udiv (zext B) --> zext (A udiv B) - if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) - if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst( - Builder.CreateUDiv(ZOp0->getOperand(0), ZOp1, "div", I.isExact()), - I.getType()); + if (Instruction *NarrowDiv = narrowUDivURem(I, Builder)) + return NarrowDiv; // (LHS udiv (select (select (...)))) -> (LHS >> (select (select (...)))) SmallVector<UDivFoldAction, 6> UDivActions; @@ -1255,8 +1303,7 @@ Instruction *InstCombiner::visitSDiv(BinaryOperator &I) { /// 1) 1/C is exact, or /// 2) reciprocal is allowed. /// If the conversion was successful, the simplified expression "X * 1/C" is -/// returned; otherwise, NULL is returned. -/// +/// returned; otherwise, nullptr is returned. static Instruction *CvtFDivConstToReciprocal(Value *Dividend, Constant *Divisor, bool AllowReciprocal) { if (!isa<ConstantFP>(Divisor)) // TODO: handle vectors. @@ -1295,7 +1342,7 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; - bool AllowReassociate = I.hasUnsafeAlgebra(); + bool AllowReassociate = I.isFast(); bool AllowReciprocal = I.hasAllowReciprocal(); if (Constant *Op1C = dyn_cast<Constant>(Op1)) { @@ -1317,7 +1364,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { Res = BinaryOperator::CreateFMul(X, C); } else if (match(Op0, m_FDiv(m_Value(X), m_Constant(C1)))) { // (X/C1)/C2 => X /(C2*C1) [=> X * 1/(C2*C1) if reciprocal is allowed] - // Constant *C = ConstantExpr::getFMul(C1, C2); if (isNormalFp(C)) { Res = CvtFDivConstToReciprocal(X, C, AllowReciprocal); @@ -1375,7 +1421,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { if (Op0->hasOneUse() && match(Op0, m_FDiv(m_Value(X), m_Value(Y)))) { // (X/Y) / Z => X / (Y*Z) - // if (!isa<Constant>(Y) || !isa<Constant>(Op1)) { NewInst = Builder.CreateFMul(Y, Op1); if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { @@ -1387,7 +1432,6 @@ Instruction *InstCombiner::visitFDiv(BinaryOperator &I) { } } else if (Op1->hasOneUse() && match(Op1, m_FDiv(m_Value(X), m_Value(Y)))) { // Z / (X/Y) => Z*Y / X - // if (!isa<Constant>(Y) || !isa<Constant>(Op0)) { NewInst = Builder.CreateFMul(Op0, Y); if (Instruction *RI = dyn_cast<Instruction>(NewInst)) { @@ -1434,7 +1478,7 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { } // Handle cases involving: rem X, (select Cond, Y, Z) - if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; if (isa<Constant>(Op1)) { @@ -1443,7 +1487,6 @@ Instruction *InstCombiner::commonIRemTransforms(BinaryOperator &I) { if (Instruction *R = FoldOpIntoSelect(I, SI)) return R; } else if (auto *PN = dyn_cast<PHINode>(Op0I)) { - using namespace llvm::PatternMatch; const APInt *Op1Int; if (match(Op1, m_APInt(Op1Int)) && !Op1Int->isMinValue() && (I.getOpcode() == Instruction::URem || @@ -1477,11 +1520,8 @@ Instruction *InstCombiner::visitURem(BinaryOperator &I) { if (Instruction *common = commonIRemTransforms(I)) return common; - // (zext A) urem (zext B) --> zext (A urem B) - if (ZExtInst *ZOp0 = dyn_cast<ZExtInst>(Op0)) - if (Value *ZOp1 = dyn_castZExtVal(Op1, ZOp0->getSrcTy())) - return new ZExtInst(Builder.CreateURem(ZOp0->getOperand(0), ZOp1), - I.getType()); + if (Instruction *NarrowRem = narrowUDivURem(I, Builder)) + return NarrowRem; // X urem Y -> X and Y-1, where Y is a power of 2, if (isKnownToBeAPowerOfTwo(Op1, /*OrZero*/ true, 0, &I)) { @@ -1592,7 +1632,7 @@ Instruction *InstCombiner::visitFRem(BinaryOperator &I) { return replaceInstUsesWith(I, V); // Handle cases involving: rem X, (select Cond, Y, Z) - if (isa<SelectInst>(Op1) && SimplifyDivRemOfSelect(I)) + if (simplifyDivRemOfSelectWithZeroOp(I)) return &I; return nullptr; diff --git a/lib/Transforms/InstCombine/InstCombinePHI.cpp b/lib/Transforms/InstCombine/InstCombinePHI.cpp index 0011412c2bf4..7ee018dbc49b 100644 --- a/lib/Transforms/InstCombine/InstCombinePHI.cpp +++ b/lib/Transforms/InstCombine/InstCombinePHI.cpp @@ -16,7 +16,6 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -27,16 +26,249 @@ using namespace llvm::PatternMatch; /// The PHI arguments will be folded into a single operation with a PHI node /// as input. The debug location of the single operation will be the merged /// locations of the original PHI node arguments. -DebugLoc InstCombiner::PHIArgMergedDebugLoc(PHINode &PN) { +void InstCombiner::PHIArgMergedDebugLoc(Instruction *Inst, PHINode &PN) { auto *FirstInst = cast<Instruction>(PN.getIncomingValue(0)); - const DILocation *Loc = FirstInst->getDebugLoc(); + Inst->setDebugLoc(FirstInst->getDebugLoc()); + // We do not expect a CallInst here, otherwise, N-way merging of DebugLoc + // will be inefficient. + assert(!isa<CallInst>(Inst)); for (unsigned i = 1; i != PN.getNumIncomingValues(); ++i) { auto *I = cast<Instruction>(PN.getIncomingValue(i)); - Loc = DILocation::getMergedLocation(Loc, I->getDebugLoc()); + Inst->applyMergedLocation(Inst->getDebugLoc(), I->getDebugLoc()); + } +} + +// Replace Integer typed PHI PN if the PHI's value is used as a pointer value. +// If there is an existing pointer typed PHI that produces the same value as PN, +// replace PN and the IntToPtr operation with it. Otherwise, synthesize a new +// PHI node: +// +// Case-1: +// bb1: +// int_init = PtrToInt(ptr_init) +// br label %bb2 +// bb2: +// int_val = PHI([int_init, %bb1], [int_val_inc, %bb2] +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ptr_val2 = IntToPtr(int_val) +// ... +// use(ptr_val2) +// ptr_val_inc = ... +// inc_val_inc = PtrToInt(ptr_val_inc) +// +// ==> +// bb1: +// br label %bb2 +// bb2: +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ... +// use(ptr_val) +// ptr_val_inc = ... +// +// Case-2: +// bb1: +// int_ptr = BitCast(ptr_ptr) +// int_init = Load(int_ptr) +// br label %bb2 +// bb2: +// int_val = PHI([int_init, %bb1], [int_val_inc, %bb2] +// ptr_val2 = IntToPtr(int_val) +// ... +// use(ptr_val2) +// ptr_val_inc = ... +// inc_val_inc = PtrToInt(ptr_val_inc) +// ==> +// bb1: +// ptr_init = Load(ptr_ptr) +// br label %bb2 +// bb2: +// ptr_val = PHI([ptr_init, %bb1], [ptr_val_inc, %bb2] +// ... +// use(ptr_val) +// ptr_val_inc = ... +// ... +// +Instruction *InstCombiner::FoldIntegerTypedPHI(PHINode &PN) { + if (!PN.getType()->isIntegerTy()) + return nullptr; + if (!PN.hasOneUse()) + return nullptr; + + auto *IntToPtr = dyn_cast<IntToPtrInst>(PN.user_back()); + if (!IntToPtr) + return nullptr; + + // Check if the pointer is actually used as pointer: + auto HasPointerUse = [](Instruction *IIP) { + for (User *U : IIP->users()) { + Value *Ptr = nullptr; + if (LoadInst *LoadI = dyn_cast<LoadInst>(U)) { + Ptr = LoadI->getPointerOperand(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(U)) { + Ptr = SI->getPointerOperand(); + } else if (GetElementPtrInst *GI = dyn_cast<GetElementPtrInst>(U)) { + Ptr = GI->getPointerOperand(); + } + + if (Ptr && Ptr == IIP) + return true; + } + return false; + }; + + if (!HasPointerUse(IntToPtr)) + return nullptr; + + if (DL.getPointerSizeInBits(IntToPtr->getAddressSpace()) != + DL.getTypeSizeInBits(IntToPtr->getOperand(0)->getType())) + return nullptr; + + SmallVector<Value *, 4> AvailablePtrVals; + for (unsigned i = 0; i != PN.getNumIncomingValues(); ++i) { + Value *Arg = PN.getIncomingValue(i); + + // First look backward: + if (auto *PI = dyn_cast<PtrToIntInst>(Arg)) { + AvailablePtrVals.emplace_back(PI->getOperand(0)); + continue; + } + + // Next look forward: + Value *ArgIntToPtr = nullptr; + for (User *U : Arg->users()) { + if (isa<IntToPtrInst>(U) && U->getType() == IntToPtr->getType() && + (DT.dominates(cast<Instruction>(U), PN.getIncomingBlock(i)) || + cast<Instruction>(U)->getParent() == PN.getIncomingBlock(i))) { + ArgIntToPtr = U; + break; + } + } + + if (ArgIntToPtr) { + AvailablePtrVals.emplace_back(ArgIntToPtr); + continue; + } + + // If Arg is defined by a PHI, allow it. This will also create + // more opportunities iteratively. + if (isa<PHINode>(Arg)) { + AvailablePtrVals.emplace_back(Arg); + continue; + } + + // For a single use integer load: + auto *LoadI = dyn_cast<LoadInst>(Arg); + if (!LoadI) + return nullptr; + + if (!LoadI->hasOneUse()) + return nullptr; + + // Push the integer typed Load instruction into the available + // value set, and fix it up later when the pointer typed PHI + // is synthesized. + AvailablePtrVals.emplace_back(LoadI); + } + + // Now search for a matching PHI + auto *BB = PN.getParent(); + assert(AvailablePtrVals.size() == PN.getNumIncomingValues() && + "Not enough available ptr typed incoming values"); + PHINode *MatchingPtrPHI = nullptr; + for (auto II = BB->begin(), EI = BasicBlock::iterator(BB->getFirstNonPHI()); + II != EI; II++) { + PHINode *PtrPHI = dyn_cast<PHINode>(II); + if (!PtrPHI || PtrPHI == &PN || PtrPHI->getType() != IntToPtr->getType()) + continue; + MatchingPtrPHI = PtrPHI; + for (unsigned i = 0; i != PtrPHI->getNumIncomingValues(); ++i) { + if (AvailablePtrVals[i] != + PtrPHI->getIncomingValueForBlock(PN.getIncomingBlock(i))) { + MatchingPtrPHI = nullptr; + break; + } + } + + if (MatchingPtrPHI) + break; + } + + if (MatchingPtrPHI) { + assert(MatchingPtrPHI->getType() == IntToPtr->getType() && + "Phi's Type does not match with IntToPtr"); + // The PtrToCast + IntToPtr will be simplified later + return CastInst::CreateBitOrPointerCast(MatchingPtrPHI, + IntToPtr->getOperand(0)->getType()); } - return Loc; + // If it requires a conversion for every PHI operand, do not do it. + if (std::all_of(AvailablePtrVals.begin(), AvailablePtrVals.end(), + [&](Value *V) { + return (V->getType() != IntToPtr->getType()) || + isa<IntToPtrInst>(V); + })) + return nullptr; + + // If any of the operand that requires casting is a terminator + // instruction, do not do it. + if (std::any_of(AvailablePtrVals.begin(), AvailablePtrVals.end(), + [&](Value *V) { + return (V->getType() != IntToPtr->getType()) && + isa<TerminatorInst>(V); + })) + return nullptr; + + PHINode *NewPtrPHI = PHINode::Create( + IntToPtr->getType(), PN.getNumIncomingValues(), PN.getName() + ".ptr"); + + InsertNewInstBefore(NewPtrPHI, PN); + SmallDenseMap<Value *, Instruction *> Casts; + for (unsigned i = 0; i != PN.getNumIncomingValues(); ++i) { + auto *IncomingBB = PN.getIncomingBlock(i); + auto *IncomingVal = AvailablePtrVals[i]; + + if (IncomingVal->getType() == IntToPtr->getType()) { + NewPtrPHI->addIncoming(IncomingVal, IncomingBB); + continue; + } + +#ifndef NDEBUG + LoadInst *LoadI = dyn_cast<LoadInst>(IncomingVal); + assert((isa<PHINode>(IncomingVal) || + IncomingVal->getType()->isPointerTy() || + (LoadI && LoadI->hasOneUse())) && + "Can not replace LoadInst with multiple uses"); +#endif + // Need to insert a BitCast. + // For an integer Load instruction with a single use, the load + IntToPtr + // cast will be simplified into a pointer load: + // %v = load i64, i64* %a.ip, align 8 + // %v.cast = inttoptr i64 %v to float ** + // ==> + // %v.ptrp = bitcast i64 * %a.ip to float ** + // %v.cast = load float *, float ** %v.ptrp, align 8 + Instruction *&CI = Casts[IncomingVal]; + if (!CI) { + CI = CastInst::CreateBitOrPointerCast(IncomingVal, IntToPtr->getType(), + IncomingVal->getName() + ".ptr"); + if (auto *IncomingI = dyn_cast<Instruction>(IncomingVal)) { + BasicBlock::iterator InsertPos(IncomingI); + InsertPos++; + if (isa<PHINode>(IncomingI)) + InsertPos = IncomingI->getParent()->getFirstInsertionPt(); + InsertNewInstBefore(CI, *InsertPos); + } else { + auto *InsertBB = &IncomingBB->getParent()->getEntryBlock(); + InsertNewInstBefore(CI, *InsertBB->getFirstInsertionPt()); + } + } + NewPtrPHI->addIncoming(CI, IncomingBB); + } + + // The PtrToCast + IntToPtr will be simplified later + return CastInst::CreateBitOrPointerCast(NewPtrPHI, + IntToPtr->getOperand(0)->getType()); } /// If we have something like phi [add (a,b), add(a,c)] and if a/b/c and the @@ -117,7 +349,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { if (CmpInst *CIOp = dyn_cast<CmpInst>(FirstInst)) { CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), LHSVal, RHSVal); - NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewCI, PN); return NewCI; } @@ -130,7 +362,7 @@ Instruction *InstCombiner::FoldPHIArgBinOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) NewBinOp->andIRFlags(PN.getIncomingValue(i)); - NewBinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewBinOp, PN); return NewBinOp; } @@ -239,7 +471,7 @@ Instruction *InstCombiner::FoldPHIArgGEPIntoPHI(PHINode &PN) { GetElementPtrInst::Create(FirstInst->getSourceElementType(), Base, makeArrayRef(FixedOperands).slice(1)); if (AllInBounds) NewGEP->setIsInBounds(); - NewGEP->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewGEP, PN); return NewGEP; } @@ -399,7 +631,7 @@ Instruction *InstCombiner::FoldPHIArgLoadIntoPHI(PHINode &PN) { for (Value *IncValue : PN.incoming_values()) cast<LoadInst>(IncValue)->setVolatile(false); - NewLI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewLI, PN); return NewLI; } @@ -565,7 +797,7 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { if (CastInst *FirstCI = dyn_cast<CastInst>(FirstInst)) { CastInst *NewCI = CastInst::Create(FirstCI->getOpcode(), PhiVal, PN.getType()); - NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewCI, PN); return NewCI; } @@ -576,14 +808,14 @@ Instruction *InstCombiner::FoldPHIArgOpIntoPHI(PHINode &PN) { for (unsigned i = 1, e = PN.getNumIncomingValues(); i != e; ++i) BinOp->andIRFlags(PN.getIncomingValue(i)); - BinOp->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(BinOp, PN); return BinOp; } CmpInst *CIOp = cast<CmpInst>(FirstInst); CmpInst *NewCI = CmpInst::Create(CIOp->getOpcode(), CIOp->getPredicate(), PhiVal, ConstantOp); - NewCI->setDebugLoc(PHIArgMergedDebugLoc(PN)); + PHIArgMergedDebugLoc(NewCI, PN); return NewCI; } @@ -902,6 +1134,9 @@ Instruction *InstCombiner::visitPHINode(PHINode &PN) { // this PHI only has a single use (a PHI), and if that PHI only has one use (a // PHI)... break the cycle. if (PN.hasOneUse()) { + if (Instruction *Result = FoldIntegerTypedPHI(PN)) + return Result; + Instruction *PHIUser = cast<Instruction>(PN.user_back()); if (PHINode *PU = dyn_cast<PHINode>(PHIUser)) { SmallPtrSet<PHINode*, 16> PotentiallyDeadPHIs; diff --git a/lib/Transforms/InstCombine/InstCombineSelect.cpp b/lib/Transforms/InstCombine/InstCombineSelect.cpp index 4eebe8255998..6f26f7f5cd19 100644 --- a/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -12,12 +12,36 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CmpInstAnalysis.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <utility> + using namespace llvm; using namespace PatternMatch; @@ -69,6 +93,111 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, return Builder.CreateSelect(Builder.CreateICmp(Pred, A, B), A, B); } +/// If one of the constants is zero (we know they can't both be) and we have an +/// icmp instruction with zero, and we have an 'and' with the non-constant value +/// and a power of two we can turn the select into a shift on the result of the +/// 'and'. +/// This folds: +/// select (icmp eq (and X, C1)), C2, C3 +/// iff C1 is a power 2 and the difference between C2 and C3 is a power of 2. +/// To something like: +/// (shr (and (X, C1)), (log2(C1) - log2(C2-C3))) + C3 +/// Or: +/// (shl (and (X, C1)), (log2(C2-C3) - log2(C1))) + C3 +/// With some variations depending if C3 is larger than C2, or the shift +/// isn't needed, or the bit widths don't match. +static Value *foldSelectICmpAnd(Type *SelType, const ICmpInst *IC, + APInt TrueVal, APInt FalseVal, + InstCombiner::BuilderTy &Builder) { + assert(SelType->isIntOrIntVectorTy() && "Not an integer select?"); + + // If this is a vector select, we need a vector compare. + if (SelType->isVectorTy() != IC->getType()->isVectorTy()) + return nullptr; + + Value *V; + APInt AndMask; + bool CreateAnd = false; + ICmpInst::Predicate Pred = IC->getPredicate(); + if (ICmpInst::isEquality(Pred)) { + if (!match(IC->getOperand(1), m_Zero())) + return nullptr; + + V = IC->getOperand(0); + + const APInt *AndRHS; + if (!match(V, m_And(m_Value(), m_Power2(AndRHS)))) + return nullptr; + + AndMask = *AndRHS; + } else if (decomposeBitTestICmp(IC->getOperand(0), IC->getOperand(1), + Pred, V, AndMask)) { + assert(ICmpInst::isEquality(Pred) && "Not equality test?"); + + if (!AndMask.isPowerOf2()) + return nullptr; + + CreateAnd = true; + } else { + return nullptr; + } + + // If both select arms are non-zero see if we have a select of the form + // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic + // for 'x ? 2^n : 0' and fix the thing up at the end. + APInt Offset(TrueVal.getBitWidth(), 0); + if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { + if ((TrueVal - FalseVal).isPowerOf2()) + Offset = FalseVal; + else if ((FalseVal - TrueVal).isPowerOf2()) + Offset = TrueVal; + else + return nullptr; + + // Adjust TrueVal and FalseVal to the offset. + TrueVal -= Offset; + FalseVal -= Offset; + } + + // Make sure one of the select arms is a power of 2. + if (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2()) + return nullptr; + + // Determine which shift is needed to transform result of the 'and' into the + // desired result. + const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; + unsigned ValZeros = ValC.logBase2(); + unsigned AndZeros = AndMask.logBase2(); + + if (CreateAnd) { + // Insert the AND instruction on the input to the truncate. + V = Builder.CreateAnd(V, ConstantInt::get(V->getType(), AndMask)); + } + + // If types don't match we can still convert the select by introducing a zext + // or a trunc of the 'and'. + if (ValZeros > AndZeros) { + V = Builder.CreateZExtOrTrunc(V, SelType); + V = Builder.CreateShl(V, ValZeros - AndZeros); + } else if (ValZeros < AndZeros) { + V = Builder.CreateLShr(V, AndZeros - ValZeros); + V = Builder.CreateZExtOrTrunc(V, SelType); + } else + V = Builder.CreateZExtOrTrunc(V, SelType); + + // Okay, now we know that everything is set up, we just don't know whether we + // have a icmp_ne or icmp_eq and whether the true or false val is the zero. + bool ShouldNotVal = !TrueVal.isNullValue(); + ShouldNotVal ^= Pred == ICmpInst::ICMP_NE; + if (ShouldNotVal) + V = Builder.CreateXor(V, ValC); + + // Apply an offset if needed. + if (!Offset.isNullValue()) + V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); + return V; +} + /// We want to turn code that looks like this: /// %C = or %A, %B /// %D = select %cond, %C, %A @@ -79,8 +208,7 @@ static Value *generateMinMaxSelectPattern(InstCombiner::BuilderTy &Builder, /// Assuming that the specified instruction is an operand to the select, return /// a bitmask indicating which operands of this instruction are foldable if they /// equal the other incoming value of the select. -/// -static unsigned getSelectFoldableOperands(Instruction *I) { +static unsigned getSelectFoldableOperands(BinaryOperator *I) { switch (I->getOpcode()) { case Instruction::Add: case Instruction::Mul: @@ -100,7 +228,7 @@ static unsigned getSelectFoldableOperands(Instruction *I) { /// For the same transformation as the previous function, return the identity /// constant that goes into the select. -static Constant *getSelectFoldableConstant(Instruction *I) { +static APInt getSelectFoldableConstant(BinaryOperator *I) { switch (I->getOpcode()) { default: llvm_unreachable("This cannot happen!"); case Instruction::Add: @@ -110,11 +238,11 @@ static Constant *getSelectFoldableConstant(Instruction *I) { case Instruction::Shl: case Instruction::LShr: case Instruction::AShr: - return Constant::getNullValue(I->getType()); + return APInt::getNullValue(I->getType()->getScalarSizeInBits()); case Instruction::And: - return Constant::getAllOnesValue(I->getType()); + return APInt::getAllOnesValue(I->getType()->getScalarSizeInBits()); case Instruction::Mul: - return ConstantInt::get(I->getType(), 1); + return APInt(I->getType()->getScalarSizeInBits(), 1); } } @@ -157,7 +285,6 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, if (TI->getOpcode() != Instruction::BitCast && (!TI->hasOneUse() || !FI->hasOneUse())) return nullptr; - } else if (!TI->hasOneUse() || !FI->hasOneUse()) { // TODO: The one-use restrictions for a scalar select could be eased if // the fold of a select in visitLoadInst() was enhanced to match a pattern @@ -218,17 +345,11 @@ Instruction *InstCombiner::foldSelectOpOp(SelectInst &SI, Instruction *TI, return BinaryOperator::Create(BO->getOpcode(), Op0, Op1); } -static bool isSelect01(Constant *C1, Constant *C2) { - ConstantInt *C1I = dyn_cast<ConstantInt>(C1); - if (!C1I) - return false; - ConstantInt *C2I = dyn_cast<ConstantInt>(C2); - if (!C2I) - return false; - if (!C1I->isZero() && !C2I->isZero()) // One side must be zero. +static bool isSelect01(const APInt &C1I, const APInt &C2I) { + if (!C1I.isNullValue() && !C2I.isNullValue()) // One side must be zero. return false; - return C1I->isOne() || C1I->isMinusOne() || - C2I->isOne() || C2I->isMinusOne(); + return C1I.isOneValue() || C1I.isAllOnesValue() || + C2I.isOneValue() || C2I.isAllOnesValue(); } /// Try to fold the select into one of the operands to allow further @@ -237,9 +358,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. - if (Instruction *TVI = dyn_cast<Instruction>(TrueVal)) { - if (TVI->hasOneUse() && TVI->getNumOperands() == 2 && - !isa<Constant>(FalseVal)) { + if (auto *TVI = dyn_cast<BinaryOperator>(TrueVal)) { + if (TVI->hasOneUse() && !isa<Constant>(FalseVal)) { if (unsigned SFO = getSelectFoldableOperands(TVI)) { unsigned OpToFold = 0; if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { @@ -249,17 +369,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = getSelectFoldableConstant(TVI); + APInt CI = getSelectFoldableConstant(TVI); Value *OOp = TVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. - if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { + Value *C = ConstantInt::get(OOp->getType(), CI); Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); NewSel->takeName(TVI); - BinaryOperator *TVI_BO = cast<BinaryOperator>(TVI); - BinaryOperator *BO = BinaryOperator::Create(TVI_BO->getOpcode(), + BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); - BO->copyIRFlags(TVI_BO); + BO->copyIRFlags(TVI); return BO; } } @@ -267,9 +389,8 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } } - if (Instruction *FVI = dyn_cast<Instruction>(FalseVal)) { - if (FVI->hasOneUse() && FVI->getNumOperands() == 2 && - !isa<Constant>(TrueVal)) { + if (auto *FVI = dyn_cast<BinaryOperator>(FalseVal)) { + if (FVI->hasOneUse() && !isa<Constant>(TrueVal)) { if (unsigned SFO = getSelectFoldableOperands(FVI)) { unsigned OpToFold = 0; if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { @@ -279,17 +400,19 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, } if (OpToFold) { - Constant *C = getSelectFoldableConstant(FVI); + APInt CI = getSelectFoldableConstant(FVI); Value *OOp = FVI->getOperand(2-OpToFold); // Avoid creating select between 2 constants unless it's selecting // between 0, 1 and -1. - if (!isa<Constant>(OOp) || isSelect01(C, cast<Constant>(OOp))) { + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa<Constant>(OOp) || (OOpIsAPInt && isSelect01(CI, *OOpC))) { + Value *C = ConstantInt::get(OOp->getType(), CI); Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); NewSel->takeName(FVI); - BinaryOperator *FVI_BO = cast<BinaryOperator>(FVI); - BinaryOperator *BO = BinaryOperator::Create(FVI_BO->getOpcode(), + BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), TrueVal, NewSel); - BO->copyIRFlags(FVI_BO); + BO->copyIRFlags(FVI); return BO; } } @@ -313,11 +436,13 @@ Instruction *InstCombiner::foldSelectIntoOp(SelectInst &SI, Value *TrueVal, /// 1. The icmp predicate is inverted /// 2. The select operands are reversed /// 3. The magnitude of C2 and C1 are flipped -static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, +static Value *foldSelectICmpAndOr(const ICmpInst *IC, Value *TrueVal, Value *FalseVal, InstCombiner::BuilderTy &Builder) { - const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !SI.getType()->isIntegerTy()) + // Only handle integer compares. Also, if this is a vector select, we need a + // vector compare. + if (!TrueVal->getType()->isIntOrIntVectorTy() || + TrueVal->getType()->isVectorTy() != IC->getType()->isVectorTy()) return nullptr; Value *CmpLHS = IC->getOperand(0); @@ -371,8 +496,8 @@ static Value *foldSelectICmpAndOr(const SelectInst &SI, Value *TrueVal, bool NeedXor = (!IsEqualZero && OrOnFalseVal) || (IsEqualZero && OrOnTrueVal); bool NeedShift = C1Log != C2Log; - bool NeedZExtTrunc = Y->getType()->getIntegerBitWidth() != - V->getType()->getIntegerBitWidth(); + bool NeedZExtTrunc = Y->getType()->getScalarSizeInBits() != + V->getType()->getScalarSizeInBits(); // Make sure we don't create more instructions than we save. Value *Or = OrOnFalseVal ? FalseVal : TrueVal; @@ -447,8 +572,7 @@ static Value *foldSelectCttzCtlz(ICmpInst *ICI, Value *TrueVal, Value *FalseVal, IntrinsicInst *II = cast<IntrinsicInst>(Count); // Explicitly clear the 'undef_on_zero' flag. IntrinsicInst *NewI = cast<IntrinsicInst>(II->clone()); - Type *Ty = NewI->getArgOperand(1)->getType(); - NewI->setArgOperand(1, Constant::getNullValue(Ty)); + NewI->setArgOperand(1, ConstantInt::getFalse(NewI->getContext())); Builder.Insert(NewI); return Builder.CreateZExtOrTrunc(NewI, ValueOnZero->getType()); } @@ -597,6 +721,9 @@ canonicalizeMinMaxWithConstant(SelectInst &Sel, ICmpInst &Cmp, /// Visit a SelectInst that has an ICmpInst as its first operand. Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst *ICI) { + Value *TrueVal = SI.getTrueValue(); + Value *FalseVal = SI.getFalseValue(); + if (Instruction *NewSel = canonicalizeMinMaxWithConstant(SI, *ICI, Builder)) return NewSel; @@ -605,40 +732,52 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, ICmpInst::Predicate Pred = ICI->getPredicate(); Value *CmpLHS = ICI->getOperand(0); Value *CmpRHS = ICI->getOperand(1); - Value *TrueVal = SI.getTrueValue(); - Value *FalseVal = SI.getFalseValue(); // Transform (X >s -1) ? C1 : C2 --> ((X >>s 31) & (C2 - C1)) + C1 // and (X <s 0) ? C2 : C1 --> ((X >>s 31) & (C2 - C1)) + C1 // FIXME: Type and constness constraints could be lifted, but we have to // watch code size carefully. We should consider xor instead of // sub/add when we decide to do that. - if (IntegerType *Ty = dyn_cast<IntegerType>(CmpLHS->getType())) { - if (TrueVal->getType() == Ty) { - if (ConstantInt *Cmp = dyn_cast<ConstantInt>(CmpRHS)) { - ConstantInt *C1 = nullptr, *C2 = nullptr; - if (Pred == ICmpInst::ICMP_SGT && Cmp->isMinusOne()) { - C1 = dyn_cast<ConstantInt>(TrueVal); - C2 = dyn_cast<ConstantInt>(FalseVal); - } else if (Pred == ICmpInst::ICMP_SLT && Cmp->isZero()) { - C1 = dyn_cast<ConstantInt>(FalseVal); - C2 = dyn_cast<ConstantInt>(TrueVal); - } - if (C1 && C2) { + // TODO: Merge this with foldSelectICmpAnd somehow. + if (CmpLHS->getType()->isIntOrIntVectorTy() && + CmpLHS->getType() == TrueVal->getType()) { + const APInt *C1, *C2; + if (match(TrueVal, m_APInt(C1)) && match(FalseVal, m_APInt(C2))) { + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *X; + APInt Mask; + if (decomposeBitTestICmp(CmpLHS, CmpRHS, Pred, X, Mask, false)) { + if (Mask.isSignMask()) { + assert(X == CmpLHS && "Expected to use the compare input directly"); + assert(ICmpInst::isEquality(Pred) && "Expected equality predicate"); + + if (Pred == ICmpInst::ICMP_NE) + std::swap(C1, C2); + // This shift results in either -1 or 0. - Value *AShr = Builder.CreateAShr(CmpLHS, Ty->getBitWidth() - 1); + Value *AShr = Builder.CreateAShr(X, Mask.getBitWidth() - 1); // Check if we can express the operation with a single or. - if (C2->isMinusOne()) - return replaceInstUsesWith(SI, Builder.CreateOr(AShr, C1)); + if (C2->isAllOnesValue()) + return replaceInstUsesWith(SI, Builder.CreateOr(AShr, *C1)); - Value *And = Builder.CreateAnd(AShr, C2->getValue() - C1->getValue()); - return replaceInstUsesWith(SI, Builder.CreateAdd(And, C1)); + Value *And = Builder.CreateAnd(AShr, *C2 - *C1); + return replaceInstUsesWith(SI, Builder.CreateAdd(And, + ConstantInt::get(And->getType(), *C1))); } } } } + { + const APInt *TrueValC, *FalseValC; + if (match(TrueVal, m_APInt(TrueValC)) && + match(FalseVal, m_APInt(FalseValC))) + if (Value *V = foldSelectICmpAnd(SI.getType(), ICI, *TrueValC, + *FalseValC, Builder)) + return replaceInstUsesWith(SI, V); + } + // NOTE: if we wanted to, this is where to detect integer MIN/MAX if (CmpRHS != CmpLHS && isa<Constant>(CmpRHS)) { @@ -703,7 +842,7 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, } } - if (Value *V = foldSelectICmpAndOr(SI, TrueVal, FalseVal, Builder)) + if (Value *V = foldSelectICmpAndOr(ICI, TrueVal, FalseVal, Builder)) return replaceInstUsesWith(SI, V); if (Value *V = foldSelectCttzCtlz(ICI, TrueVal, FalseVal, Builder)) @@ -722,7 +861,6 @@ Instruction *InstCombiner::foldSelectInstWithICmp(SelectInst &SI, /// Z = select X, Y, 0 /// /// because Y is not live in BB1/BB2. -/// static bool canSelectOperandBeMappingIntoPredBlock(const Value *V, const SelectInst &SI) { // If the value is a non-instruction value like a constant or argument, it @@ -864,78 +1002,6 @@ Instruction *InstCombiner::foldSPFofSPF(Instruction *Inner, return nullptr; } -/// If one of the constants is zero (we know they can't both be) and we have an -/// icmp instruction with zero, and we have an 'and' with the non-constant value -/// and a power of two we can turn the select into a shift on the result of the -/// 'and'. -static Value *foldSelectICmpAnd(const SelectInst &SI, APInt TrueVal, - APInt FalseVal, - InstCombiner::BuilderTy &Builder) { - const ICmpInst *IC = dyn_cast<ICmpInst>(SI.getCondition()); - if (!IC || !IC->isEquality() || !SI.getType()->isIntegerTy()) - return nullptr; - - if (!match(IC->getOperand(1), m_Zero())) - return nullptr; - - ConstantInt *AndRHS; - Value *LHS = IC->getOperand(0); - if (!match(LHS, m_And(m_Value(), m_ConstantInt(AndRHS)))) - return nullptr; - - // If both select arms are non-zero see if we have a select of the form - // 'x ? 2^n + C : C'. Then we can offset both arms by C, use the logic - // for 'x ? 2^n : 0' and fix the thing up at the end. - APInt Offset(TrueVal.getBitWidth(), 0); - if (!TrueVal.isNullValue() && !FalseVal.isNullValue()) { - if ((TrueVal - FalseVal).isPowerOf2()) - Offset = FalseVal; - else if ((FalseVal - TrueVal).isPowerOf2()) - Offset = TrueVal; - else - return nullptr; - - // Adjust TrueVal and FalseVal to the offset. - TrueVal -= Offset; - FalseVal -= Offset; - } - - // Make sure the mask in the 'and' and one of the select arms is a power of 2. - if (!AndRHS->getValue().isPowerOf2() || - (!TrueVal.isPowerOf2() && !FalseVal.isPowerOf2())) - return nullptr; - - // Determine which shift is needed to transform result of the 'and' into the - // desired result. - const APInt &ValC = !TrueVal.isNullValue() ? TrueVal : FalseVal; - unsigned ValZeros = ValC.logBase2(); - unsigned AndZeros = AndRHS->getValue().logBase2(); - - // If types don't match we can still convert the select by introducing a zext - // or a trunc of the 'and'. The trunc case requires that all of the truncated - // bits are zero, we can figure that out by looking at the 'and' mask. - if (AndZeros >= ValC.getBitWidth()) - return nullptr; - - Value *V = Builder.CreateZExtOrTrunc(LHS, SI.getType()); - if (ValZeros > AndZeros) - V = Builder.CreateShl(V, ValZeros - AndZeros); - else if (ValZeros < AndZeros) - V = Builder.CreateLShr(V, AndZeros - ValZeros); - - // Okay, now we know that everything is set up, we just don't know whether we - // have a icmp_ne or icmp_eq and whether the true or false val is the zero. - bool ShouldNotVal = !TrueVal.isNullValue(); - ShouldNotVal ^= IC->getPredicate() == ICmpInst::ICMP_NE; - if (ShouldNotVal) - V = Builder.CreateXor(V, ValC); - - // Apply an offset if needed. - if (!Offset.isNullValue()) - V = Builder.CreateAdd(V, ConstantInt::get(V->getType(), Offset)); - return V; -} - /// Turn select C, (X + Y), (X - Y) --> (X + (select C, Y, (-Y))). /// This is even legal for FP. static Instruction *foldAddSubSelect(SelectInst &SI, @@ -1151,12 +1217,100 @@ static Instruction *foldSelectCmpBitcasts(SelectInst &Sel, return CastInst::CreateBitOrPointerCast(NewSel, Sel.getType()); } +/// Try to eliminate select instructions that test the returned flag of cmpxchg +/// instructions. +/// +/// If a select instruction tests the returned flag of a cmpxchg instruction and +/// selects between the returned value of the cmpxchg instruction its compare +/// operand, the result of the select will always be equal to its false value. +/// For example: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 1 +/// %2 = extractvalue { i64, i1 } %0, 0 +/// %3 = select i1 %1, i64 %compare, i64 %2 +/// ret i64 %3 +/// +/// The returned value of the cmpxchg instruction (%2) is the original value +/// located at %ptr prior to any update. If the cmpxchg operation succeeds, %2 +/// must have been equal to %compare. Thus, the result of the select is always +/// equal to %2, and the code can be simplified to: +/// +/// %0 = cmpxchg i64* %ptr, i64 %compare, i64 %new_value seq_cst seq_cst +/// %1 = extractvalue { i64, i1 } %0, 0 +/// ret i64 %1 +/// +static Instruction *foldSelectCmpXchg(SelectInst &SI) { + // A helper that determines if V is an extractvalue instruction whose + // aggregate operand is a cmpxchg instruction and whose single index is equal + // to I. If such conditions are true, the helper returns the cmpxchg + // instruction; otherwise, a nullptr is returned. + auto isExtractFromCmpXchg = [](Value *V, unsigned I) -> AtomicCmpXchgInst * { + auto *Extract = dyn_cast<ExtractValueInst>(V); + if (!Extract) + return nullptr; + if (Extract->getIndices()[0] != I) + return nullptr; + return dyn_cast<AtomicCmpXchgInst>(Extract->getAggregateOperand()); + }; + + // If the select has a single user, and this user is a select instruction that + // we can simplify, skip the cmpxchg simplification for now. + if (SI.hasOneUse()) + if (auto *Select = dyn_cast<SelectInst>(SI.user_back())) + if (Select->getCondition() == SI.getCondition()) + if (Select->getFalseValue() == SI.getTrueValue() || + Select->getTrueValue() == SI.getFalseValue()) + return nullptr; + + // Ensure the select condition is the returned flag of a cmpxchg instruction. + auto *CmpXchg = isExtractFromCmpXchg(SI.getCondition(), 1); + if (!CmpXchg) + return nullptr; + + // Check the true value case: The true value of the select is the returned + // value of the same cmpxchg used by the condition, and the false value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getTrueValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getFalseValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + // Check the false value case: The false value of the select is the returned + // value of the same cmpxchg used by the condition, and the true value is the + // cmpxchg instruction's compare operand. + if (auto *X = isExtractFromCmpXchg(SI.getFalseValue(), 0)) + if (X == CmpXchg && X->getCompareOperand() == SI.getTrueValue()) { + SI.setTrueValue(SI.getFalseValue()); + return &SI; + } + + return nullptr; +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); Value *FalseVal = SI.getFalseValue(); Type *SelType = SI.getType(); + // FIXME: Remove this workaround when freeze related patches are done. + // For select with undef operand which feeds into an equality comparison, + // don't simplify it so loop unswitch can know the equality comparison + // may have an undef operand. This is a workaround for PR31652 caused by + // descrepancy about branch on undef between LoopUnswitch and GVN. + if (isa<UndefValue>(TrueVal) || isa<UndefValue>(FalseVal)) { + if (llvm::any_of(SI.users(), [&](User *U) { + ICmpInst *CI = dyn_cast<ICmpInst>(U); + if (CI && CI->isEquality()) + return true; + return false; + })) { + return nullptr; + } + } + if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, SQ.getWithInstruction(&SI))) return replaceInstUsesWith(SI, V); @@ -1246,12 +1400,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } - if (ConstantInt *TrueValC = dyn_cast<ConstantInt>(TrueVal)) - if (ConstantInt *FalseValC = dyn_cast<ConstantInt>(FalseVal)) - if (Value *V = foldSelectICmpAnd(SI, TrueValC->getValue(), - FalseValC->getValue(), Builder)) - return replaceInstUsesWith(SI, V); - // See if we are selecting two values based on a comparison of the two values. if (FCmpInst *FCI = dyn_cast<FCmpInst>(CondVal)) { if (FCI->getOperand(0) == TrueVal && FCI->getOperand(1) == FalseVal) { @@ -1373,9 +1521,17 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { auto SPF = SPR.Flavor; if (SelectPatternResult::isMinOrMax(SPF)) { - // Canonicalize so that type casts are outside select patterns. - if (LHS->getType()->getPrimitiveSizeInBits() != - SelType->getPrimitiveSizeInBits()) { + // Canonicalize so that + // - type casts are outside select patterns. + // - float clamp is transformed to min/max pattern + + bool IsCastNeeded = LHS->getType() != SelType; + Value *CmpLHS = cast<CmpInst>(CondVal)->getOperand(0); + Value *CmpRHS = cast<CmpInst>(CondVal)->getOperand(1); + if (IsCastNeeded || + (LHS->getType()->isFPOrFPVectorTy() && + ((CmpLHS != LHS && CmpLHS != RHS) || + (CmpRHS != LHS && CmpRHS != RHS)))) { CmpInst::Predicate Pred = getCmpPredicateForMinMax(SPF, SPR.Ordered); Value *Cmp; @@ -1388,10 +1544,12 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Cmp = Builder.CreateFCmp(Pred, LHS, RHS); } - Value *NewSI = Builder.CreateCast( - CastOp, Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI), - SelType); - return replaceInstUsesWith(SI, NewSI); + Value *NewSI = Builder.CreateSelect(Cmp, LHS, RHS, SI.getName(), &SI); + if (!IsCastNeeded) + return replaceInstUsesWith(SI, NewSI); + + Value *NewCast = Builder.CreateCast(CastOp, NewSI, SelType); + return replaceInstUsesWith(SI, NewCast); } } @@ -1485,6 +1643,46 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { } } + // Try to simplify a binop sandwiched between 2 selects with the same + // condition. + // select(C, binop(select(C, X, Y), W), Z) -> select(C, binop(X, W), Z) + BinaryOperator *TrueBO; + if (match(TrueVal, m_OneUse(m_BinOp(TrueBO)))) { + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(0))) { + if (TrueBOSI->getCondition() == CondVal) { + TrueBO->setOperand(0, TrueBOSI->getTrueValue()); + Worklist.Add(TrueBO); + return &SI; + } + } + if (auto *TrueBOSI = dyn_cast<SelectInst>(TrueBO->getOperand(1))) { + if (TrueBOSI->getCondition() == CondVal) { + TrueBO->setOperand(1, TrueBOSI->getTrueValue()); + Worklist.Add(TrueBO); + return &SI; + } + } + } + + // select(C, Z, binop(select(C, X, Y), W)) -> select(C, Z, binop(Y, W)) + BinaryOperator *FalseBO; + if (match(FalseVal, m_OneUse(m_BinOp(FalseBO)))) { + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(0))) { + if (FalseBOSI->getCondition() == CondVal) { + FalseBO->setOperand(0, FalseBOSI->getFalseValue()); + Worklist.Add(FalseBO); + return &SI; + } + } + if (auto *FalseBOSI = dyn_cast<SelectInst>(FalseBO->getOperand(1))) { + if (FalseBOSI->getCondition() == CondVal) { + FalseBO->setOperand(1, FalseBOSI->getFalseValue()); + Worklist.Add(FalseBO); + return &SI; + } + } + } + if (BinaryOperator::isNot(CondVal)) { SI.setOperand(0, BinaryOperator::getNotArgument(CondVal)); SI.setOperand(1, FalseVal); @@ -1501,10 +1699,6 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { return replaceInstUsesWith(SI, V); return &SI; } - - if (isa<ConstantAggregateZero>(CondVal)) { - return replaceInstUsesWith(SI, FalseVal); - } } // See if we can determine the result of this select based on a dominating @@ -1515,9 +1709,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (PBI && PBI->isConditional() && PBI->getSuccessor(0) != PBI->getSuccessor(1) && (PBI->getSuccessor(0) == Parent || PBI->getSuccessor(1) == Parent)) { - bool CondIsFalse = PBI->getSuccessor(1) == Parent; + bool CondIsTrue = PBI->getSuccessor(0) == Parent; Optional<bool> Implication = isImpliedCondition( - PBI->getCondition(), SI.getCondition(), DL, CondIsFalse); + PBI->getCondition(), SI.getCondition(), DL, CondIsTrue); if (Implication) { Value *V = *Implication ? TrueVal : FalseVal; return replaceInstUsesWith(SI, V); @@ -1542,5 +1736,9 @@ Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { if (Instruction *BitCastSel = foldSelectCmpBitcasts(SI, Builder)) return BitCastSel; + // Simplify selects that test the returned flag of cmpxchg instructions. + if (Instruction *Select = foldSelectCmpXchg(SI)) + return Select; + return nullptr; } diff --git a/lib/Transforms/InstCombine/InstCombineShifts.cpp b/lib/Transforms/InstCombine/InstCombineShifts.cpp index 7ed141c7fd79..44bbb84686ab 100644 --- a/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -310,6 +310,40 @@ static Value *getShiftedValue(Value *V, unsigned NumBits, bool isLeftShift, } } +// If this is a bitwise operator or add with a constant RHS we might be able +// to pull it through a shift. +static bool canShiftBinOpWithConstantRHS(BinaryOperator &Shift, + BinaryOperator *BO, + const APInt &C) { + bool IsValid = true; // Valid only for And, Or Xor, + bool HighBitSet = false; // Transform ifhigh bit of constant set? + + switch (BO->getOpcode()) { + default: IsValid = false; break; // Do not perform transform! + case Instruction::Add: + IsValid = Shift.getOpcode() == Instruction::Shl; + break; + case Instruction::Or: + case Instruction::Xor: + HighBitSet = false; + break; + case Instruction::And: + HighBitSet = true; + break; + } + + // If this is a signed shift right, and the high bit is modified + // by the logical operation, do not perform the transformation. + // The HighBitSet boolean indicates the value of the high bit of + // the constant which would cause it to be modified for this + // operation. + // + if (IsValid && Shift.getOpcode() == Instruction::AShr) + IsValid = C.isNegative() == HighBitSet; + + return IsValid; +} + Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, BinaryOperator &I) { bool isLeftShift = I.getOpcode() == Instruction::Shl; @@ -470,35 +504,11 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, // If the operand is a bitwise operator with a constant RHS, and the // shift is the only use, we can pull it out of the shift. - if (ConstantInt *Op0C = dyn_cast<ConstantInt>(Op0BO->getOperand(1))) { - bool isValid = true; // Valid only for And, Or, Xor - bool highBitSet = false; // Transform if high bit of constant set? - - switch (Op0BO->getOpcode()) { - default: isValid = false; break; // Do not perform transform! - case Instruction::Add: - isValid = isLeftShift; - break; - case Instruction::Or: - case Instruction::Xor: - highBitSet = false; - break; - case Instruction::And: - highBitSet = true; - break; - } - - // If this is a signed shift right, and the high bit is modified - // by the logical operation, do not perform the transformation. - // The highBitSet boolean indicates the value of the high bit of - // the constant which would cause it to be modified for this - // operation. - // - if (isValid && I.getOpcode() == Instruction::AShr) - isValid = Op0C->getValue()[TypeBits-1] == highBitSet; - - if (isValid) { - Constant *NewRHS = ConstantExpr::get(I.getOpcode(), Op0C, Op1); + const APInt *Op0C; + if (match(Op0BO->getOperand(1), m_APInt(Op0C))) { + if (canShiftBinOpWithConstantRHS(I, Op0BO, *Op0C)) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(Op0BO->getOperand(1)), Op1); Value *NewShift = Builder.CreateBinOp(I.getOpcode(), Op0BO->getOperand(0), Op1); @@ -508,6 +518,67 @@ Instruction *InstCombiner::FoldShiftByConstant(Value *Op0, Constant *Op1, NewRHS); } } + + // If the operand is a subtract with a constant LHS, and the shift + // is the only use, we can pull it out of the shift. + // This folds (shl (sub C1, X), C2) -> (sub (C1 << C2), (shl X, C2)) + if (isLeftShift && Op0BO->getOpcode() == Instruction::Sub && + match(Op0BO->getOperand(0), m_APInt(Op0C))) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(Op0BO->getOperand(0)), Op1); + + Value *NewShift = Builder.CreateShl(Op0BO->getOperand(1), Op1); + NewShift->takeName(Op0BO); + + return BinaryOperator::CreateSub(NewRHS, NewShift); + } + } + + // If we have a select that conditionally executes some binary operator, + // see if we can pull it the select and operator through the shift. + // + // For example, turning: + // shl (select C, (add X, C1), X), C2 + // Into: + // Y = shl X, C2 + // select C, (add Y, C1 << C2), Y + Value *Cond; + BinaryOperator *TBO; + Value *FalseVal; + if (match(Op0, m_Select(m_Value(Cond), m_OneUse(m_BinOp(TBO)), + m_Value(FalseVal)))) { + const APInt *C; + if (!isa<Constant>(FalseVal) && TBO->getOperand(0) == FalseVal && + match(TBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, TBO, *C)) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(TBO->getOperand(1)), Op1); + + Value *NewShift = + Builder.CreateBinOp(I.getOpcode(), FalseVal, Op1); + Value *NewOp = Builder.CreateBinOp(TBO->getOpcode(), NewShift, + NewRHS); + return SelectInst::Create(Cond, NewOp, NewShift); + } + } + + BinaryOperator *FBO; + Value *TrueVal; + if (match(Op0, m_Select(m_Value(Cond), m_Value(TrueVal), + m_OneUse(m_BinOp(FBO))))) { + const APInt *C; + if (!isa<Constant>(TrueVal) && FBO->getOperand(0) == TrueVal && + match(FBO->getOperand(1), m_APInt(C)) && + canShiftBinOpWithConstantRHS(I, FBO, *C)) { + Constant *NewRHS = ConstantExpr::get(I.getOpcode(), + cast<Constant>(FBO->getOperand(1)), Op1); + + Value *NewShift = + Builder.CreateBinOp(I.getOpcode(), TrueVal, Op1); + Value *NewOp = Builder.CreateBinOp(FBO->getOpcode(), NewShift, + NewRHS); + return SelectInst::Create(Cond, NewShift, NewOp); + } } } @@ -543,8 +614,8 @@ Instruction *InstCombiner::visitShl(BinaryOperator &I) { return new ZExtInst(Builder.CreateShl(X, ShAmt), Ty); } - // (X >>u C) << C --> X & (-1 << C) - if (match(Op0, m_LShr(m_Value(X), m_Specific(Op1)))) { + // (X >> C) << C --> X & (-1 << C) + if (match(Op0, m_Shr(m_Value(X), m_Specific(Op1)))) { APInt Mask(APInt::getHighBitsSet(BitWidth, BitWidth - ShAmt)); return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } @@ -680,6 +751,15 @@ Instruction *InstCombiner::visitLShr(BinaryOperator &I) { return BinaryOperator::CreateAnd(X, ConstantInt::get(Ty, Mask)); } + if (match(Op0, m_OneUse(m_ZExt(m_Value(X)))) && + (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { + assert(ShAmt < X->getType()->getScalarSizeInBits() && + "Big shift not simplified to zero?"); + // lshr (zext iM X to iN), C --> zext (lshr X, C) to iN + Value *NewLShr = Builder.CreateLShr(X, ShAmt); + return new ZExtInst(NewLShr, Ty); + } + if (match(Op0, m_SExt(m_Value(X))) && (!Ty->isIntegerTy() || shouldChangeType(Ty, X->getType()))) { // Are we moving the sign bit to the low bit and widening with high zeros? @@ -778,6 +858,15 @@ Instruction *InstCombiner::visitAShr(BinaryOperator &I) { return BinaryOperator::CreateAShr(X, ConstantInt::get(Ty, AmtSum)); } + if (match(Op0, m_OneUse(m_SExt(m_Value(X)))) && + (Ty->isVectorTy() || shouldChangeType(Ty, X->getType()))) { + // ashr (sext X), C --> sext (ashr X, C') + Type *SrcTy = X->getType(); + ShAmt = std::min(ShAmt, SrcTy->getScalarSizeInBits() - 1); + Value *NewSh = Builder.CreateAShr(X, ConstantInt::get(SrcTy, ShAmt)); + return new SExtInst(NewSh, Ty); + } + // If the shifted-out value is known-zero, then this is an exact shift. if (!I.isExact() && MaskedValueIsZero(Op0, APInt::getLowBitsSet(BitWidth, ShAmt), 0, &I)) { diff --git a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp index a20f474cbf40..a2e757cb4273 100644 --- a/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp +++ b/lib/Transforms/InstCombine/InstCombineSimplifyDemanded.cpp @@ -396,50 +396,50 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, /// If the high-bits of an ADD/SUB are not demanded, then we do not care /// about the high bits of the operands. unsigned NLZ = DemandedMask.countLeadingZeros(); - if (NLZ > 0) { - // Right fill the mask of bits for this ADD/SUB to demand the most - // significant bit and all those below it. - APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); - if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || - SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || - ShrinkDemandedConstant(I, 1, DemandedFromOps) || - SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + // Right fill the mask of bits for this ADD/SUB to demand the most + // significant bit and all those below it. + APInt DemandedFromOps(APInt::getLowBitsSet(BitWidth, BitWidth-NLZ)); + if (ShrinkDemandedConstant(I, 0, DemandedFromOps) || + SimplifyDemandedBits(I, 0, DemandedFromOps, LHSKnown, Depth + 1) || + ShrinkDemandedConstant(I, 1, DemandedFromOps) || + SimplifyDemandedBits(I, 1, DemandedFromOps, RHSKnown, Depth + 1)) { + if (NLZ > 0) { // Disable the nsw and nuw flags here: We can no longer guarantee that // we won't wrap after simplification. Removing the nsw/nuw flags is // legal here because the top bit is not demanded. BinaryOperator &BinOP = *cast<BinaryOperator>(I); BinOP.setHasNoSignedWrap(false); BinOP.setHasNoUnsignedWrap(false); - return I; } - - // If we are known to be adding/subtracting zeros to every bit below - // the highest demanded bit, we just return the other side. - if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) - return I->getOperand(0); - // We can't do this with the LHS for subtraction, unless we are only - // demanding the LSB. - if ((I->getOpcode() == Instruction::Add || - DemandedFromOps.isOneValue()) && - DemandedFromOps.isSubsetOf(LHSKnown.Zero)) - return I->getOperand(1); + return I; } - // Otherwise just hand the add/sub off to computeKnownBits to fill in - // the known zeros and ones. - computeKnownBits(V, Known, Depth, CxtI); + // If we are known to be adding/subtracting zeros to every bit below + // the highest demanded bit, we just return the other side. + if (DemandedFromOps.isSubsetOf(RHSKnown.Zero)) + return I->getOperand(0); + // We can't do this with the LHS for subtraction, unless we are only + // demanding the LSB. + if ((I->getOpcode() == Instruction::Add || + DemandedFromOps.isOneValue()) && + DemandedFromOps.isSubsetOf(LHSKnown.Zero)) + return I->getOperand(1); + + // Otherwise just compute the known bits of the result. + bool NSW = cast<OverflowingBinaryOperator>(I)->hasNoSignedWrap(); + Known = KnownBits::computeForAddSub(I->getOpcode() == Instruction::Add, + NSW, LHSKnown, RHSKnown); break; } case Instruction::Shl: { const APInt *SA; if (match(I->getOperand(1), m_APInt(SA))) { const APInt *ShrAmt; - if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) { - Instruction *Shr = cast<Instruction>(I->getOperand(0)); - if (Value *R = simplifyShrShlDemandedBits( - Shr, *ShrAmt, I, *SA, DemandedMask, Known)) - return R; - } + if (match(I->getOperand(0), m_Shr(m_Value(), m_APInt(ShrAmt)))) + if (Instruction *Shr = dyn_cast<Instruction>(I->getOperand(0))) + if (Value *R = simplifyShrShlDemandedBits(Shr, *ShrAmt, I, *SA, + DemandedMask, Known)) + return R; uint64_t ShiftAmt = SA->getLimitedValue(BitWidth-1); APInt DemandedMaskIn(DemandedMask.lshr(ShiftAmt)); @@ -521,26 +521,25 @@ Value *InstCombiner::SimplifyDemandedUseBits(Value *V, APInt DemandedMask, if (SimplifyDemandedBits(I, 0, DemandedMaskIn, Known, Depth + 1)) return I; + unsigned SignBits = ComputeNumSignBits(I->getOperand(0), Depth + 1, CxtI); + assert(!Known.hasConflict() && "Bits known to be one AND zero?"); - // Compute the new bits that are at the top now. - APInt HighBits(APInt::getHighBitsSet(BitWidth, ShiftAmt)); + // Compute the new bits that are at the top now plus sign bits. + APInt HighBits(APInt::getHighBitsSet( + BitWidth, std::min(SignBits + ShiftAmt - 1, BitWidth))); Known.Zero.lshrInPlace(ShiftAmt); Known.One.lshrInPlace(ShiftAmt); - // Handle the sign bits. - APInt SignMask(APInt::getSignMask(BitWidth)); - // Adjust to where it is now in the mask. - SignMask.lshrInPlace(ShiftAmt); - // If the input sign bit is known to be zero, or if none of the top bits // are demanded, turn this into an unsigned shift right. - if (BitWidth <= ShiftAmt || Known.Zero[BitWidth-ShiftAmt-1] || + assert(BitWidth > ShiftAmt && "Shift amount not saturated?"); + if (Known.Zero[BitWidth-ShiftAmt-1] || !DemandedMask.intersects(HighBits)) { BinaryOperator *LShr = BinaryOperator::CreateLShr(I->getOperand(0), I->getOperand(1)); LShr->setIsExact(cast<BinaryOperator>(I)->isExact()); return InsertNewInstWith(LShr, *I); - } else if (Known.One.intersects(SignMask)) { // New bits are known one. + } else if (Known.One[BitWidth-ShiftAmt-1]) { // New bits are known one. Known.One |= HighBits; } } @@ -993,22 +992,23 @@ Value *InstCombiner::SimplifyDemandedVectorElts(Value *V, APInt DemandedElts, break; } + // The element inserted overwrites whatever was there, so the input demanded + // set is simpler than the output set. + unsigned IdxNo = Idx->getZExtValue(); + APInt PreInsertDemandedElts = DemandedElts; + if (IdxNo < VWidth) + PreInsertDemandedElts.clearBit(IdxNo); + TmpV = SimplifyDemandedVectorElts(I->getOperand(0), PreInsertDemandedElts, + UndefElts, Depth + 1); + if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } + // If this is inserting an element that isn't demanded, remove this // insertelement. - unsigned IdxNo = Idx->getZExtValue(); if (IdxNo >= VWidth || !DemandedElts[IdxNo]) { Worklist.Add(I); return I->getOperand(0); } - // Otherwise, the element inserted overwrites whatever was there, so the - // input demanded set is simpler than the output set. - APInt DemandedElts2 = DemandedElts; - DemandedElts2.clearBit(IdxNo); - TmpV = SimplifyDemandedVectorElts(I->getOperand(0), DemandedElts2, - UndefElts, Depth + 1); - if (TmpV) { I->setOperand(0, TmpV); MadeChange = true; } - // The inserted element is defined. UndefElts.clearBit(IdxNo); break; diff --git a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp index dd71a31b644b..65a96b965227 100644 --- a/lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ b/lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -13,10 +13,33 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <utility> + using namespace llvm; using namespace PatternMatch; @@ -90,7 +113,7 @@ Instruction *InstCombiner::scalarizePHI(ExtractElementInst &EI, PHINode *PN) { // Verify that this PHI user has one use, which is the PHI itself, // and that it is a binary operation which is cheap to scalarize. - // otherwise return NULL. + // otherwise return nullptr. if (!PHIUser->hasOneUse() || !(PHIUser->user_back() == PN) || !(isa<BinaryOperator>(PHIUser)) || !cheapToScalarize(PHIUser, true)) return nullptr; @@ -255,39 +278,6 @@ Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) { Worklist.AddValue(EE); return CastInst::Create(CI->getOpcode(), EE, EI.getType()); } - } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { - if (SI->hasOneUse()) { - // TODO: For a select on vectors, it might be useful to do this if it - // has multiple extractelement uses. For vector select, that seems to - // fight the vectorizer. - - // If we are extracting an element from a vector select or a select on - // vectors, create a select on the scalars extracted from the vector - // arguments. - Value *TrueVal = SI->getTrueValue(); - Value *FalseVal = SI->getFalseValue(); - - Value *Cond = SI->getCondition(); - if (Cond->getType()->isVectorTy()) { - Cond = Builder.CreateExtractElement(Cond, - EI.getIndexOperand(), - Cond->getName() + ".elt"); - } - - Value *V1Elem - = Builder.CreateExtractElement(TrueVal, - EI.getIndexOperand(), - TrueVal->getName() + ".elt"); - - Value *V2Elem - = Builder.CreateExtractElement(FalseVal, - EI.getIndexOperand(), - FalseVal->getName() + ".elt"); - return SelectInst::Create(Cond, - V1Elem, - V2Elem, - SI->getName() + ".elt"); - } } } return nullptr; @@ -454,7 +444,7 @@ static void replaceExtractElements(InsertElementInst *InsElt, /// /// Note: we intentionally don't try to fold earlier shuffles since they have /// often been chosen carefully to be efficiently implementable on the target. -typedef std::pair<Value *, Value *> ShuffleOps; +using ShuffleOps = std::pair<Value *, Value *>; static ShuffleOps collectShuffleElements(Value *V, SmallVectorImpl<Constant *> &Mask, @@ -615,20 +605,26 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { Value *SplatVal = InsElt.getOperand(1); InsertElementInst *CurrIE = &InsElt; SmallVector<bool, 16> ElementPresent(NumElements, false); + InsertElementInst *FirstIE = nullptr; // Walk the chain backwards, keeping track of which indices we inserted into, // until we hit something that isn't an insert of the splatted value. while (CurrIE) { - ConstantInt *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); + auto *Idx = dyn_cast<ConstantInt>(CurrIE->getOperand(2)); if (!Idx || CurrIE->getOperand(1) != SplatVal) return nullptr; - // Check none of the intermediate steps have any additional uses. - if ((CurrIE != &InsElt) && !CurrIE->hasOneUse()) + auto *NextIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + // Check none of the intermediate steps have any additional uses, except + // for the root insertelement instruction, which can be re-used, if it + // inserts at position 0. + if (CurrIE != &InsElt && + (!CurrIE->hasOneUse() && (NextIE != nullptr || !Idx->isZero()))) return nullptr; ElementPresent[Idx->getZExtValue()] = true; - CurrIE = dyn_cast<InsertElementInst>(CurrIE->getOperand(0)); + FirstIE = CurrIE; + CurrIE = NextIE; } // Make sure we've seen an insert into every element. @@ -636,9 +632,14 @@ static Instruction *foldInsSequenceIntoBroadcast(InsertElementInst &InsElt) { return nullptr; // All right, create the insert + shuffle. - Instruction *InsertFirst = InsertElementInst::Create( - UndefValue::get(VT), SplatVal, - ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), "", &InsElt); + Instruction *InsertFirst; + if (cast<ConstantInt>(FirstIE->getOperand(2))->isZero()) + InsertFirst = FirstIE; + else + InsertFirst = InsertElementInst::Create( + UndefValue::get(VT), SplatVal, + ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), + "", &InsElt); Constant *ZeroMask = ConstantAggregateZero::get( VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements)); @@ -780,6 +781,10 @@ Instruction *InstCombiner::visitInsertElementInst(InsertElementInst &IE) { Value *ScalarOp = IE.getOperand(1); Value *IdxOp = IE.getOperand(2); + if (auto *V = SimplifyInsertElementInst( + VecOp, ScalarOp, IdxOp, SQ.getWithInstruction(&IE))) + return replaceInstUsesWith(IE, V); + // Inserting an undef or into an undefined place, remove this. if (isa<UndefValue>(ScalarOp) || isa<UndefValue>(IdxOp)) replaceInstUsesWith(IE, VecOp); @@ -1007,15 +1012,13 @@ InstCombiner::EvaluateInDifferentElementOrder(Value *V, ArrayRef<int> Mask) { // Mask.size() does not need to be equal to the number of vector elements. assert(V->getType()->isVectorTy() && "can't reorder non-vector elements"); - if (isa<UndefValue>(V)) { - return UndefValue::get(VectorType::get(V->getType()->getScalarType(), - Mask.size())); - } - if (isa<ConstantAggregateZero>(V)) { - return ConstantAggregateZero::get( - VectorType::get(V->getType()->getScalarType(), - Mask.size())); - } + Type *EltTy = V->getType()->getScalarType(); + if (isa<UndefValue>(V)) + return UndefValue::get(VectorType::get(EltTy, Mask.size())); + + if (isa<ConstantAggregateZero>(V)) + return ConstantAggregateZero::get(VectorType::get(EltTy, Mask.size())); + if (Constant *C = dyn_cast<Constant>(V)) { SmallVector<Constant *, 16> MaskValues; for (int i = 0, e = Mask.size(); i != e; ++i) { @@ -1153,9 +1156,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { if (Value *V = SimplifyDemandedVectorElts(&SVI, AllOnesEltMask, UndefElts)) { if (V != &SVI) return replaceInstUsesWith(SVI, V); - LHS = SVI.getOperand(0); - RHS = SVI.getOperand(1); - MadeChange = true; + return &SVI; } unsigned LHSWidth = LHS->getType()->getVectorNumElements(); @@ -1446,7 +1447,7 @@ Instruction *InstCombiner::visitShuffleVectorInst(ShuffleVectorInst &SVI) { eltMask = Mask[i]-LHSWidth; // If LHS's width is changed, shift the mask value accordingly. - // If newRHS == NULL, i.e. LHSOp0 == RHSOp0, we want to remap any + // If newRHS == nullptr, i.e. LHSOp0 == RHSOp0, we want to remap any // references from RHSOp0 to LHSOp0, so we don't need to shift the mask. // If newRHS == newLHS, we want to remap any references from newRHS to // newLHS so that we can properly identify splats that may occur due to diff --git a/lib/Transforms/InstCombine/InstructionCombining.cpp b/lib/Transforms/InstCombine/InstructionCombining.cpp index c7766568fd9d..b332e75c7feb 100644 --- a/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -34,10 +34,14 @@ //===----------------------------------------------------------------------===// #include "InstCombineInternal.h" -#include "llvm-c/Initialization.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringSwitch.h" +#include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BasicAliasAnalysis.h" @@ -48,24 +52,56 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/CBindingWrapping.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/InstCombine/InstCombine.h" +#include "llvm/Transforms/InstCombine/InstCombineWorklist.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> -#include <climits> +#include <cassert> +#include <cstdint> +#include <memory> +#include <string> +#include <utility> + using namespace llvm; using namespace llvm::PatternMatch; @@ -78,6 +114,8 @@ STATISTIC(NumSunkInst , "Number of instructions sunk"); STATISTIC(NumExpand, "Number of expansions"); STATISTIC(NumFactor , "Number of factorizations"); STATISTIC(NumReassoc , "Number of reassociations"); +DEBUG_COUNTER(VisitCounter, "instcombine-visit", + "Controls which instructions are visited"); static cl::opt<bool> EnableExpensiveCombines("expensive-combines", @@ -87,6 +125,16 @@ static cl::opt<unsigned> MaxArraySize("instcombine-maxarray-size", cl::init(1024), cl::desc("Maximum array size considered when doing a combine")); +// FIXME: Remove this flag when it is no longer necessary to convert +// llvm.dbg.declare to avoid inaccurate debug info. Setting this to false +// increases variable availability at the cost of accuracy. Variables that +// cannot be promoted by mem2reg or SROA will be described as living in memory +// for their entire lifetime. However, passes like DSE and instcombine can +// delete stores to the alloca, leading to misleading and inaccurate debug +// information. This flag can be removed when those passes are fixed. +static cl::opt<unsigned> ShouldLowerDbgDeclare("instcombine-lower-dbg-declare", + cl::Hidden, cl::init(true)); + Value *InstCombiner::EmitGEPOffset(User *GEP) { return llvm::EmitGEPOffset(&Builder, DL, GEP); } @@ -381,7 +429,7 @@ bool InstCombiner::SimplifyAssociativeOrCommutative(BinaryOperator &I) { // No further simplifications. return Changed; - } while (1); + } while (true); } /// Return whether "X LOp (Y ROp Z)" is always equal to @@ -704,36 +752,37 @@ Value *InstCombiner::SimplifyUsingDistributiveLaws(BinaryOperator &I) { } } - // (op (select (a, c, b)), (select (a, d, b))) -> (select (a, (op c, d), 0)) - // (op (select (a, b, c)), (select (a, b, d))) -> (select (a, 0, (op c, d))) - if (auto *SI0 = dyn_cast<SelectInst>(LHS)) { - if (auto *SI1 = dyn_cast<SelectInst>(RHS)) { - if (SI0->getCondition() == SI1->getCondition()) { - Value *SI = nullptr; - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect(SI0->getCondition(), - Builder.CreateBinOp(TopLevelOpcode, - SI0->getTrueValue(), - SI1->getTrueValue()), - V); - if (Value *V = - SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(), - SI1->getTrueValue(), SQ.getWithInstruction(&I))) - SI = Builder.CreateSelect( - SI0->getCondition(), V, - Builder.CreateBinOp(TopLevelOpcode, SI0->getFalseValue(), - SI1->getFalseValue())); - if (SI) { - SI->takeName(&I); - return SI; - } - } - } + return SimplifySelectsFeedingBinaryOp(I, LHS, RHS); +} + +Value *InstCombiner::SimplifySelectsFeedingBinaryOp(BinaryOperator &I, + Value *LHS, Value *RHS) { + Instruction::BinaryOps Opcode = I.getOpcode(); + // (op (select (a, b, c)), (select (a, d, e))) -> (select (a, (op b, d), (op + // c, e))) + Value *A, *B, *C, *D, *E; + Value *SI = nullptr; + if (match(LHS, m_Select(m_Value(A), m_Value(B), m_Value(C))) && + match(RHS, m_Select(m_Specific(A), m_Value(D), m_Value(E)))) { + bool SelectsHaveOneUse = LHS->hasOneUse() && RHS->hasOneUse(); + BuilderTy::FastMathFlagGuard Guard(Builder); + if (isa<FPMathOperator>(&I)) + Builder.setFastMathFlags(I.getFastMathFlags()); + + Value *V1 = SimplifyBinOp(Opcode, C, E, SQ.getWithInstruction(&I)); + Value *V2 = SimplifyBinOp(Opcode, B, D, SQ.getWithInstruction(&I)); + if (V1 && V2) + SI = Builder.CreateSelect(A, V2, V1); + else if (V2 && SelectsHaveOneUse) + SI = Builder.CreateSelect(A, V2, Builder.CreateBinOp(Opcode, C, E)); + else if (V1 && SelectsHaveOneUse) + SI = Builder.CreateSelect(A, Builder.CreateBinOp(Opcode, B, D), V1); + + if (SI) + SI->takeName(&I); } - return nullptr; + return SI; } /// Given a 'sub' instruction, return the RHS of the instruction if the LHS is a @@ -1158,7 +1207,7 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { // Parent - initially null, but after drilling down notes where Op came from. // In the example above, Parent is (Val, 0) when Op is M1, because M1 is the // 0'th operand of Val. - std::pair<Instruction*, unsigned> Parent; + std::pair<Instruction *, unsigned> Parent; // Set if the transform requires a descaling at deeper levels that doesn't // overflow. @@ -1168,7 +1217,6 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { int32_t logScale = Scale.exactLogBase2(); for (;; Op = Parent.first->getOperand(Parent.second)) { // Drill down - if (ConstantInt *CI = dyn_cast<ConstantInt>(Op)) { // If Op is a constant divisible by Scale then descale to the quotient. APInt Quotient(Scale), Remainder(Scale); // Init ensures right bitwidth. @@ -1183,7 +1231,6 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { } if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) { - if (BO->getOpcode() == Instruction::Mul) { // Multiplication. NoSignedWrap = BO->hasNoSignedWrap(); @@ -1358,7 +1405,7 @@ Value *InstCombiner::Descale(Value *Val, APInt Scale, bool &NoSignedWrap) { // Move up one level in the expression. assert(Ancestor->hasOneUse() && "Drilled down when more than one use!"); Ancestor = Ancestor->user_back(); - } while (1); + } while (true); } /// \brief Creates node of binary operation with the same attributes as the @@ -1605,7 +1652,6 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { // Combine Indices - If the source pointer to this getelementptr instruction // is a getelementptr instruction, combine the indices of the two // getelementptr instructions into a single instruction. - // if (GEPOperator *Src = dyn_cast<GEPOperator>(PtrOp)) { if (!shouldMergeGEPs(*cast<GEPOperator>(&GEP), *Src)) return nullptr; @@ -1630,7 +1676,6 @@ Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) { if (EndsWithSequential) { // Replace: gep (gep %P, long B), long A, ... // With: T = long A+B; gep %P, T, ... - // Value *SO1 = Src->getOperand(Src->getNumOperands()-1); Value *GO1 = GEP.getOperand(1); @@ -2053,8 +2098,6 @@ static bool isAllocSiteRemovable(Instruction *AI, return false; LLVM_FALLTHROUGH; } - case Intrinsic::dbg_declare: - case Intrinsic::dbg_value: case Intrinsic::invariant_start: case Intrinsic::invariant_end: case Intrinsic::lifetime_start: @@ -2090,6 +2133,16 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { // to null and free calls, delete the calls and replace the comparisons with // true or false as appropriate. SmallVector<WeakTrackingVH, 64> Users; + + // If we are removing an alloca with a dbg.declare, insert dbg.value calls + // before each store. + TinyPtrVector<DbgInfoIntrinsic *> DIIs; + std::unique_ptr<DIBuilder> DIB; + if (isa<AllocaInst>(MI)) { + DIIs = FindDbgAddrUses(&MI); + DIB.reset(new DIBuilder(*MI.getModule(), /*AllowUnresolved=*/false)); + } + if (isAllocSiteRemovable(&MI, Users, &TLI)) { for (unsigned i = 0, e = Users.size(); i != e; ++i) { // Lowering all @llvm.objectsize calls first because they may @@ -2122,6 +2175,9 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { } else if (isa<BitCastInst>(I) || isa<GetElementPtrInst>(I) || isa<AddrSpaceCastInst>(I)) { replaceInstUsesWith(*I, UndefValue::get(I->getType())); + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + for (auto *DII : DIIs) + ConvertDebugDeclareToDebugValue(DII, SI, *DIB); } eraseInstFromFunction(*I); } @@ -2133,6 +2189,10 @@ Instruction *InstCombiner::visitAllocSite(Instruction &MI) { InvokeInst::Create(F, II->getNormalDest(), II->getUnwindDest(), None, "", II->getParent()); } + + for (auto *DII : DIIs) + eraseInstFromFunction(*DII); + return eraseInstFromFunction(MI); } return nullptr; @@ -2195,7 +2255,6 @@ tryToMoveFreeBeforeNullTest(CallInst &FI) { return &FI; } - Instruction *InstCombiner::visitFree(CallInst &FI) { Value *Op = FI.getArgOperand(0); @@ -2258,10 +2317,9 @@ Instruction *InstCombiner::visitBranchInst(BranchInst &BI) { // If the condition is irrelevant, remove the use so that other // transforms on the condition become more effective. - if (BI.isConditional() && - BI.getSuccessor(0) == BI.getSuccessor(1) && - !isa<UndefValue>(BI.getCondition())) { - BI.setCondition(UndefValue::get(BI.getCondition()->getType())); + if (BI.isConditional() && !isa<ConstantInt>(BI.getCondition()) && + BI.getSuccessor(0) == BI.getSuccessor(1)) { + BI.setCondition(ConstantInt::getFalse(BI.getCondition()->getType())); return &BI; } @@ -2881,6 +2939,9 @@ bool InstCombiner::run() { continue; } + if (!DebugCounter::shouldExecute(VisitCounter)) + continue; + // Instruction isn't dead, see if we can constant propagate it. if (!I->use_empty() && (I->getNumOperands() == 0 || isa<Constant>(I->getOperand(0)))) { @@ -3027,7 +3088,6 @@ bool InstCombiner::run() { /// them to the worklist (this significantly speeds up instcombine on code where /// many instructions are dead or constant). Additionally, if we find a branch /// whose condition is a known constant, we only visit the reachable successors. -/// static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, SmallPtrSetImpl<BasicBlock *> &Visited, InstCombineWorklist &ICWorklist, @@ -3053,6 +3113,7 @@ static bool AddReachableCodeToWorklist(BasicBlock *BB, const DataLayout &DL, if (isInstructionTriviallyDead(Inst, TLI)) { ++NumDeadInst; DEBUG(dbgs() << "IC: DCE: " << *Inst << '\n'); + salvageDebugInfo(*Inst); Inst->eraseFromParent(); MadeIRChange = true; continue; @@ -3162,12 +3223,11 @@ static bool prepareICWorklistFromFunction(Function &F, const DataLayout &DL, return MadeIRChange; } -static bool -combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, - AliasAnalysis *AA, AssumptionCache &AC, - TargetLibraryInfo &TLI, DominatorTree &DT, - bool ExpensiveCombines = true, - LoopInfo *LI = nullptr) { +static bool combineInstructionsOverFunction( + Function &F, InstCombineWorklist &Worklist, AliasAnalysis *AA, + AssumptionCache &AC, TargetLibraryInfo &TLI, DominatorTree &DT, + OptimizationRemarkEmitter &ORE, bool ExpensiveCombines = true, + LoopInfo *LI = nullptr) { auto &DL = F.getParent()->getDataLayout(); ExpensiveCombines |= EnableExpensiveCombines; @@ -3177,27 +3237,27 @@ combineInstructionsOverFunction(Function &F, InstCombineWorklist &Worklist, F.getContext(), TargetFolder(DL), IRBuilderCallbackInserter([&Worklist, &AC](Instruction *I) { Worklist.Add(I); - - using namespace llvm::PatternMatch; if (match(I, m_Intrinsic<Intrinsic::assume>())) AC.registerAssumption(cast<CallInst>(I)); })); // Lower dbg.declare intrinsics otherwise their value may be clobbered // by instcombiner. - bool MadeIRChange = LowerDbgDeclare(F); + bool MadeIRChange = false; + if (ShouldLowerDbgDeclare) + MadeIRChange = LowerDbgDeclare(F); // Iterate while there is work to do. int Iteration = 0; - for (;;) { + while (true) { ++Iteration; DEBUG(dbgs() << "\n\nINSTCOMBINE ITERATION #" << Iteration << " on " << F.getName() << "\n"); MadeIRChange |= prepareICWorklistFromFunction(F, DL, &TLI, Worklist); - InstCombiner IC(Worklist, Builder, F.optForMinSize(), ExpensiveCombines, - AA, AC, TLI, DT, DL, LI); + InstCombiner IC(Worklist, Builder, F.optForMinSize(), ExpensiveCombines, AA, + AC, TLI, DT, ORE, DL, LI); IC.MaxArraySizeForCombine = MaxArraySize; if (!IC.run()) @@ -3212,11 +3272,12 @@ PreservedAnalyses InstCombinePass::run(Function &F, auto &AC = AM.getResult<AssumptionAnalysis>(F); auto &DT = AM.getResult<DominatorTreeAnalysis>(F); auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); auto *LI = AM.getCachedResult<LoopAnalysis>(F); - // FIXME: The AliasAnalysis is not yet supported in the new pass manager - if (!combineInstructionsOverFunction(F, Worklist, nullptr, AC, TLI, DT, + auto *AA = &AM.getResult<AAManager>(F); + if (!combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, ExpensiveCombines, LI)) // No changes, all analyses are preserved. return PreservedAnalyses::all(); @@ -3225,6 +3286,7 @@ PreservedAnalyses InstCombinePass::run(Function &F, PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); PA.preserve<AAManager>(); + PA.preserve<BasicAA>(); PA.preserve<GlobalsAA>(); return PA; } @@ -3235,6 +3297,7 @@ void InstructionCombiningPass::getAnalysisUsage(AnalysisUsage &AU) const { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<AAResultsWrapperPass>(); AU.addPreserved<BasicAAWrapperPass>(); @@ -3250,16 +3313,18 @@ bool InstructionCombiningPass::runOnFunction(Function &F) { auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); // Optional analyses. auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); auto *LI = LIWP ? &LIWP->getLoopInfo() : nullptr; - return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, + return combineInstructionsOverFunction(F, Worklist, AA, AC, TLI, DT, ORE, ExpensiveCombines, LI); } char InstructionCombiningPass::ID = 0; + INITIALIZE_PASS_BEGIN(InstructionCombiningPass, "instcombine", "Combine redundant instructions", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) @@ -3267,6 +3332,7 @@ INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(InstructionCombiningPass, "instcombine", "Combine redundant instructions", false, false) diff --git a/lib/Transforms/Instrumentation/AddressSanitizer.cpp b/lib/Transforms/Instrumentation/AddressSanitizer.cpp index f8d255273b2a..8328d4031941 100644 --- a/lib/Transforms/Instrumentation/AddressSanitizer.cpp +++ b/lib/Transforms/Instrumentation/AddressSanitizer.cpp @@ -1,4 +1,4 @@ -//===-- AddressSanitizer.cpp - memory error detector ------------*- C++ -*-===// +//===- AddressSanitizer.cpp - memory error detector -----------------------===// // // The LLVM Compiler Infrastructure // @@ -9,59 +9,81 @@ // // This file is a part of AddressSanitizer, an address sanity checker. // Details of the algorithm: -// http://code.google.com/p/address-sanitizer/wiki/AddressSanitizerAlgorithm +// https://github.com/google/sanitizers/wiki/AddressSanitizerAlgorithm // //===----------------------------------------------------------------------===// #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" -#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/BinaryFormat/MachO.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Comdat.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" #include "llvm/MC/MCSectionMachO.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/DataTypes.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/Endian.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/ScopedPrinter.h" -#include "llvm/Support/SwapByteOrder.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/ASanStackFrameLayout.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> #include <iomanip> #include <limits> +#include <memory> #include <sstream> #include <string> -#include <system_error> +#include <tuple> using namespace llvm; @@ -70,21 +92,25 @@ using namespace llvm; static const uint64_t kDefaultShadowScale = 3; static const uint64_t kDefaultShadowOffset32 = 1ULL << 29; static const uint64_t kDefaultShadowOffset64 = 1ULL << 44; -static const uint64_t kDynamicShadowSentinel = ~(uint64_t)0; +static const uint64_t kDynamicShadowSentinel = + std::numeric_limits<uint64_t>::max(); static const uint64_t kIOSShadowOffset32 = 1ULL << 30; static const uint64_t kIOSSimShadowOffset32 = 1ULL << 30; static const uint64_t kIOSSimShadowOffset64 = kDefaultShadowOffset64; -static const uint64_t kSmallX86_64ShadowOffset = 0x7FFF8000; // < 2G. +static const uint64_t kSmallX86_64ShadowOffsetBase = 0x7FFFFFFF; // < 2G. +static const uint64_t kSmallX86_64ShadowOffsetAlignMask = ~0xFFFULL; static const uint64_t kLinuxKasan_ShadowOffset64 = 0xdffffc0000000000; -static const uint64_t kPPC64_ShadowOffset64 = 1ULL << 41; +static const uint64_t kPPC64_ShadowOffset64 = 1ULL << 44; static const uint64_t kSystemZ_ShadowOffset64 = 1ULL << 52; static const uint64_t kMIPS32_ShadowOffset32 = 0x0aaa0000; static const uint64_t kMIPS64_ShadowOffset64 = 1ULL << 37; static const uint64_t kAArch64_ShadowOffset64 = 1ULL << 36; static const uint64_t kFreeBSD_ShadowOffset32 = 1ULL << 30; static const uint64_t kFreeBSD_ShadowOffset64 = 1ULL << 46; +static const uint64_t kNetBSD_ShadowOffset64 = 1ULL << 46; static const uint64_t kPS4CPU_ShadowOffset64 = 1ULL << 40; static const uint64_t kWindowsShadowOffset32 = 3ULL << 28; + // The shadow memory space is dynamically allocated. static const uint64_t kWindowsShadowOffset64 = kDynamicShadowSentinel; @@ -111,8 +137,8 @@ static const char *const kAsanUnregisterElfGlobalsName = static const char *const kAsanPoisonGlobalsName = "__asan_before_dynamic_init"; static const char *const kAsanUnpoisonGlobalsName = "__asan_after_dynamic_init"; static const char *const kAsanInitName = "__asan_init"; -static const char *const kAsanVersionCheckName = - "__asan_version_mismatch_check_v8"; +static const char *const kAsanVersionCheckNamePrefix = + "__asan_version_mismatch_check_v"; static const char *const kAsanPtrCmp = "__sanitizer_ptr_cmp"; static const char *const kAsanPtrSub = "__sanitizer_ptr_sub"; static const char *const kAsanHandleNoReturnName = "__asan_handle_no_return"; @@ -148,9 +174,11 @@ static const size_t kNumberOfAccessSizes = 5; static const unsigned kAllocaRzSize = 32; // Command-line flags. + static cl::opt<bool> ClEnableKasan( "asan-kernel", cl::desc("Enable KernelAddressSanitizer instrumentation"), cl::Hidden, cl::init(false)); + static cl::opt<bool> ClRecover( "asan-recover", cl::desc("Enable recovery mode (continue-after-error)."), @@ -160,22 +188,38 @@ static cl::opt<bool> ClRecover( static cl::opt<bool> ClInstrumentReads("asan-instrument-reads", cl::desc("instrument read instructions"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClInstrumentWrites( "asan-instrument-writes", cl::desc("instrument write instructions"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClInstrumentAtomics( "asan-instrument-atomics", cl::desc("instrument atomic instructions (rmw, cmpxchg)"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClAlwaysSlowPath( "asan-always-slow-path", cl::desc("use instrumentation with slow path for all accesses"), cl::Hidden, cl::init(false)); + static cl::opt<bool> ClForceDynamicShadow( "asan-force-dynamic-shadow", cl::desc("Load shadow address into a local variable for each function"), cl::Hidden, cl::init(false)); +static cl::opt<bool> + ClWithIfunc("asan-with-ifunc", + cl::desc("Access dynamic shadow through an ifunc global on " + "platforms that support this"), + cl::Hidden, cl::init(true)); + +static cl::opt<bool> ClWithIfuncSuppressRemat( + "asan-with-ifunc-suppress-remat", + cl::desc("Suppress rematerialization of dynamic shadow address by passing " + "it through inline asm in prologue."), + cl::Hidden, cl::init(true)); + // This flag limits the number of instructions to be instrumented // in any given BB. Normally, this should be set to unlimited (INT_MAX), // but due to http://llvm.org/bugs/show_bug.cgi?id=12652 we temporary @@ -184,6 +228,7 @@ static cl::opt<int> ClMaxInsnsToInstrumentPerBB( "asan-max-ins-per-bb", cl::init(10000), cl::desc("maximal number of instructions to instrument in any given BB"), cl::Hidden); + // This flag may need to be replaced with -f[no]asan-stack. static cl::opt<bool> ClStack("asan-stack", cl::desc("Handle stack memory"), cl::Hidden, cl::init(true)); @@ -192,32 +237,40 @@ static cl::opt<uint32_t> ClMaxInlinePoisoningSize( cl::desc( "Inline shadow poisoning for blocks up to the given size in bytes."), cl::Hidden, cl::init(64)); + static cl::opt<bool> ClUseAfterReturn("asan-use-after-return", cl::desc("Check stack-use-after-return"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClRedzoneByvalArgs("asan-redzone-byval-args", cl::desc("Create redzones for byval " "arguments (extra copy " "required)"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClUseAfterScope("asan-use-after-scope", cl::desc("Check stack-use-after-scope"), cl::Hidden, cl::init(false)); + // This flag may need to be replaced with -f[no]asan-globals. static cl::opt<bool> ClGlobals("asan-globals", cl::desc("Handle global objects"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClInitializers("asan-initialization-order", cl::desc("Handle C++ initializer order"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClInvalidPointerPairs( "asan-detect-invalid-pointer-pair", cl::desc("Instrument <, <=, >, >=, - with pointer operands"), cl::Hidden, cl::init(false)); + static cl::opt<unsigned> ClRealignStack( "asan-realign-stack", cl::desc("Realign stack to the value of this flag (power of two)"), cl::Hidden, cl::init(32)); + static cl::opt<int> ClInstrumentationWithCallsThreshold( "asan-instrumentation-with-call-threshold", cl::desc( @@ -225,14 +278,17 @@ static cl::opt<int> ClInstrumentationWithCallsThreshold( "this number of memory accesses, use callbacks instead of " "inline checks (-1 means never use callbacks)."), cl::Hidden, cl::init(7000)); + static cl::opt<std::string> ClMemoryAccessCallbackPrefix( "asan-memory-access-callback-prefix", cl::desc("Prefix for memory access callbacks"), cl::Hidden, cl::init("__asan_")); + static cl::opt<bool> ClInstrumentDynamicAllocas("asan-instrument-dynamic-allocas", cl::desc("instrument dynamic allocas"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClSkipPromotableAllocas( "asan-skip-promotable-allocas", cl::desc("Do not instrument promotable allocas"), cl::Hidden, @@ -241,9 +297,11 @@ static cl::opt<bool> ClSkipPromotableAllocas( // These flags allow to change the shadow mapping. // The shadow mapping looks like // Shadow = (Mem >> scale) + offset + static cl::opt<int> ClMappingScale("asan-mapping-scale", cl::desc("scale of asan shadow mapping"), cl::Hidden, cl::init(0)); + static cl::opt<unsigned long long> ClMappingOffset( "asan-mapping-offset", cl::desc("offset of asan shadow mapping [EXPERIMENTAL]"), cl::Hidden, @@ -251,14 +309,18 @@ static cl::opt<unsigned long long> ClMappingOffset( // Optimization flags. Not user visible, used mostly for testing // and benchmarking the tool. + static cl::opt<bool> ClOpt("asan-opt", cl::desc("Optimize instrumentation"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClOptSameTemp( "asan-opt-same-temp", cl::desc("Instrument the same temp just once"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClOptGlobals("asan-opt-globals", cl::desc("Don't instrument scalar globals"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClOptStack( "asan-opt-stack", cl::desc("Don't instrument scalar stack variables"), cl::Hidden, cl::init(false)); @@ -293,14 +355,19 @@ static cl::opt<bool> cl::Hidden, cl::init(true)); // Debug flags. + static cl::opt<int> ClDebug("asan-debug", cl::desc("debug"), cl::Hidden, cl::init(0)); + static cl::opt<int> ClDebugStack("asan-debug-stack", cl::desc("debug stack"), cl::Hidden, cl::init(0)); + static cl::opt<std::string> ClDebugFunc("asan-debug-func", cl::Hidden, cl::desc("Debug func")); + static cl::opt<int> ClDebugMin("asan-debug-min", cl::desc("Debug min inst"), cl::Hidden, cl::init(-1)); + static cl::opt<int> ClDebugMax("asan-debug-max", cl::desc("Debug max inst"), cl::Hidden, cl::init(-1)); @@ -312,13 +379,14 @@ STATISTIC(NumOptimizedAccessesToStackVar, "Number of optimized accesses to stack vars"); namespace { + /// Frontend-provided metadata for source location. struct LocationMetadata { StringRef Filename; - int LineNo; - int ColumnNo; + int LineNo = 0; + int ColumnNo = 0; - LocationMetadata() : Filename(), LineNo(0), ColumnNo(0) {} + LocationMetadata() = default; bool empty() const { return Filename.empty(); } @@ -335,16 +403,17 @@ struct LocationMetadata { /// Frontend-provided metadata for global variables. class GlobalsMetadata { - public: +public: struct Entry { - Entry() : SourceLoc(), Name(), IsDynInit(false), IsBlacklisted(false) {} LocationMetadata SourceLoc; StringRef Name; - bool IsDynInit; - bool IsBlacklisted; + bool IsDynInit = false; + bool IsBlacklisted = false; + + Entry() = default; }; - GlobalsMetadata() : inited_(false) {} + GlobalsMetadata() = default; void reset() { inited_ = false; @@ -384,46 +453,57 @@ class GlobalsMetadata { return (Pos != Entries.end()) ? Pos->second : Entry(); } - private: - bool inited_; +private: + bool inited_ = false; DenseMap<GlobalVariable *, Entry> Entries; }; /// This struct defines the shadow mapping using the rule: /// shadow = (mem >> Scale) ADD-or-OR Offset. +/// If InGlobal is true, then +/// extern char __asan_shadow[]; +/// shadow = (mem >> Scale) + &__asan_shadow struct ShadowMapping { int Scale; uint64_t Offset; bool OrShadowOffset; + bool InGlobal; }; +} // end anonymous namespace + static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, bool IsKasan) { bool IsAndroid = TargetTriple.isAndroid(); bool IsIOS = TargetTriple.isiOS() || TargetTriple.isWatchOS(); bool IsFreeBSD = TargetTriple.isOSFreeBSD(); + bool IsNetBSD = TargetTriple.isOSNetBSD(); bool IsPS4CPU = TargetTriple.isPS4CPU(); bool IsLinux = TargetTriple.isOSLinux(); - bool IsPPC64 = TargetTriple.getArch() == llvm::Triple::ppc64 || - TargetTriple.getArch() == llvm::Triple::ppc64le; - bool IsSystemZ = TargetTriple.getArch() == llvm::Triple::systemz; - bool IsX86 = TargetTriple.getArch() == llvm::Triple::x86; - bool IsX86_64 = TargetTriple.getArch() == llvm::Triple::x86_64; - bool IsMIPS32 = TargetTriple.getArch() == llvm::Triple::mips || - TargetTriple.getArch() == llvm::Triple::mipsel; - bool IsMIPS64 = TargetTriple.getArch() == llvm::Triple::mips64 || - TargetTriple.getArch() == llvm::Triple::mips64el; - bool IsAArch64 = TargetTriple.getArch() == llvm::Triple::aarch64; + bool IsPPC64 = TargetTriple.getArch() == Triple::ppc64 || + TargetTriple.getArch() == Triple::ppc64le; + bool IsSystemZ = TargetTriple.getArch() == Triple::systemz; + bool IsX86 = TargetTriple.getArch() == Triple::x86; + bool IsX86_64 = TargetTriple.getArch() == Triple::x86_64; + bool IsMIPS32 = TargetTriple.getArch() == Triple::mips || + TargetTriple.getArch() == Triple::mipsel; + bool IsMIPS64 = TargetTriple.getArch() == Triple::mips64 || + TargetTriple.getArch() == Triple::mips64el; + bool IsArmOrThumb = TargetTriple.isARM() || TargetTriple.isThumb(); + bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64; bool IsWindows = TargetTriple.isOSWindows(); bool IsFuchsia = TargetTriple.isOSFuchsia(); ShadowMapping Mapping; + Mapping.Scale = kDefaultShadowScale; + if (ClMappingScale.getNumOccurrences() > 0) { + Mapping.Scale = ClMappingScale; + } + if (LongSize == 32) { - // Android is always PIE, which means that the beginning of the address - // space is always available. if (IsAndroid) - Mapping.Offset = 0; + Mapping.Offset = kDynamicShadowSentinel; else if (IsMIPS32) Mapping.Offset = kMIPS32_ShadowOffset32; else if (IsFreeBSD) @@ -446,13 +526,16 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, Mapping.Offset = kSystemZ_ShadowOffset64; else if (IsFreeBSD) Mapping.Offset = kFreeBSD_ShadowOffset64; + else if (IsNetBSD) + Mapping.Offset = kNetBSD_ShadowOffset64; else if (IsPS4CPU) Mapping.Offset = kPS4CPU_ShadowOffset64; else if (IsLinux && IsX86_64) { if (IsKasan) Mapping.Offset = kLinuxKasan_ShadowOffset64; else - Mapping.Offset = kSmallX86_64ShadowOffset; + Mapping.Offset = (kSmallX86_64ShadowOffsetBase & + (kSmallX86_64ShadowOffsetAlignMask << Mapping.Scale)); } else if (IsWindows && IsX86_64) { Mapping.Offset = kWindowsShadowOffset64; } else if (IsMIPS64) @@ -472,11 +555,6 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, Mapping.Offset = kDynamicShadowSentinel; } - Mapping.Scale = kDefaultShadowScale; - if (ClMappingScale.getNumOccurrences() > 0) { - Mapping.Scale = ClMappingScale; - } - if (ClMappingOffset.getNumOccurrences() > 0) { Mapping.Offset = ClMappingOffset; } @@ -489,6 +567,9 @@ static ShadowMapping getShadowMapping(Triple &TargetTriple, int LongSize, Mapping.OrShadowOffset = !IsAArch64 && !IsPPC64 && !IsSystemZ && !IsPS4CPU && !(Mapping.Offset & (Mapping.Offset - 1)) && Mapping.Offset != kDynamicShadowSentinel; + bool IsAndroidWithIfuncSupport = + IsAndroid && !TargetTriple.isAndroidVersionLT(21); + Mapping.InGlobal = ClWithIfunc && IsAndroidWithIfuncSupport && IsArmOrThumb; return Mapping; } @@ -499,23 +580,30 @@ static size_t RedzoneSizeForScale(int MappingScale) { return std::max(32U, 1U << MappingScale); } +namespace { + /// AddressSanitizer: instrument the code in module to find memory bugs. struct AddressSanitizer : public FunctionPass { + // Pass identification, replacement for typeid + static char ID; + explicit AddressSanitizer(bool CompileKernel = false, bool Recover = false, bool UseAfterScope = false) : FunctionPass(ID), CompileKernel(CompileKernel || ClEnableKasan), Recover(Recover || ClRecover), - UseAfterScope(UseAfterScope || ClUseAfterScope), - LocalDynamicShadow(nullptr) { + UseAfterScope(UseAfterScope || ClUseAfterScope) { initializeAddressSanitizerPass(*PassRegistry::getPassRegistry()); } + StringRef getPassName() const override { return "AddressSanitizerFunctionPass"; } + void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetLibraryInfoWrapperPass>(); } + uint64_t getAllocaSizeInBytes(const AllocaInst &AI) const { uint64_t ArraySize = 1; if (AI.isArrayAllocation()) { @@ -528,6 +616,7 @@ struct AddressSanitizer : public FunctionPass { AI.getModule()->getDataLayout().getTypeAllocSize(Ty); return SizeInBytes * ArraySize; } + /// Check if we want (and can) handle this alloca. bool isInterestingAlloca(const AllocaInst &AI); @@ -538,6 +627,7 @@ struct AddressSanitizer : public FunctionPass { Value *isInterestingMemoryAccess(Instruction *I, bool *IsWrite, uint64_t *TypeSize, unsigned *Alignment, Value **MaybeMask = nullptr); + void instrumentMop(ObjectSizeOffsetVisitor &ObjSizeVis, Instruction *I, bool UseCalls, const DataLayout &DL); void instrumentPointerComparisonOrSubtraction(Instruction *I); @@ -562,11 +652,12 @@ struct AddressSanitizer : public FunctionPass { void markEscapedLocalAllocas(Function &F); bool doInitialization(Module &M) override; bool doFinalization(Module &M) override; - static char ID; // Pass identification, replacement for typeid DominatorTree &getDominatorTree() const { return *DT; } - private: +private: + friend struct FunctionStackPoisoner; + void initializeCallbacks(Module &M); bool LooksLikeCodeInBug11395(Instruction *I); @@ -577,11 +668,13 @@ struct AddressSanitizer : public FunctionPass { /// Helper to cleanup per-function state. struct FunctionStateRAII { AddressSanitizer *Pass; + FunctionStateRAII(AddressSanitizer *Pass) : Pass(Pass) { assert(Pass->ProcessedAllocas.empty() && "last pass forgot to clear cache"); assert(!Pass->LocalDynamicShadow); } + ~FunctionStateRAII() { Pass->LocalDynamicShadow = nullptr; Pass->ProcessedAllocas.clear(); @@ -599,23 +692,28 @@ struct AddressSanitizer : public FunctionPass { DominatorTree *DT; Function *AsanHandleNoReturnFunc; Function *AsanPtrCmpFunction, *AsanPtrSubFunction; - // This array is indexed by AccessIsWrite, Experiment and log2(AccessSize). + Constant *AsanShadowGlobal; + + // These arrays is indexed by AccessIsWrite, Experiment and log2(AccessSize). Function *AsanErrorCallback[2][2][kNumberOfAccessSizes]; Function *AsanMemoryAccessCallback[2][2][kNumberOfAccessSizes]; - // This array is indexed by AccessIsWrite and Experiment. + + // These arrays is indexed by AccessIsWrite and Experiment. Function *AsanErrorCallbackSized[2][2]; Function *AsanMemoryAccessCallbackSized[2][2]; + Function *AsanMemmove, *AsanMemcpy, *AsanMemset; InlineAsm *EmptyAsm; - Value *LocalDynamicShadow; + Value *LocalDynamicShadow = nullptr; GlobalsMetadata GlobalsMD; DenseMap<const AllocaInst *, bool> ProcessedAllocas; - - friend struct FunctionStackPoisoner; }; class AddressSanitizerModule : public ModulePass { public: + // Pass identification, replacement for typeid + static char ID; + explicit AddressSanitizerModule(bool CompileKernel = false, bool Recover = false, bool UseGlobalsGC = true) @@ -630,8 +728,8 @@ public: // ClWithComdat and ClUseGlobalsGC unless the frontend says it's ok to // do globals-gc. UseCtorComdat(UseGlobalsGC && ClWithComdat) {} + bool runOnModule(Module &M) override; - static char ID; // Pass identification, replacement for typeid StringRef getPassName() const override { return "AddressSanitizerModule"; } private: @@ -667,6 +765,7 @@ private: size_t MinRedzoneSizeForGlobal() const { return RedzoneSizeForScale(Mapping.Scale); } + int GetAsanVersion(const Module &M) const; GlobalsMetadata GlobalsMD; bool CompileKernel; @@ -735,7 +834,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { IntrinsicInst *LocalEscapeCall = nullptr; // Maps Value to an AllocaInst from which the Value is originated. - typedef DenseMap<Value *, AllocaInst *> AllocaForValueMapTy; + using AllocaForValueMapTy = DenseMap<Value *, AllocaInst *>; AllocaForValueMapTy AllocaForValue; bool HasNonEmptyInlineAsm = false; @@ -756,7 +855,7 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { bool runOnFunction() { if (!ClStack) return false; - if (ClRedzoneByvalArgs && Mapping.Offset != kDynamicShadowSentinel) + if (ClRedzoneByvalArgs) copyArgsPassedByValToAllocas(); // Collect alloca, ret, lifetime instructions etc. @@ -899,8 +998,9 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { void visitCallSite(CallSite CS) { Instruction *I = CS.getInstruction(); if (CallInst *CI = dyn_cast<CallInst>(I)) { - HasNonEmptyInlineAsm |= - CI->isInlineAsm() && !CI->isIdenticalTo(EmptyInlineAsm.get()); + HasNonEmptyInlineAsm |= CI->isInlineAsm() && + !CI->isIdenticalTo(EmptyInlineAsm.get()) && + I != ASan.LocalDynamicShadow; HasReturnsTwiceCall |= CI->canReturnTwice(); } } @@ -938,9 +1038,10 @@ struct FunctionStackPoisoner : public InstVisitor<FunctionStackPoisoner> { Instruction *ThenTerm, Value *ValueIfFalse); }; -} // anonymous namespace +} // end anonymous namespace char AddressSanitizer::ID = 0; + INITIALIZE_PASS_BEGIN( AddressSanitizer, "asan", "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, @@ -951,6 +1052,7 @@ INITIALIZE_PASS_END( AddressSanitizer, "asan", "AddressSanitizer: detects use-after-free and out-of-bounds bugs.", false, false) + FunctionPass *llvm::createAddressSanitizerFunctionPass(bool CompileKernel, bool Recover, bool UseAfterScope) { @@ -959,11 +1061,13 @@ FunctionPass *llvm::createAddressSanitizerFunctionPass(bool CompileKernel, } char AddressSanitizerModule::ID = 0; + INITIALIZE_PASS( AddressSanitizerModule, "asan-module", "AddressSanitizer: detects use-after-free and out-of-bounds bugs." "ModulePass", false, false) + ModulePass *llvm::createAddressSanitizerModulePass(bool CompileKernel, bool Recover, bool UseGlobalsGC) { @@ -1544,7 +1648,7 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { // Callbacks put into the CRT initializer/terminator sections // should not be instrumented. - // See https://code.google.com/p/address-sanitizer/issues/detail?id=305 + // See https://github.com/google/sanitizers/issues/305 // and http://msdn.microsoft.com/en-US/en-en/library/bb918180(v=vs.120).aspx if (Section.startswith(".CRT")) { DEBUG(dbgs() << "Ignoring a global initializer callback: " << *G << "\n"); @@ -1567,7 +1671,7 @@ bool AddressSanitizerModule::ShouldInstrumentGlobal(GlobalVariable *G) { DEBUG(dbgs() << "Ignoring ObjC runtime global: " << *G << "\n"); return false; } - // See http://code.google.com/p/address-sanitizer/issues/detail?id=32 + // See https://github.com/google/sanitizers/issues/32 // Constant CFString instances are compiled in the following way: // -- the string buffer is emitted into // __TEXT,__cstring,cstring_literals @@ -1683,9 +1787,14 @@ void AddressSanitizerModule::SetComdatForGlobalMetadata( C = M.getOrInsertComdat(G->getName()); } - // Make this IMAGE_COMDAT_SELECT_NODUPLICATES on COFF. - if (TargetTriple.isOSBinFormatCOFF()) + // Make this IMAGE_COMDAT_SELECT_NODUPLICATES on COFF. Also upgrade private + // linkage to internal linkage so that a symbol table entry is emitted. This + // is necessary in order to create the comdat group. + if (TargetTriple.isOSBinFormatCOFF()) { C->setSelectionKind(Comdat::NoDuplicates); + if (G->hasPrivateLinkage()) + G->setLinkage(GlobalValue::InternalLinkage); + } G->setComdat(C); } @@ -1871,6 +1980,8 @@ void AddressSanitizerModule::InstrumentGlobalsWithMetadataArray( auto AllGlobals = new GlobalVariable( M, ArrayOfGlobalStructTy, false, GlobalVariable::InternalLinkage, ConstantArray::get(ArrayOfGlobalStructTy, MetadataInitializers), ""); + if (Mapping.Scale > 3) + AllGlobals->setAlignment(1ULL << Mapping.Scale); IRB.CreateCall(AsanRegisterGlobals, {IRB.CreatePointerCast(AllGlobals, IntptrTy), @@ -2070,6 +2181,16 @@ bool AddressSanitizerModule::InstrumentGlobals(IRBuilder<> &IRB, Module &M, bool return true; } +int AddressSanitizerModule::GetAsanVersion(const Module &M) const { + int LongSize = M.getDataLayout().getPointerSizeInBits(); + bool isAndroid = Triple(M.getTargetTriple()).isAndroid(); + int Version = 8; + // 32-bit Android is one version ahead because of the switch to dynamic + // shadow. + Version += (LongSize == 32 && isAndroid); + return Version; +} + bool AddressSanitizerModule::runOnModule(Module &M) { C = &(M.getContext()); int LongSize = M.getDataLayout().getPointerSizeInBits(); @@ -2083,9 +2204,11 @@ bool AddressSanitizerModule::runOnModule(Module &M) { // Create a module constructor. A destructor is created lazily because not all // platforms, and not all modules need it. + std::string VersionCheckName = + kAsanVersionCheckNamePrefix + std::to_string(GetAsanVersion(M)); std::tie(AsanCtorFunction, std::ignore) = createSanitizerCtorAndInitFunctions( M, kAsanModuleCtorName, kAsanInitName, /*InitArgTypes=*/{}, - /*InitArgs=*/{}, kAsanVersionCheckName); + /*InitArgs=*/{}, VersionCheckName); bool CtorComdat = true; bool Changed = false; @@ -2134,31 +2257,31 @@ void AddressSanitizer::initializeCallbacks(Module &M) { Args2.push_back(ExpType); Args1.push_back(ExpType); } - AsanErrorCallbackSized[AccessIsWrite][Exp] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + - EndingStr, - FunctionType::get(IRB.getVoidTy(), Args2, false))); - - AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args2, false))); - - for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; - AccessSizeIndex++) { - const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); - AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args1, false))); - - AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = - checkSanitizerInterfaceFunction(M.getOrInsertFunction( - ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, - FunctionType::get(IRB.getVoidTy(), Args1, false))); - } - } + AsanErrorCallbackSized[AccessIsWrite][Exp] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + kAsanReportErrorTemplate + ExpStr + TypeStr + SuffixStr + + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args2, false))); + + AsanMemoryAccessCallbackSized[AccessIsWrite][Exp] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + ExpStr + TypeStr + "N" + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args2, false))); + + for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; + AccessSizeIndex++) { + const std::string Suffix = TypeStr + itostr(1ULL << AccessSizeIndex); + AsanErrorCallback[AccessIsWrite][Exp][AccessSizeIndex] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + kAsanReportErrorTemplate + ExpStr + Suffix + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args1, false))); + + AsanMemoryAccessCallback[AccessIsWrite][Exp][AccessSizeIndex] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + ExpStr + Suffix + EndingStr, + FunctionType::get(IRB.getVoidTy(), Args1, false))); + } + } } const std::string MemIntrinCallbackPrefix = @@ -2184,6 +2307,9 @@ void AddressSanitizer::initializeCallbacks(Module &M) { EmptyAsm = InlineAsm::get(FunctionType::get(IRB.getVoidTy(), false), StringRef(""), StringRef(""), /*hasSideEffects=*/true); + if (Mapping.InGlobal) + AsanShadowGlobal = M.getOrInsertGlobal("__asan_shadow", + ArrayType::get(IRB.getInt8Ty(), 0)); } // virtual @@ -2229,9 +2355,25 @@ void AddressSanitizer::maybeInsertDynamicShadowAtFunctionEntry(Function &F) { return; IRBuilder<> IRB(&F.front().front()); - Value *GlobalDynamicAddress = F.getParent()->getOrInsertGlobal( - kAsanShadowMemoryDynamicAddress, IntptrTy); - LocalDynamicShadow = IRB.CreateLoad(GlobalDynamicAddress); + if (Mapping.InGlobal) { + if (ClWithIfuncSuppressRemat) { + // An empty inline asm with input reg == output reg. + // An opaque pointer-to-int cast, basically. + InlineAsm *Asm = InlineAsm::get( + FunctionType::get(IntptrTy, {AsanShadowGlobal->getType()}, false), + StringRef(""), StringRef("=r,0"), + /*hasSideEffects=*/false); + LocalDynamicShadow = + IRB.CreateCall(Asm, {AsanShadowGlobal}, ".asan.shadow"); + } else { + LocalDynamicShadow = + IRB.CreatePointerCast(AsanShadowGlobal, IntptrTy, ".asan.shadow"); + } + } else { + Value *GlobalDynamicAddress = F.getParent()->getOrInsertGlobal( + kAsanShadowMemoryDynamicAddress, IntptrTy); + LocalDynamicShadow = IRB.CreateLoad(GlobalDynamicAddress); + } } void AddressSanitizer::markEscapedLocalAllocas(Function &F) { @@ -2378,7 +2520,7 @@ bool AddressSanitizer::runOnFunction(Function &F) { bool ChangedStack = FSP.runOnFunction(); // We must unpoison the stack before every NoReturn call (throw, _exit, etc). - // See e.g. http://code.google.com/p/address-sanitizer/issues/detail?id=37 + // See e.g. https://github.com/google/sanitizers/issues/37 for (auto CI : NoReturnCalls) { IRBuilder<> IRB(CI); IRB.CreateCall(AsanHandleNoReturnFunc, {}); @@ -2546,8 +2688,13 @@ static int StackMallocSizeClass(uint64_t LocalStackSize) { } void FunctionStackPoisoner::copyArgsPassedByValToAllocas() { - BasicBlock &FirstBB = *F.begin(); - IRBuilder<> IRB(&FirstBB, FirstBB.getFirstInsertionPt()); + Instruction *CopyInsertPoint = &F.front().front(); + if (CopyInsertPoint == ASan.LocalDynamicShadow) { + // Insert after the dynamic shadow location is determined + CopyInsertPoint = CopyInsertPoint->getNextNode(); + assert(CopyInsertPoint); + } + IRBuilder<> IRB(CopyInsertPoint); const DataLayout &DL = F.getParent()->getDataLayout(); for (Argument &Arg : F.args()) { if (Arg.hasByValAttr()) { @@ -2674,9 +2821,10 @@ void FunctionStackPoisoner::processStaticAllocas() { // Minimal header size (left redzone) is 4 pointers, // i.e. 32 bytes on 64-bit platforms and 16 bytes in 32-bit platforms. - size_t MinHeaderSize = ASan.LongSize / 2; + size_t Granularity = 1ULL << Mapping.Scale; + size_t MinHeaderSize = std::max((size_t)ASan.LongSize / 2, Granularity); const ASanStackFrameLayout &L = - ComputeASanStackFrameLayout(SVD, 1ULL << Mapping.Scale, MinHeaderSize); + ComputeASanStackFrameLayout(SVD, Granularity, MinHeaderSize); // Build AllocaToSVDMap for ASanStackVariableDescription lookup. DenseMap<const AllocaInst *, ASanStackVariableDescription *> AllocaToSVDMap; @@ -2721,8 +2869,12 @@ void FunctionStackPoisoner::processStaticAllocas() { Value *FakeStack; Value *LocalStackBase; + Value *LocalStackBaseAlloca; + bool Deref; if (DoStackMalloc) { + LocalStackBaseAlloca = + IRB.CreateAlloca(IntptrTy, nullptr, "asan_local_stack_base"); // void *FakeStack = __asan_option_detect_stack_use_after_return // ? __asan_stack_malloc_N(LocalStackSize) // : nullptr; @@ -2753,24 +2905,31 @@ void FunctionStackPoisoner::processStaticAllocas() { IRBIf.SetCurrentDebugLocation(EntryDebugLocation); Value *AllocaValue = DoDynamicAlloca ? createAllocaForLayout(IRBIf, L, true) : StaticAlloca; + IRB.SetInsertPoint(InsBefore); IRB.SetCurrentDebugLocation(EntryDebugLocation); LocalStackBase = createPHI(IRB, NoFakeStack, AllocaValue, Term, FakeStack); + IRB.SetCurrentDebugLocation(EntryDebugLocation); + IRB.CreateStore(LocalStackBase, LocalStackBaseAlloca); + Deref = true; } else { // void *FakeStack = nullptr; // void *LocalStackBase = alloca(LocalStackSize); FakeStack = ConstantInt::get(IntptrTy, 0); LocalStackBase = DoDynamicAlloca ? createAllocaForLayout(IRB, L, true) : StaticAlloca; + LocalStackBaseAlloca = LocalStackBase; + Deref = false; } // Replace Alloca instructions with base+offset. for (const auto &Desc : SVD) { AllocaInst *AI = Desc.AI; + replaceDbgDeclareForAlloca(AI, LocalStackBaseAlloca, DIB, Deref, + Desc.Offset, DIExpression::NoDeref); Value *NewAllocaPtr = IRB.CreateIntToPtr( IRB.CreateAdd(LocalStackBase, ConstantInt::get(IntptrTy, Desc.Offset)), AI->getType()); - replaceDbgDeclareForAlloca(AI, NewAllocaPtr, DIB, DIExpression::NoDeref); AI->replaceAllUsesWith(NewAllocaPtr); } diff --git a/lib/Transforms/Instrumentation/BoundsChecking.cpp b/lib/Transforms/Instrumentation/BoundsChecking.cpp index a193efe902cf..be9a22a8681b 100644 --- a/lib/Transforms/Instrumentation/BoundsChecking.cpp +++ b/lib/Transforms/Instrumentation/BoundsChecking.cpp @@ -6,25 +6,33 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// -// -// This file implements a pass that instruments the code to perform run-time -// bounds checking on loads, stores, and other memory intrinsics. -// -//===----------------------------------------------------------------------===// +#include "llvm/Transforms/Instrumentation/BoundsChecking.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/TargetFolder.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Instrumentation.h" +#include <cstdint> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "bounds-checking" @@ -36,102 +44,30 @@ STATISTIC(ChecksAdded, "Bounds checks added"); STATISTIC(ChecksSkipped, "Bounds checks skipped"); STATISTIC(ChecksUnable, "Bounds checks unable to add"); -typedef IRBuilder<TargetFolder> BuilderTy; - -namespace { - struct BoundsChecking : public FunctionPass { - static char ID; - - BoundsChecking() : FunctionPass(ID) { - initializeBoundsCheckingPass(*PassRegistry::getPassRegistry()); - } - - bool runOnFunction(Function &F) override; - - void getAnalysisUsage(AnalysisUsage &AU) const override { - AU.addRequired<TargetLibraryInfoWrapperPass>(); - } - - private: - const TargetLibraryInfo *TLI; - ObjectSizeOffsetEvaluator *ObjSizeEval; - BuilderTy *Builder; - Instruction *Inst; - BasicBlock *TrapBB; - - BasicBlock *getTrapBB(); - void emitBranchToTrap(Value *Cmp = nullptr); - bool instrument(Value *Ptr, Value *Val, const DataLayout &DL); - }; -} - -char BoundsChecking::ID = 0; -INITIALIZE_PASS(BoundsChecking, "bounds-checking", "Run-time bounds checking", - false, false) - - -/// getTrapBB - create a basic block that traps. All overflowing conditions -/// branch to this block. There's only one trap block per function. -BasicBlock *BoundsChecking::getTrapBB() { - if (TrapBB && SingleTrapBB) - return TrapBB; - - Function *Fn = Inst->getParent()->getParent(); - IRBuilder<>::InsertPointGuard Guard(*Builder); - TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn); - Builder->SetInsertPoint(TrapBB); - - llvm::Value *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap); - CallInst *TrapCall = Builder->CreateCall(F, {}); - TrapCall->setDoesNotReturn(); - TrapCall->setDoesNotThrow(); - TrapCall->setDebugLoc(Inst->getDebugLoc()); - Builder->CreateUnreachable(); - - return TrapBB; -} - - -/// emitBranchToTrap - emit a branch instruction to a trap block. -/// If Cmp is non-null, perform a jump only if its value evaluates to true. -void BoundsChecking::emitBranchToTrap(Value *Cmp) { - // check if the comparison is always false - ConstantInt *C = dyn_cast_or_null<ConstantInt>(Cmp); - if (C) { - ++ChecksSkipped; - if (!C->getZExtValue()) - return; - else - Cmp = nullptr; // unconditional branch - } - ++ChecksAdded; - - BasicBlock::iterator Inst = Builder->GetInsertPoint(); - BasicBlock *OldBB = Inst->getParent(); - BasicBlock *Cont = OldBB->splitBasicBlock(Inst); - OldBB->getTerminator()->eraseFromParent(); - - if (Cmp) - BranchInst::Create(getTrapBB(), Cont, Cmp, OldBB); - else - BranchInst::Create(getTrapBB(), OldBB); -} - +using BuilderTy = IRBuilder<TargetFolder>; -/// instrument - adds run-time bounds checks to memory accessing instructions. -/// Ptr is the pointer that will be read/written, and InstVal is either the -/// result from the load or the value being stored. It is used to determine the -/// size of memory block that is touched. +/// Adds run-time bounds checks to memory accessing instructions. +/// +/// \p Ptr is the pointer that will be read/written, and \p InstVal is either +/// the result from the load or the value being stored. It is used to determine +/// the size of memory block that is touched. +/// +/// \p GetTrapBB is a callable that returns the trap BB to use on failure. +/// /// Returns true if any change was made to the IR, false otherwise. -bool BoundsChecking::instrument(Value *Ptr, Value *InstVal, - const DataLayout &DL) { +template <typename GetTrapBBT> +static bool instrumentMemAccess(Value *Ptr, Value *InstVal, + const DataLayout &DL, TargetLibraryInfo &TLI, + ObjectSizeOffsetEvaluator &ObjSizeEval, + BuilderTy &IRB, + GetTrapBBT GetTrapBB) { uint64_t NeededSize = DL.getTypeStoreSize(InstVal->getType()); DEBUG(dbgs() << "Instrument " << *Ptr << " for " << Twine(NeededSize) << " bytes\n"); - SizeOffsetEvalType SizeOffset = ObjSizeEval->compute(Ptr); + SizeOffsetEvalType SizeOffset = ObjSizeEval.compute(Ptr); - if (!ObjSizeEval->bothKnown(SizeOffset)) { + if (!ObjSizeEval.bothKnown(SizeOffset)) { ++ChecksUnable; return false; } @@ -150,56 +86,101 @@ bool BoundsChecking::instrument(Value *Ptr, Value *InstVal, // // optimization: if Size >= 0 (signed), skip 1st check // FIXME: add NSW/NUW here? -- we dont care if the subtraction overflows - Value *ObjSize = Builder->CreateSub(Size, Offset); - Value *Cmp2 = Builder->CreateICmpULT(Size, Offset); - Value *Cmp3 = Builder->CreateICmpULT(ObjSize, NeededSizeVal); - Value *Or = Builder->CreateOr(Cmp2, Cmp3); + Value *ObjSize = IRB.CreateSub(Size, Offset); + Value *Cmp2 = IRB.CreateICmpULT(Size, Offset); + Value *Cmp3 = IRB.CreateICmpULT(ObjSize, NeededSizeVal); + Value *Or = IRB.CreateOr(Cmp2, Cmp3); if (!SizeCI || SizeCI->getValue().slt(0)) { - Value *Cmp1 = Builder->CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0)); - Or = Builder->CreateOr(Cmp1, Or); + Value *Cmp1 = IRB.CreateICmpSLT(Offset, ConstantInt::get(IntTy, 0)); + Or = IRB.CreateOr(Cmp1, Or); + } + + // check if the comparison is always false + ConstantInt *C = dyn_cast_or_null<ConstantInt>(Or); + if (C) { + ++ChecksSkipped; + // If non-zero, nothing to do. + if (!C->getZExtValue()) + return true; + } + ++ChecksAdded; + + BasicBlock::iterator SplitI = IRB.GetInsertPoint(); + BasicBlock *OldBB = SplitI->getParent(); + BasicBlock *Cont = OldBB->splitBasicBlock(SplitI); + OldBB->getTerminator()->eraseFromParent(); + + if (C) { + // If we have a constant zero, unconditionally branch. + // FIXME: We should really handle this differently to bypass the splitting + // the block. + BranchInst::Create(GetTrapBB(IRB), OldBB); + return true; } - emitBranchToTrap(Or); + // Create the conditional branch. + BranchInst::Create(GetTrapBB(IRB), Cont, Or, OldBB); return true; } -bool BoundsChecking::runOnFunction(Function &F) { +static bool addBoundsChecking(Function &F, TargetLibraryInfo &TLI) { const DataLayout &DL = F.getParent()->getDataLayout(); - TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); - - TrapBB = nullptr; - BuilderTy TheBuilder(F.getContext(), TargetFolder(DL)); - Builder = &TheBuilder; - ObjectSizeOffsetEvaluator TheObjSizeEval(DL, TLI, F.getContext(), + ObjectSizeOffsetEvaluator ObjSizeEval(DL, &TLI, F.getContext(), /*RoundToAlign=*/true); - ObjSizeEval = &TheObjSizeEval; // check HANDLE_MEMORY_INST in include/llvm/Instruction.def for memory // touching instructions - std::vector<Instruction*> WorkList; - for (inst_iterator i = inst_begin(F), e = inst_end(F); i != e; ++i) { - Instruction *I = &*i; + std::vector<Instruction *> WorkList; + for (Instruction &I : instructions(F)) { if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<AtomicCmpXchgInst>(I) || isa<AtomicRMWInst>(I)) - WorkList.push_back(I); + WorkList.push_back(&I); } - bool MadeChange = false; - for (Instruction *i : WorkList) { - Inst = i; + // Create a trapping basic block on demand using a callback. Depending on + // flags, this will either create a single block for the entire function or + // will create a fresh block every time it is called. + BasicBlock *TrapBB = nullptr; + auto GetTrapBB = [&TrapBB](BuilderTy &IRB) { + if (TrapBB && SingleTrapBB) + return TrapBB; + + Function *Fn = IRB.GetInsertBlock()->getParent(); + // FIXME: This debug location doesn't make a lot of sense in the + // `SingleTrapBB` case. + auto DebugLoc = IRB.getCurrentDebugLocation(); + IRBuilder<>::InsertPointGuard Guard(IRB); + TrapBB = BasicBlock::Create(Fn->getContext(), "trap", Fn); + IRB.SetInsertPoint(TrapBB); + + auto *F = Intrinsic::getDeclaration(Fn->getParent(), Intrinsic::trap); + CallInst *TrapCall = IRB.CreateCall(F, {}); + TrapCall->setDoesNotReturn(); + TrapCall->setDoesNotThrow(); + TrapCall->setDebugLoc(DebugLoc); + IRB.CreateUnreachable(); + + return TrapBB; + }; - Builder->SetInsertPoint(Inst); + bool MadeChange = false; + for (Instruction *Inst : WorkList) { + BuilderTy IRB(Inst->getParent(), BasicBlock::iterator(Inst), TargetFolder(DL)); if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { - MadeChange |= instrument(LI->getPointerOperand(), LI, DL); + MadeChange |= instrumentMemAccess(LI->getPointerOperand(), LI, DL, TLI, + ObjSizeEval, IRB, GetTrapBB); } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { MadeChange |= - instrument(SI->getPointerOperand(), SI->getValueOperand(), DL); + instrumentMemAccess(SI->getPointerOperand(), SI->getValueOperand(), + DL, TLI, ObjSizeEval, IRB, GetTrapBB); } else if (AtomicCmpXchgInst *AI = dyn_cast<AtomicCmpXchgInst>(Inst)) { MadeChange |= - instrument(AI->getPointerOperand(), AI->getCompareOperand(), DL); + instrumentMemAccess(AI->getPointerOperand(), AI->getCompareOperand(), + DL, TLI, ObjSizeEval, IRB, GetTrapBB); } else if (AtomicRMWInst *AI = dyn_cast<AtomicRMWInst>(Inst)) { MadeChange |= - instrument(AI->getPointerOperand(), AI->getValOperand(), DL); + instrumentMemAccess(AI->getPointerOperand(), AI->getValOperand(), DL, + TLI, ObjSizeEval, IRB, GetTrapBB); } else { llvm_unreachable("unknown Instruction type"); } @@ -207,6 +188,41 @@ bool BoundsChecking::runOnFunction(Function &F) { return MadeChange; } -FunctionPass *llvm::createBoundsCheckingPass() { - return new BoundsChecking(); +PreservedAnalyses BoundsCheckingPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + + if (!addBoundsChecking(F, TLI)) + return PreservedAnalyses::all(); + + return PreservedAnalyses::none(); +} + +namespace { +struct BoundsCheckingLegacyPass : public FunctionPass { + static char ID; + + BoundsCheckingLegacyPass() : FunctionPass(ID) { + initializeBoundsCheckingLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return addBoundsChecking(F, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } +}; +} // namespace + +char BoundsCheckingLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(BoundsCheckingLegacyPass, "bounds-checking", + "Run-time bounds checking", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(BoundsCheckingLegacyPass, "bounds-checking", + "Run-time bounds checking", false, false) + +FunctionPass *llvm::createBoundsCheckingLegacyPass() { + return new BoundsCheckingLegacyPass(); } diff --git a/lib/Transforms/Instrumentation/CFGMST.h b/lib/Transforms/Instrumentation/CFGMST.h index 16e2e6b4e730..075e5672cff8 100644 --- a/lib/Transforms/Instrumentation/CFGMST.h +++ b/lib/Transforms/Instrumentation/CFGMST.h @@ -46,6 +46,10 @@ public: // This map records the auxiliary information for each BB. DenseMap<const BasicBlock *, std::unique_ptr<BBInfo>> BBInfos; + // Whehter the function has an exit block with no successors. + // (For function with an infinite loop, this block may be absent) + bool ExitBlockFound = false; + // Find the root group of the G and compress the path from G to the root. BBInfo *findAndCompressGroup(BBInfo *G) { if (G->Group != G) @@ -95,14 +99,20 @@ public: void buildEdges() { DEBUG(dbgs() << "Build Edge on " << F.getName() << "\n"); - const BasicBlock *BB = &(F.getEntryBlock()); + const BasicBlock *Entry = &(F.getEntryBlock()); uint64_t EntryWeight = (BFI != nullptr ? BFI->getEntryFreq() : 2); + Edge *EntryIncoming = nullptr, *EntryOutgoing = nullptr, + *ExitOutgoing = nullptr, *ExitIncoming = nullptr; + uint64_t MaxEntryOutWeight = 0, MaxExitOutWeight = 0, MaxExitInWeight = 0; + // Add a fake edge to the entry. - addEdge(nullptr, BB, EntryWeight); + EntryIncoming = &addEdge(nullptr, Entry, EntryWeight); + DEBUG(dbgs() << " Edge: from fake node to " << Entry->getName() + << " w = " << EntryWeight << "\n"); // Special handling for single BB functions. - if (succ_empty(BB)) { - addEdge(BB, nullptr, EntryWeight); + if (succ_empty(Entry)) { + addEdge(Entry, nullptr, EntryWeight); return; } @@ -126,16 +136,62 @@ public: } if (BPI != nullptr) Weight = BPI->getEdgeProbability(&*BB, TargetBB).scale(scaleFactor); - addEdge(&*BB, TargetBB, Weight).IsCritical = Critical; + auto *E = &addEdge(&*BB, TargetBB, Weight); + E->IsCritical = Critical; DEBUG(dbgs() << " Edge: from " << BB->getName() << " to " << TargetBB->getName() << " w=" << Weight << "\n"); + + // Keep track of entry/exit edges: + if (&*BB == Entry) { + if (Weight > MaxEntryOutWeight) { + MaxEntryOutWeight = Weight; + EntryOutgoing = E; + } + } + + auto *TargetTI = TargetBB->getTerminator(); + if (TargetTI && !TargetTI->getNumSuccessors()) { + if (Weight > MaxExitInWeight) { + MaxExitInWeight = Weight; + ExitIncoming = E; + } + } } } else { - addEdge(&*BB, nullptr, BBWeight); - DEBUG(dbgs() << " Edge: from " << BB->getName() << " to exit" + ExitBlockFound = true; + Edge *ExitO = &addEdge(&*BB, nullptr, BBWeight); + if (BBWeight > MaxExitOutWeight) { + MaxExitOutWeight = BBWeight; + ExitOutgoing = ExitO; + } + DEBUG(dbgs() << " Edge: from " << BB->getName() << " to fake exit" << " w = " << BBWeight << "\n"); } } + + // Entry/exit edge adjustment heurisitic: + // prefer instrumenting entry edge over exit edge + // if possible. Those exit edges may never have a chance to be + // executed (for instance the program is an event handling loop) + // before the profile is asynchronously dumped. + // + // If EntryIncoming and ExitOutgoing has similar weight, make sure + // ExitOutging is selected as the min-edge. Similarly, if EntryOutgoing + // and ExitIncoming has similar weight, make sure ExitIncoming becomes + // the min-edge. + uint64_t EntryInWeight = EntryWeight; + + if (EntryInWeight >= MaxExitOutWeight && + EntryInWeight * 2 < MaxExitOutWeight * 3) { + EntryIncoming->Weight = MaxExitOutWeight; + ExitOutgoing->Weight = EntryInWeight + 1; + } + + if (MaxEntryOutWeight >= MaxExitInWeight && + MaxEntryOutWeight * 2 < MaxExitInWeight * 3) { + EntryOutgoing->Weight = MaxExitInWeight; + ExitIncoming->Weight = MaxEntryOutWeight + 1; + } } // Sort CFG edges based on its weight. @@ -167,6 +223,10 @@ public: for (auto &Ei : AllEdges) { if (Ei->Removed) continue; + // If we detect infinite loops, force + // instrumenting the entry edge: + if (!ExitBlockFound && Ei->SrcBB == nullptr) + continue; if (unionGroups(Ei->SrcBB, Ei->DestBB)) Ei->InMST = true; } diff --git a/lib/Transforms/Instrumentation/CMakeLists.txt b/lib/Transforms/Instrumentation/CMakeLists.txt index f2806e278e6e..66fdcb3ccc49 100644 --- a/lib/Transforms/Instrumentation/CMakeLists.txt +++ b/lib/Transforms/Instrumentation/CMakeLists.txt @@ -12,6 +12,7 @@ add_llvm_library(LLVMInstrumentation SanitizerCoverage.cpp ThreadSanitizer.cpp EfficiencySanitizer.cpp + HWAddressSanitizer.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp index ddc975cbed1a..09bcbb282653 100644 --- a/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp +++ b/lib/Transforms/Instrumentation/DataFlowSanitizer.cpp @@ -1,4 +1,4 @@ -//===-- DataFlowSanitizer.cpp - dynamic data flow analysis ----------------===// +//===- DataFlowSanitizer.cpp - dynamic data flow analysis -----------------===// // // The LLVM Compiler Infrastructure // @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// This file is a part of DataFlowSanitizer, a generalised dynamic data flow /// analysis. @@ -43,32 +44,63 @@ /// /// For more information, please refer to the design document: /// http://clang.llvm.org/docs/DataFlowSanitizerDesign.html +// +//===----------------------------------------------------------------------===// #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/SpecialCaseList.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> #include <iterator> +#include <memory> #include <set> +#include <string> #include <utility> +#include <vector> using namespace llvm; @@ -129,10 +161,7 @@ static cl::opt<bool> ClDebugNonzeroLabels( "load or return with a nonzero label"), cl::Hidden); - -namespace { - -StringRef GetGlobalTypeString(const GlobalValue &G) { +static StringRef GetGlobalTypeString(const GlobalValue &G) { // Types of GlobalVariables are always pointer types. Type *GType = G.getValueType(); // For now we support blacklisting struct types only. @@ -143,11 +172,13 @@ StringRef GetGlobalTypeString(const GlobalValue &G) { return "<unknown type>"; } +namespace { + class DFSanABIList { std::unique_ptr<SpecialCaseList> SCL; public: - DFSanABIList() {} + DFSanABIList() = default; void set(std::unique_ptr<SpecialCaseList> List) { SCL = std::move(List); } @@ -155,7 +186,7 @@ class DFSanABIList { /// given category. bool isIn(const Function &F, StringRef Category) const { return isIn(*F.getParent(), Category) || - SCL->inSection("fun", F.getName(), Category); + SCL->inSection("dataflow", "fun", F.getName(), Category); } /// Returns whether this global alias is listed in the given category. @@ -167,15 +198,16 @@ class DFSanABIList { return true; if (isa<FunctionType>(GA.getValueType())) - return SCL->inSection("fun", GA.getName(), Category); + return SCL->inSection("dataflow", "fun", GA.getName(), Category); - return SCL->inSection("global", GA.getName(), Category) || - SCL->inSection("type", GetGlobalTypeString(GA), Category); + return SCL->inSection("dataflow", "global", GA.getName(), Category) || + SCL->inSection("dataflow", "type", GetGlobalTypeString(GA), + Category); } /// Returns whether this module is listed in the given category. bool isIn(const Module &M, StringRef Category) const { - return SCL->inSection("src", M.getModuleIdentifier(), Category); + return SCL->inSection("dataflow", "src", M.getModuleIdentifier(), Category); } }; @@ -255,7 +287,7 @@ class DataFlowSanitizer : public ModulePass { DFSanABIList ABIList; DenseMap<Value *, Function *> UnwrappedFnMap; AttrBuilder ReadOnlyNoneAttrs; - bool DFSanRuntimeShadowMask; + bool DFSanRuntimeShadowMask = false; Value *getShadowAddress(Value *Addr, Instruction *Pos); bool isInstrumented(const Function *F); @@ -271,11 +303,13 @@ class DataFlowSanitizer : public ModulePass { FunctionType *NewFT); Constant *getOrBuildTrampolineFunction(FunctionType *FT, StringRef FName); - public: +public: + static char ID; + DataFlowSanitizer( const std::vector<std::string> &ABIListFiles = std::vector<std::string>(), void *(*getArgTLS)() = nullptr, void *(*getRetValTLS)() = nullptr); - static char ID; + bool doInitialization(Module &M) override; bool runOnModule(Module &M) override; }; @@ -286,12 +320,12 @@ struct DFSanFunction { DominatorTree DT; DataFlowSanitizer::InstrumentedABI IA; bool IsNativeABI; - Value *ArgTLSPtr; - Value *RetvalTLSPtr; - AllocaInst *LabelReturnAlloca; + Value *ArgTLSPtr = nullptr; + Value *RetvalTLSPtr = nullptr; + AllocaInst *LabelReturnAlloca = nullptr; DenseMap<Value *, Value *> ValShadowMap; DenseMap<AllocaInst *, AllocaInst *> AllocaShadowMap; - std::vector<std::pair<PHINode *, PHINode *> > PHIFixups; + std::vector<std::pair<PHINode *, PHINode *>> PHIFixups; DenseSet<Instruction *> SkipInsts; std::vector<Value *> NonZeroChecks; bool AvoidNewBlocks; @@ -305,14 +339,13 @@ struct DFSanFunction { DenseMap<Value *, std::set<Value *>> ShadowElements; DFSanFunction(DataFlowSanitizer &DFS, Function *F, bool IsNativeABI) - : DFS(DFS), F(F), IA(DFS.getInstrumentedABI()), - IsNativeABI(IsNativeABI), ArgTLSPtr(nullptr), RetvalTLSPtr(nullptr), - LabelReturnAlloca(nullptr) { + : DFS(DFS), F(F), IA(DFS.getInstrumentedABI()), IsNativeABI(IsNativeABI) { DT.recalculate(*F); // FIXME: Need to track down the register allocator issue which causes poor // performance in pathological cases with large numbers of basic blocks. AvoidNewBlocks = F->size() > 1000; } + Value *getArgTLSPtr(); Value *getArgTLS(unsigned Index, Instruction *Pos); Value *getRetvalTLS(); @@ -327,8 +360,9 @@ struct DFSanFunction { }; class DFSanVisitor : public InstVisitor<DFSanVisitor> { - public: +public: DFSanFunction &DFSF; + DFSanVisitor(DFSanFunction &DFSF) : DFSF(DFSF) {} const DataLayout &getDataLayout() const { @@ -336,7 +370,6 @@ class DFSanVisitor : public InstVisitor<DFSanVisitor> { } void visitOperandShadowInst(Instruction &I); - void visitBinaryOperator(BinaryOperator &BO); void visitCastInst(CastInst &CI); void visitCmpInst(CmpInst &CI); @@ -357,9 +390,10 @@ class DFSanVisitor : public InstVisitor<DFSanVisitor> { void visitMemTransferInst(MemTransferInst &I); }; -} +} // end anonymous namespace char DataFlowSanitizer::ID; + INITIALIZE_PASS(DataFlowSanitizer, "dfsan", "DataFlowSanitizer: dynamic data flow analysis.", false, false) @@ -373,8 +407,7 @@ llvm::createDataFlowSanitizerPass(const std::vector<std::string> &ABIListFiles, DataFlowSanitizer::DataFlowSanitizer( const std::vector<std::string> &ABIListFiles, void *(*getArgTLS)(), void *(*getRetValTLS)()) - : ModulePass(ID), GetArgTLSPtr(getArgTLS), GetRetvalTLSPtr(getRetValTLS), - DFSanRuntimeShadowMask(false) { + : ModulePass(ID), GetArgTLSPtr(getArgTLS), GetRetvalTLSPtr(getRetValTLS) { std::vector<std::string> AllABIListFiles(std::move(ABIListFiles)); AllABIListFiles.insert(AllABIListFiles.end(), ClABIListFiles.begin(), ClABIListFiles.end()); @@ -382,7 +415,7 @@ DataFlowSanitizer::DataFlowSanitizer( } FunctionType *DataFlowSanitizer::getArgsFunctionType(FunctionType *T) { - llvm::SmallVector<Type *, 4> ArgTypes(T->param_begin(), T->param_end()); + SmallVector<Type *, 4> ArgTypes(T->param_begin(), T->param_end()); ArgTypes.append(T->getNumParams(), ShadowTy); if (T->isVarArg()) ArgTypes.push_back(ShadowPtrTy); @@ -394,7 +427,7 @@ FunctionType *DataFlowSanitizer::getArgsFunctionType(FunctionType *T) { FunctionType *DataFlowSanitizer::getTrampolineFunctionType(FunctionType *T) { assert(!T->isVarArg()); - llvm::SmallVector<Type *, 4> ArgTypes; + SmallVector<Type *, 4> ArgTypes; ArgTypes.push_back(T->getPointerTo()); ArgTypes.append(T->param_begin(), T->param_end()); ArgTypes.append(T->getNumParams(), ShadowTy); @@ -405,7 +438,7 @@ FunctionType *DataFlowSanitizer::getTrampolineFunctionType(FunctionType *T) { } FunctionType *DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { - llvm::SmallVector<Type *, 4> ArgTypes; + SmallVector<Type *, 4> ArgTypes; for (FunctionType::param_iterator i = T->param_begin(), e = T->param_end(); i != e; ++i) { FunctionType *FT; @@ -428,12 +461,12 @@ FunctionType *DataFlowSanitizer::getCustomFunctionType(FunctionType *T) { } bool DataFlowSanitizer::doInitialization(Module &M) { - llvm::Triple TargetTriple(M.getTargetTriple()); - bool IsX86_64 = TargetTriple.getArch() == llvm::Triple::x86_64; - bool IsMIPS64 = TargetTriple.getArch() == llvm::Triple::mips64 || - TargetTriple.getArch() == llvm::Triple::mips64el; - bool IsAArch64 = TargetTriple.getArch() == llvm::Triple::aarch64 || - TargetTriple.getArch() == llvm::Triple::aarch64_be; + Triple TargetTriple(M.getTargetTriple()); + bool IsX86_64 = TargetTriple.getArch() == Triple::x86_64; + bool IsMIPS64 = TargetTriple.getArch() == Triple::mips64 || + TargetTriple.getArch() == Triple::mips64el; + bool IsAArch64 = TargetTriple.getArch() == Triple::aarch64 || + TargetTriple.getArch() == Triple::aarch64_be; const DataLayout &DL = M.getDataLayout(); @@ -654,7 +687,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) { DFSanVarargWrapperFnTy); std::vector<Function *> FnsToInstrument; - llvm::SmallPtrSet<Function *, 2> FnsWithNativeABI; + SmallPtrSet<Function *, 2> FnsWithNativeABI; for (Function &i : M) { if (!i.isIntrinsic() && &i != DFSanUnionFn && @@ -797,11 +830,11 @@ bool DataFlowSanitizer::runOnModule(Module &M) { // DFSanVisitor may create new basic blocks, which confuses df_iterator. // Build a copy of the list before iterating over it. - llvm::SmallVector<BasicBlock *, 4> BBList(depth_first(&i->getEntryBlock())); + SmallVector<BasicBlock *, 4> BBList(depth_first(&i->getEntryBlock())); for (BasicBlock *i : BBList) { Instruction *Inst = &i->front(); - while (1) { + while (true) { // DFSanVisitor may split the current basic block, changing the current // instruction's next pointer and moving the next instruction to the // tail block from which we should continue. @@ -821,7 +854,7 @@ bool DataFlowSanitizer::runOnModule(Module &M) { // until we have visited every block. Therefore, the code that handles phi // nodes adds them to the PHIFixups list so that they can be properly // handled here. - for (std::vector<std::pair<PHINode *, PHINode *> >::iterator + for (std::vector<std::pair<PHINode *, PHINode *>>::iterator i = DFSF.PHIFixups.begin(), e = DFSF.PHIFixups.end(); i != e; ++i) { @@ -1045,8 +1078,7 @@ void DFSanVisitor::visitOperandShadowInst(Instruction &I) { Value *DFSanFunction::loadShadow(Value *Addr, uint64_t Size, uint64_t Align, Instruction *Pos) { if (AllocaInst *AI = dyn_cast<AllocaInst>(Addr)) { - llvm::DenseMap<AllocaInst *, AllocaInst *>::iterator i = - AllocaShadowMap.find(AI); + const auto i = AllocaShadowMap.find(AI); if (i != AllocaShadowMap.end()) { IRBuilder<> IRB(Pos); return IRB.CreateLoad(i->second); @@ -1187,8 +1219,7 @@ void DFSanVisitor::visitLoadInst(LoadInst &LI) { void DFSanFunction::storeShadow(Value *Addr, uint64_t Size, uint64_t Align, Value *Shadow, Instruction *Pos) { if (AllocaInst *AI = dyn_cast<AllocaInst>(Addr)) { - llvm::DenseMap<AllocaInst *, AllocaInst *>::iterator i = - AllocaShadowMap.find(AI); + const auto i = AllocaShadowMap.find(AI); if (i != AllocaShadowMap.end()) { IRBuilder<> IRB(Pos); IRB.CreateStore(Shadow, i->second); @@ -1409,24 +1440,21 @@ void DFSanVisitor::visitCallSite(CallSite CS) { if (i != DFSF.DFS.UnwrappedFnMap.end()) { Function *F = i->second; switch (DFSF.DFS.getWrapperKind(F)) { - case DataFlowSanitizer::WK_Warning: { + case DataFlowSanitizer::WK_Warning: CS.setCalledFunction(F); IRB.CreateCall(DFSF.DFS.DFSanUnimplementedFn, IRB.CreateGlobalStringPtr(F->getName())); DFSF.setShadow(CS.getInstruction(), DFSF.DFS.ZeroShadow); return; - } - case DataFlowSanitizer::WK_Discard: { + case DataFlowSanitizer::WK_Discard: CS.setCalledFunction(F); DFSF.setShadow(CS.getInstruction(), DFSF.DFS.ZeroShadow); return; - } - case DataFlowSanitizer::WK_Functional: { + case DataFlowSanitizer::WK_Functional: CS.setCalledFunction(F); visitOperandShadowInst(*CS.getInstruction()); return; - } - case DataFlowSanitizer::WK_Custom: { + case DataFlowSanitizer::WK_Custom: // Don't try to handle invokes of custom functions, it's too complicated. // Instead, invoke the dfsw$ wrapper, which will in turn call the __dfsw_ // wrapper. @@ -1526,7 +1554,6 @@ void DFSanVisitor::visitCallSite(CallSite CS) { } break; } - } } FunctionType *FT = cast<FunctionType>( diff --git a/lib/Transforms/Instrumentation/GCOVProfiling.cpp b/lib/Transforms/Instrumentation/GCOVProfiling.cpp index 56d0f5e983ca..67ca8172b0d5 100644 --- a/lib/Transforms/Instrumentation/GCOVProfiling.cpp +++ b/lib/Transforms/Instrumentation/GCOVProfiling.cpp @@ -21,6 +21,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/UniqueVector.h" +#include "llvm/Analysis/EHPersonalities.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/DebugLoc.h" #include "llvm/IR/IRBuilder.h" @@ -502,6 +503,23 @@ static bool functionHasLines(Function &F) { return false; } +static bool isUsingFuncletBasedEH(Function &F) { + if (!F.hasPersonalityFn()) return false; + + EHPersonality Personality = classifyEHPersonality(F.getPersonalityFn()); + return isFuncletEHPersonality(Personality); +} + +static bool shouldKeepInEntry(BasicBlock::iterator It) { + if (isa<AllocaInst>(*It)) return true; + if (isa<DbgInfoIntrinsic>(*It)) return true; + if (auto *II = dyn_cast<IntrinsicInst>(It)) { + if (II->getIntrinsicID() == llvm::Intrinsic::localescape) return true; + } + + return false; +} + void GCOVProfiler::emitProfileNotes() { NamedMDNode *CU_Nodes = M->getNamedMetadata("llvm.dbg.cu"); if (!CU_Nodes) return; @@ -519,6 +537,12 @@ void GCOVProfiler::emitProfileNotes() { std::error_code EC; raw_fd_ostream out(mangleName(CU, GCovFileType::GCNO), EC, sys::fs::F_None); + if (EC) { + Ctx->emitError(Twine("failed to open coverage notes file for writing: ") + + EC.message()); + continue; + } + std::string EdgeDestinations; unsigned FunctionIdent = 0; @@ -526,12 +550,14 @@ void GCOVProfiler::emitProfileNotes() { DISubprogram *SP = F.getSubprogram(); if (!SP) continue; if (!functionHasLines(F)) continue; + // TODO: Functions using funclet-based EH are currently not supported. + if (isUsingFuncletBasedEH(F)) continue; // gcov expects every function to start with an entry block that has a // single successor, so split the entry block to make sure of that. BasicBlock &EntryBlock = F.getEntryBlock(); BasicBlock::iterator It = EntryBlock.begin(); - while (isa<AllocaInst>(*It) || isa<DbgInfoIntrinsic>(*It)) + while (shouldKeepInEntry(It)) ++It; EntryBlock.splitBasicBlock(It); @@ -603,7 +629,10 @@ bool GCOVProfiler::emitProfileArcs() { DISubprogram *SP = F.getSubprogram(); if (!SP) continue; if (!functionHasLines(F)) continue; + // TODO: Functions using funclet-based EH are currently not supported. + if (isUsingFuncletBasedEH(F)) continue; if (!Result) Result = true; + unsigned Edges = 0; for (auto &BB : F) { TerminatorInst *TI = BB.getTerminator(); diff --git a/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp new file mode 100644 index 000000000000..2a25423e04bd --- /dev/null +++ b/lib/Transforms/Instrumentation/HWAddressSanitizer.cpp @@ -0,0 +1,327 @@ +//===- HWAddressSanitizer.cpp - detector of uninitialized reads -------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file is a part of HWAddressSanitizer, an address sanity checker +/// based on tagged addressing. +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Triple.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/IR/Function.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InlineAsm.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Instrumentation.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "hwasan" + +static const char *const kHwasanModuleCtorName = "hwasan.module_ctor"; +static const char *const kHwasanInitName = "__hwasan_init"; + +// Accesses sizes are powers of two: 1, 2, 4, 8, 16. +static const size_t kNumberOfAccessSizes = 5; + +static const size_t kShadowScale = 4; +static const unsigned kPointerTagShift = 56; + +static cl::opt<std::string> ClMemoryAccessCallbackPrefix( + "hwasan-memory-access-callback-prefix", + cl::desc("Prefix for memory access callbacks"), cl::Hidden, + cl::init("__hwasan_")); + +static cl::opt<bool> + ClInstrumentWithCalls("hwasan-instrument-with-calls", + cl::desc("instrument reads and writes with callbacks"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> ClInstrumentReads("hwasan-instrument-reads", + cl::desc("instrument read instructions"), + cl::Hidden, cl::init(true)); + +static cl::opt<bool> ClInstrumentWrites( + "hwasan-instrument-writes", cl::desc("instrument write instructions"), + cl::Hidden, cl::init(true)); + +static cl::opt<bool> ClInstrumentAtomics( + "hwasan-instrument-atomics", + cl::desc("instrument atomic instructions (rmw, cmpxchg)"), cl::Hidden, + cl::init(true)); + +namespace { + +/// \brief An instrumentation pass implementing detection of addressability bugs +/// using tagged pointers. +class HWAddressSanitizer : public FunctionPass { +public: + // Pass identification, replacement for typeid. + static char ID; + + HWAddressSanitizer() : FunctionPass(ID) {} + + StringRef getPassName() const override { return "HWAddressSanitizer"; } + + bool runOnFunction(Function &F) override; + bool doInitialization(Module &M) override; + + void initializeCallbacks(Module &M); + void instrumentMemAccessInline(Value *PtrLong, bool IsWrite, + unsigned AccessSizeIndex, + Instruction *InsertBefore); + bool instrumentMemAccess(Instruction *I); + Value *isInterestingMemoryAccess(Instruction *I, bool *IsWrite, + uint64_t *TypeSize, unsigned *Alignment, + Value **MaybeMask); + +private: + LLVMContext *C; + Type *IntptrTy; + + Function *HwasanCtorFunction; + + Function *HwasanMemoryAccessCallback[2][kNumberOfAccessSizes]; + Function *HwasanMemoryAccessCallbackSized[2]; +}; + +} // end anonymous namespace + +char HWAddressSanitizer::ID = 0; + +INITIALIZE_PASS_BEGIN( + HWAddressSanitizer, "hwasan", + "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, false) +INITIALIZE_PASS_END( + HWAddressSanitizer, "hwasan", + "HWAddressSanitizer: detect memory bugs using tagged addressing.", false, false) + +FunctionPass *llvm::createHWAddressSanitizerPass() { + return new HWAddressSanitizer(); +} + +/// \brief Module-level initialization. +/// +/// inserts a call to __hwasan_init to the module's constructor list. +bool HWAddressSanitizer::doInitialization(Module &M) { + DEBUG(dbgs() << "Init " << M.getName() << "\n"); + auto &DL = M.getDataLayout(); + + Triple TargetTriple(M.getTargetTriple()); + + C = &(M.getContext()); + IRBuilder<> IRB(*C); + IntptrTy = IRB.getIntPtrTy(DL); + + std::tie(HwasanCtorFunction, std::ignore) = + createSanitizerCtorAndInitFunctions(M, kHwasanModuleCtorName, + kHwasanInitName, + /*InitArgTypes=*/{}, + /*InitArgs=*/{}); + appendToGlobalCtors(M, HwasanCtorFunction, 0); + return true; +} + +void HWAddressSanitizer::initializeCallbacks(Module &M) { + IRBuilder<> IRB(*C); + for (size_t AccessIsWrite = 0; AccessIsWrite <= 1; AccessIsWrite++) { + const std::string TypeStr = AccessIsWrite ? "store" : "load"; + + HwasanMemoryAccessCallbackSized[AccessIsWrite] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + TypeStr, + FunctionType::get(IRB.getVoidTy(), {IntptrTy, IntptrTy}, false))); + + for (size_t AccessSizeIndex = 0; AccessSizeIndex < kNumberOfAccessSizes; + AccessSizeIndex++) { + HwasanMemoryAccessCallback[AccessIsWrite][AccessSizeIndex] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + ClMemoryAccessCallbackPrefix + TypeStr + + itostr(1ULL << AccessSizeIndex), + FunctionType::get(IRB.getVoidTy(), {IntptrTy}, false))); + } + } +} + +Value *HWAddressSanitizer::isInterestingMemoryAccess(Instruction *I, + bool *IsWrite, + uint64_t *TypeSize, + unsigned *Alignment, + Value **MaybeMask) { + // Skip memory accesses inserted by another instrumentation. + if (I->getMetadata("nosanitize")) return nullptr; + + Value *PtrOperand = nullptr; + const DataLayout &DL = I->getModule()->getDataLayout(); + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (!ClInstrumentReads) return nullptr; + *IsWrite = false; + *TypeSize = DL.getTypeStoreSizeInBits(LI->getType()); + *Alignment = LI->getAlignment(); + PtrOperand = LI->getPointerOperand(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(I)) { + if (!ClInstrumentWrites) return nullptr; + *IsWrite = true; + *TypeSize = DL.getTypeStoreSizeInBits(SI->getValueOperand()->getType()); + *Alignment = SI->getAlignment(); + PtrOperand = SI->getPointerOperand(); + } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(I)) { + if (!ClInstrumentAtomics) return nullptr; + *IsWrite = true; + *TypeSize = DL.getTypeStoreSizeInBits(RMW->getValOperand()->getType()); + *Alignment = 0; + PtrOperand = RMW->getPointerOperand(); + } else if (AtomicCmpXchgInst *XCHG = dyn_cast<AtomicCmpXchgInst>(I)) { + if (!ClInstrumentAtomics) return nullptr; + *IsWrite = true; + *TypeSize = DL.getTypeStoreSizeInBits(XCHG->getCompareOperand()->getType()); + *Alignment = 0; + PtrOperand = XCHG->getPointerOperand(); + } + + if (PtrOperand) { + // Do not instrument acesses from different address spaces; we cannot deal + // with them. + Type *PtrTy = cast<PointerType>(PtrOperand->getType()->getScalarType()); + if (PtrTy->getPointerAddressSpace() != 0) + return nullptr; + + // Ignore swifterror addresses. + // swifterror memory addresses are mem2reg promoted by instruction + // selection. As such they cannot have regular uses like an instrumentation + // function and it makes no sense to track them as memory. + if (PtrOperand->isSwiftError()) + return nullptr; + } + + return PtrOperand; +} + +static size_t TypeSizeToSizeIndex(uint32_t TypeSize) { + size_t Res = countTrailingZeros(TypeSize / 8); + assert(Res < kNumberOfAccessSizes); + return Res; +} + +void HWAddressSanitizer::instrumentMemAccessInline(Value *PtrLong, bool IsWrite, + unsigned AccessSizeIndex, + Instruction *InsertBefore) { + IRBuilder<> IRB(InsertBefore); + Value *PtrTag = IRB.CreateTrunc(IRB.CreateLShr(PtrLong, kPointerTagShift), IRB.getInt8Ty()); + Value *AddrLong = + IRB.CreateAnd(PtrLong, ConstantInt::get(PtrLong->getType(), + ~(0xFFULL << kPointerTagShift))); + Value *ShadowLong = IRB.CreateLShr(AddrLong, kShadowScale); + Value *MemTag = IRB.CreateLoad(IRB.CreateIntToPtr(ShadowLong, IRB.getInt8PtrTy())); + Value *TagMismatch = IRB.CreateICmpNE(PtrTag, MemTag); + + TerminatorInst *CheckTerm = + SplitBlockAndInsertIfThen(TagMismatch, InsertBefore, false, + MDBuilder(*C).createBranchWeights(1, 100000)); + + IRB.SetInsertPoint(CheckTerm); + // The signal handler will find the data address in x0. + InlineAsm *Asm = InlineAsm::get( + FunctionType::get(IRB.getVoidTy(), {PtrLong->getType()}, false), + "hlt #" + itostr(0x100 + IsWrite * 0x10 + AccessSizeIndex), "{x0}", + /*hasSideEffects=*/true); + IRB.CreateCall(Asm, PtrLong); +} + +bool HWAddressSanitizer::instrumentMemAccess(Instruction *I) { + DEBUG(dbgs() << "Instrumenting: " << *I << "\n"); + bool IsWrite = false; + unsigned Alignment = 0; + uint64_t TypeSize = 0; + Value *MaybeMask = nullptr; + Value *Addr = + isInterestingMemoryAccess(I, &IsWrite, &TypeSize, &Alignment, &MaybeMask); + + if (!Addr) + return false; + + if (MaybeMask) + return false; //FIXME + + IRBuilder<> IRB(I); + Value *AddrLong = IRB.CreatePointerCast(Addr, IntptrTy); + if (isPowerOf2_64(TypeSize) && + (TypeSize / 8 <= (1UL << (kNumberOfAccessSizes - 1))) && + (Alignment >= (1UL << kShadowScale) || Alignment == 0 || + Alignment >= TypeSize / 8)) { + size_t AccessSizeIndex = TypeSizeToSizeIndex(TypeSize); + if (ClInstrumentWithCalls) { + IRB.CreateCall(HwasanMemoryAccessCallback[IsWrite][AccessSizeIndex], + AddrLong); + } else { + instrumentMemAccessInline(AddrLong, IsWrite, AccessSizeIndex, I); + } + } else { + IRB.CreateCall(HwasanMemoryAccessCallbackSized[IsWrite], + {AddrLong, ConstantInt::get(IntptrTy, TypeSize / 8)}); + } + + return true; +} + +bool HWAddressSanitizer::runOnFunction(Function &F) { + if (&F == HwasanCtorFunction) + return false; + + if (!F.hasFnAttribute(Attribute::SanitizeHWAddress)) + return false; + + DEBUG(dbgs() << "Function: " << F.getName() << "\n"); + + initializeCallbacks(*F.getParent()); + + bool Changed = false; + SmallVector<Instruction*, 16> ToInstrument; + for (auto &BB : F) { + for (auto &Inst : BB) { + Value *MaybeMask = nullptr; + bool IsWrite; + unsigned Alignment; + uint64_t TypeSize; + Value *Addr = isInterestingMemoryAccess(&Inst, &IsWrite, &TypeSize, + &Alignment, &MaybeMask); + if (Addr || isa<MemIntrinsic>(Inst)) + ToInstrument.push_back(&Inst); + } + } + + for (auto Inst : ToInstrument) + Changed |= instrumentMemAccess(Inst); + + return Changed; +} diff --git a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp index 4089d81ea3e1..49b8a67a6c14 100644 --- a/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp +++ b/lib/Transforms/Instrumentation/IndirectCallPromotion.cpp @@ -1,4 +1,4 @@ -//===-- IndirectCallPromotion.cpp - Optimizations based on value profiling ===// +//===- IndirectCallPromotion.cpp - Optimizations based on value profiling -===// // // The LLVM Compiler Infrastructure // @@ -14,13 +14,15 @@ //===----------------------------------------------------------------------===// #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" -#include "llvm/ADT/Twine.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/IndirectCallPromotionAnalysis.h" #include "llvm/Analysis/IndirectCallSiteVisitor.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DerivedTypes.h" @@ -34,20 +36,23 @@ #include "llvm/IR/MDBuilder.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" -#include "llvm/PassRegistry.h" -#include "llvm/PassSupport.h" #include "llvm/ProfileData/InstrProf.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Error.h" #include "llvm/Support/Debug.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/PGOInstrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/CallPromotionUtils.h" #include <cassert> #include <cstdint> +#include <memory> +#include <string> +#include <utility> #include <vector> using namespace llvm; @@ -110,6 +115,7 @@ static cl::opt<bool> cl::desc("Dump IR after transformation happens")); namespace { + class PGOIndirectCallPromotionLegacyPass : public ModulePass { public: static char ID; @@ -120,6 +126,10 @@ public: *PassRegistry::getPassRegistry()); } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<ProfileSummaryInfoWrapperPass>(); + } + StringRef getPassName() const override { return "PGOIndirectCallPromotion"; } private: @@ -133,13 +143,20 @@ private: // the promoted direct call. bool SamplePGO; }; + } // end anonymous namespace char PGOIndirectCallPromotionLegacyPass::ID = 0; -INITIALIZE_PASS(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom", - "Use PGO instrumentation profile to promote indirect calls to " - "direct calls.", - false, false) + +INITIALIZE_PASS_BEGIN(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom", + "Use PGO instrumentation profile to promote indirect " + "calls to direct calls.", + false, false) +INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) +INITIALIZE_PASS_END(PGOIndirectCallPromotionLegacyPass, "pgo-icall-prom", + "Use PGO instrumentation profile to promote indirect " + "calls to direct calls.", + false, false) ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO, bool SamplePGO) { @@ -147,6 +164,7 @@ ModulePass *llvm::createPGOIndirectCallPromotionLegacyPass(bool InLTO, } namespace { + // The class for main data structure to promote indirect calls to conditional // direct calls. class ICallPromotionFunc { @@ -160,14 +178,13 @@ private: bool SamplePGO; - // Test if we can legally promote this direct-call of Target. - bool isPromotionLegal(Instruction *Inst, uint64_t Target, Function *&F, - const char **Reason = nullptr); + OptimizationRemarkEmitter &ORE; // A struct that records the direct target and it's call count. struct PromotionCandidate { Function *TargetFunction; uint64_t Count; + PromotionCandidate(Function *F, uint64_t C) : TargetFunction(F), Count(C) {} }; @@ -186,72 +203,17 @@ private: const std::vector<PromotionCandidate> &Candidates, uint64_t &TotalCount); - // Noncopyable - ICallPromotionFunc(const ICallPromotionFunc &other) = delete; - ICallPromotionFunc &operator=(const ICallPromotionFunc &other) = delete; - public: ICallPromotionFunc(Function &Func, Module *Modu, InstrProfSymtab *Symtab, - bool SamplePGO) - : F(Func), M(Modu), Symtab(Symtab), SamplePGO(SamplePGO) {} + bool SamplePGO, OptimizationRemarkEmitter &ORE) + : F(Func), M(Modu), Symtab(Symtab), SamplePGO(SamplePGO), ORE(ORE) {} + ICallPromotionFunc(const ICallPromotionFunc &) = delete; + ICallPromotionFunc &operator=(const ICallPromotionFunc &) = delete; - bool processFunction(); + bool processFunction(ProfileSummaryInfo *PSI); }; -} // end anonymous namespace -bool llvm::isLegalToPromote(Instruction *Inst, Function *F, - const char **Reason) { - // Check the return type. - Type *CallRetType = Inst->getType(); - if (!CallRetType->isVoidTy()) { - Type *FuncRetType = F->getReturnType(); - if (FuncRetType != CallRetType && - !CastInst::isBitCastable(FuncRetType, CallRetType)) { - if (Reason) - *Reason = "Return type mismatch"; - return false; - } - } - - // Check if the arguments are compatible with the parameters - FunctionType *DirectCalleeType = F->getFunctionType(); - unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); - CallSite CS(Inst); - unsigned ArgNum = CS.arg_size(); - - if (ParamNum != ArgNum && !DirectCalleeType->isVarArg()) { - if (Reason) - *Reason = "The number of arguments mismatch"; - return false; - } - - for (unsigned I = 0; I < ParamNum; ++I) { - Type *PTy = DirectCalleeType->getFunctionParamType(I); - Type *ATy = CS.getArgument(I)->getType(); - if (PTy == ATy) - continue; - if (!CastInst::castIsValid(Instruction::BitCast, CS.getArgument(I), PTy)) { - if (Reason) - *Reason = "Argument type mismatch"; - return false; - } - } - - DEBUG(dbgs() << " #" << NumOfPGOICallPromotion << " Promote the icall to " - << F->getName() << "\n"); - return true; -} - -bool ICallPromotionFunc::isPromotionLegal(Instruction *Inst, uint64_t Target, - Function *&TargetFunction, - const char **Reason) { - TargetFunction = Symtab->getFunction(Target); - if (TargetFunction == nullptr) { - *Reason = "Cannot find the target"; - return false; - } - return isLegalToPromote(Inst, TargetFunction, Reason); -} +} // end anonymous namespace // Indirect-call promotion heuristic. The direct targets are sorted based on // the count. Stop at the first target that is not promoted. @@ -279,51 +241,63 @@ ICallPromotionFunc::getPromotionCandidatesForCallSite( if (ICPInvokeOnly && dyn_cast<CallInst>(Inst)) { DEBUG(dbgs() << " Not promote: User options.\n"); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UserOptions", Inst) + << " Not promote: User options"; + }); break; } if (ICPCallOnly && dyn_cast<InvokeInst>(Inst)) { DEBUG(dbgs() << " Not promote: User option.\n"); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UserOptions", Inst) + << " Not promote: User options"; + }); break; } if (ICPCutOff != 0 && NumOfPGOICallPromotion >= ICPCutOff) { DEBUG(dbgs() << " Not promote: Cutoff reached.\n"); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "CutOffReached", Inst) + << " Not promote: Cutoff reached"; + }); + break; + } + + Function *TargetFunction = Symtab->getFunction(Target); + if (TargetFunction == nullptr) { + DEBUG(dbgs() << " Not promote: Cannot find the target\n"); + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnableToFindTarget", Inst) + << "Cannot promote indirect call: target not found"; + }); break; } - Function *TargetFunction = nullptr; + const char *Reason = nullptr; - if (!isPromotionLegal(Inst, Target, TargetFunction, &Reason)) { - StringRef TargetFuncName = Symtab->getFuncName(Target); - DEBUG(dbgs() << " Not promote: " << Reason << "\n"); - emitOptimizationRemarkMissed( - F.getContext(), "pgo-icall-prom", F, Inst->getDebugLoc(), - Twine("Cannot promote indirect call to ") + - (TargetFuncName.empty() ? Twine(Target) : Twine(TargetFuncName)) + - Twine(" with count of ") + Twine(Count) + ": " + Reason); + if (!isLegalToPromote(CallSite(Inst), TargetFunction, &Reason)) { + using namespace ore; + + ORE.emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnableToPromote", Inst) + << "Cannot promote indirect call to " + << NV("TargetFunction", TargetFunction) << " with count of " + << NV("Count", Count) << ": " << Reason; + }); break; } + Ret.push_back(PromotionCandidate(TargetFunction, Count)); TotalCount -= Count; } return Ret; } -// Create a diamond structure for If_Then_Else. Also update the profile -// count. Do the fix-up for the invoke instruction. -static void createIfThenElse(Instruction *Inst, Function *DirectCallee, - uint64_t Count, uint64_t TotalCount, - BasicBlock **DirectCallBB, - BasicBlock **IndirectCallBB, - BasicBlock **MergeBB) { - CallSite CS(Inst); - Value *OrigCallee = CS.getCalledValue(); - - IRBuilder<> BBBuilder(Inst); - LLVMContext &Ctx = Inst->getContext(); - Value *BCI1 = - BBBuilder.CreateBitCast(OrigCallee, Type::getInt8PtrTy(Ctx), ""); - Value *BCI2 = - BBBuilder.CreateBitCast(DirectCallee, Type::getInt8PtrTy(Ctx), ""); - Value *PtrCmp = BBBuilder.CreateICmpEQ(BCI1, BCI2, ""); +Instruction *llvm::pgo::promoteIndirectCall(Instruction *Inst, + Function *DirectCallee, + uint64_t Count, uint64_t TotalCount, + bool AttachProfToDirectCall, + OptimizationRemarkEmitter *ORE) { uint64_t ElseCount = TotalCount - Count; uint64_t MaxCount = (Count >= ElseCount ? Count : ElseCount); @@ -331,261 +305,26 @@ static void createIfThenElse(Instruction *Inst, Function *DirectCallee, MDBuilder MDB(Inst->getContext()); MDNode *BranchWeights = MDB.createBranchWeights( scaleBranchCount(Count, Scale), scaleBranchCount(ElseCount, Scale)); - TerminatorInst *ThenTerm, *ElseTerm; - SplitBlockAndInsertIfThenElse(PtrCmp, Inst, &ThenTerm, &ElseTerm, - BranchWeights); - *DirectCallBB = ThenTerm->getParent(); - (*DirectCallBB)->setName("if.true.direct_targ"); - *IndirectCallBB = ElseTerm->getParent(); - (*IndirectCallBB)->setName("if.false.orig_indirect"); - *MergeBB = Inst->getParent(); - (*MergeBB)->setName("if.end.icp"); - - // Special handing of Invoke instructions. - InvokeInst *II = dyn_cast<InvokeInst>(Inst); - if (!II) - return; - - // We don't need branch instructions for invoke. - ThenTerm->eraseFromParent(); - ElseTerm->eraseFromParent(); - - // Add jump from Merge BB to the NormalDest. This is needed for the newly - // created direct invoke stmt -- as its NormalDst will be fixed up to MergeBB. - BranchInst::Create(II->getNormalDest(), *MergeBB); -} - -// Find the PHI in BB that have the CallResult as the operand. -static bool getCallRetPHINode(BasicBlock *BB, Instruction *Inst) { - BasicBlock *From = Inst->getParent(); - for (auto &I : *BB) { - PHINode *PHI = dyn_cast<PHINode>(&I); - if (!PHI) - continue; - int IX = PHI->getBasicBlockIndex(From); - if (IX == -1) - continue; - Value *V = PHI->getIncomingValue(IX); - if (dyn_cast<Instruction>(V) == Inst) - return true; - } - return false; -} - -// This method fixes up PHI nodes in BB where BB is the UnwindDest of an -// invoke instruction. In BB, there may be PHIs with incoming block being -// OrigBB (the MergeBB after if-then-else splitting). After moving the invoke -// instructions to its own BB, OrigBB is no longer the predecessor block of BB. -// Instead two new predecessors are added: IndirectCallBB and DirectCallBB, -// so the PHI node's incoming BBs need to be fixed up accordingly. -static void fixupPHINodeForUnwind(Instruction *Inst, BasicBlock *BB, - BasicBlock *OrigBB, - BasicBlock *IndirectCallBB, - BasicBlock *DirectCallBB) { - for (auto &I : *BB) { - PHINode *PHI = dyn_cast<PHINode>(&I); - if (!PHI) - continue; - int IX = PHI->getBasicBlockIndex(OrigBB); - if (IX == -1) - continue; - Value *V = PHI->getIncomingValue(IX); - PHI->addIncoming(V, IndirectCallBB); - PHI->setIncomingBlock(IX, DirectCallBB); - } -} - -// This method fixes up PHI nodes in BB where BB is the NormalDest of an -// invoke instruction. In BB, there may be PHIs with incoming block being -// OrigBB (the MergeBB after if-then-else splitting). After moving the invoke -// instructions to its own BB, a new incoming edge will be added to the original -// NormalDstBB from the IndirectCallBB. -static void fixupPHINodeForNormalDest(Instruction *Inst, BasicBlock *BB, - BasicBlock *OrigBB, - BasicBlock *IndirectCallBB, - Instruction *NewInst) { - for (auto &I : *BB) { - PHINode *PHI = dyn_cast<PHINode>(&I); - if (!PHI) - continue; - int IX = PHI->getBasicBlockIndex(OrigBB); - if (IX == -1) - continue; - Value *V = PHI->getIncomingValue(IX); - if (dyn_cast<Instruction>(V) == Inst) { - PHI->setIncomingBlock(IX, IndirectCallBB); - PHI->addIncoming(NewInst, OrigBB); - continue; - } - PHI->addIncoming(V, IndirectCallBB); - } -} - -// Add a bitcast instruction to the direct-call return value if needed. -static Instruction *insertCallRetCast(const Instruction *Inst, - Instruction *DirectCallInst, - Function *DirectCallee) { - if (Inst->getType()->isVoidTy()) - return DirectCallInst; - - Type *CallRetType = Inst->getType(); - Type *FuncRetType = DirectCallee->getReturnType(); - if (FuncRetType == CallRetType) - return DirectCallInst; - - BasicBlock *InsertionBB; - if (CallInst *CI = dyn_cast<CallInst>(DirectCallInst)) - InsertionBB = CI->getParent(); - else - InsertionBB = (dyn_cast<InvokeInst>(DirectCallInst))->getNormalDest(); - - return (new BitCastInst(DirectCallInst, CallRetType, "", - InsertionBB->getTerminator())); -} - -// Create a DirectCall instruction in the DirectCallBB. -// Parameter Inst is the indirect-call (invoke) instruction. -// DirectCallee is the decl of the direct-call (invoke) target. -// DirecallBB is the BB that the direct-call (invoke) instruction is inserted. -// MergeBB is the bottom BB of the if-then-else-diamond after the -// transformation. For invoke instruction, the edges from DirectCallBB and -// IndirectCallBB to MergeBB are removed before this call (during -// createIfThenElse). -static Instruction *createDirectCallInst(const Instruction *Inst, - Function *DirectCallee, - BasicBlock *DirectCallBB, - BasicBlock *MergeBB) { - Instruction *NewInst = Inst->clone(); - if (CallInst *CI = dyn_cast<CallInst>(NewInst)) { - CI->setCalledFunction(DirectCallee); - CI->mutateFunctionType(DirectCallee->getFunctionType()); - } else { - // Must be an invoke instruction. Direct invoke's normal destination is - // fixed up to MergeBB. MergeBB is the place where return cast is inserted. - // Also since IndirectCallBB does not have an edge to MergeBB, there is no - // need to insert new PHIs into MergeBB. - InvokeInst *II = dyn_cast<InvokeInst>(NewInst); - assert(II); - II->setCalledFunction(DirectCallee); - II->mutateFunctionType(DirectCallee->getFunctionType()); - II->setNormalDest(MergeBB); - } - - DirectCallBB->getInstList().insert(DirectCallBB->getFirstInsertionPt(), - NewInst); - - // Clear the value profile data. - NewInst->setMetadata(LLVMContext::MD_prof, nullptr); - CallSite NewCS(NewInst); - FunctionType *DirectCalleeType = DirectCallee->getFunctionType(); - unsigned ParamNum = DirectCalleeType->getFunctionNumParams(); - for (unsigned I = 0; I < ParamNum; ++I) { - Type *ATy = NewCS.getArgument(I)->getType(); - Type *PTy = DirectCalleeType->getParamType(I); - if (ATy != PTy) { - BitCastInst *BI = new BitCastInst(NewCS.getArgument(I), PTy, "", NewInst); - NewCS.setArgument(I, BI); - } - } - - return insertCallRetCast(Inst, NewInst, DirectCallee); -} - -// Create a PHI to unify the return values of calls. -static void insertCallRetPHI(Instruction *Inst, Instruction *CallResult, - Function *DirectCallee) { - if (Inst->getType()->isVoidTy()) - return; - - BasicBlock *RetValBB = CallResult->getParent(); - - BasicBlock *PHIBB; - if (InvokeInst *II = dyn_cast<InvokeInst>(CallResult)) - RetValBB = II->getNormalDest(); - - PHIBB = RetValBB->getSingleSuccessor(); - if (getCallRetPHINode(PHIBB, Inst)) - return; - - PHINode *CallRetPHI = PHINode::Create(Inst->getType(), 0); - PHIBB->getInstList().push_front(CallRetPHI); - Inst->replaceAllUsesWith(CallRetPHI); - CallRetPHI->addIncoming(Inst, Inst->getParent()); - CallRetPHI->addIncoming(CallResult, RetValBB); -} - -// This function does the actual indirect-call promotion transformation: -// For an indirect-call like: -// Ret = (*Foo)(Args); -// It transforms to: -// if (Foo == DirectCallee) -// Ret1 = DirectCallee(Args); -// else -// Ret2 = (*Foo)(Args); -// Ret = phi(Ret1, Ret2); -// It adds type casts for the args do not match the parameters and the return -// value. Branch weights metadata also updated. -// If \p AttachProfToDirectCall is true, a prof metadata is attached to the -// new direct call to contain \p Count. This is used by SamplePGO inliner to -// check callsite hotness. -// Returns the promoted direct call instruction. -Instruction *llvm::promoteIndirectCall(Instruction *Inst, - Function *DirectCallee, uint64_t Count, - uint64_t TotalCount, - bool AttachProfToDirectCall) { - assert(DirectCallee != nullptr); - BasicBlock *BB = Inst->getParent(); - // Just to suppress the non-debug build warning. - (void)BB; - DEBUG(dbgs() << "\n\n== Basic Block Before ==\n"); - DEBUG(dbgs() << *BB << "\n"); - - BasicBlock *DirectCallBB, *IndirectCallBB, *MergeBB; - createIfThenElse(Inst, DirectCallee, Count, TotalCount, &DirectCallBB, - &IndirectCallBB, &MergeBB); Instruction *NewInst = - createDirectCallInst(Inst, DirectCallee, DirectCallBB, MergeBB); + promoteCallWithIfThenElse(CallSite(Inst), DirectCallee, BranchWeights); if (AttachProfToDirectCall) { SmallVector<uint32_t, 1> Weights; Weights.push_back(Count); MDBuilder MDB(NewInst->getContext()); - dyn_cast<Instruction>(NewInst->stripPointerCasts()) - ->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); - } - - // Move Inst from MergeBB to IndirectCallBB. - Inst->removeFromParent(); - IndirectCallBB->getInstList().insert(IndirectCallBB->getFirstInsertionPt(), - Inst); - - if (InvokeInst *II = dyn_cast<InvokeInst>(Inst)) { - // At this point, the original indirect invoke instruction has the original - // UnwindDest and NormalDest. For the direct invoke instruction, the - // NormalDest points to MergeBB, and MergeBB jumps to the original - // NormalDest. MergeBB might have a new bitcast instruction for the return - // value. The PHIs are with the original NormalDest. Since we now have two - // incoming edges to NormalDest and UnwindDest, we have to do some fixups. - // - // UnwindDest will not use the return value. So pass nullptr here. - fixupPHINodeForUnwind(Inst, II->getUnwindDest(), MergeBB, IndirectCallBB, - DirectCallBB); - // We don't need to update the operand from NormalDest for DirectCallBB. - // Pass nullptr here. - fixupPHINodeForNormalDest(Inst, II->getNormalDest(), MergeBB, - IndirectCallBB, NewInst); + NewInst->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); } - insertCallRetPHI(Inst, NewInst, DirectCallee); + using namespace ore; - DEBUG(dbgs() << "\n== Basic Blocks After ==\n"); - DEBUG(dbgs() << *BB << *DirectCallBB << *IndirectCallBB << *MergeBB << "\n"); - - emitOptimizationRemark( - BB->getContext(), "pgo-icall-prom", *BB->getParent(), Inst->getDebugLoc(), - Twine("Promote indirect call to ") + DirectCallee->getName() + - " with count " + Twine(Count) + " out of " + Twine(TotalCount)); + if (ORE) + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Promoted", Inst) + << "Promote indirect call to " << NV("DirectCallee", DirectCallee) + << " with count " << NV("Count", Count) << " out of " + << NV("TotalCount", TotalCount); + }); return NewInst; } @@ -597,7 +336,8 @@ uint32_t ICallPromotionFunc::tryToPromote( for (auto &C : Candidates) { uint64_t Count = C.Count; - promoteIndirectCall(Inst, C.TargetFunction, Count, TotalCount, SamplePGO); + pgo::promoteIndirectCall(Inst, C.TargetFunction, Count, TotalCount, + SamplePGO, &ORE); assert(TotalCount >= Count); TotalCount -= Count; NumOfPGOICallPromotion++; @@ -608,7 +348,7 @@ uint32_t ICallPromotionFunc::tryToPromote( // Traverse all the indirect-call callsite and get the value profile // annotation to perform indirect-call promotion. -bool ICallPromotionFunc::processFunction() { +bool ICallPromotionFunc::processFunction(ProfileSummaryInfo *PSI) { bool Changed = false; ICallPromotionAnalysis ICallAnalysis; for (auto &I : findIndirectCallSites(F)) { @@ -616,7 +356,8 @@ bool ICallPromotionFunc::processFunction() { uint64_t TotalCount; auto ICallProfDataRef = ICallAnalysis.getPromotionCandidatesForInstruction( I, NumVals, TotalCount, NumCandidates); - if (!NumCandidates) + if (!NumCandidates || + (PSI && PSI->hasProfileSummary() && !PSI->isHotCount(TotalCount))) continue; auto PromotionCandidates = getPromotionCandidatesForCallSite( I, ICallProfDataRef, TotalCount, NumCandidates); @@ -638,7 +379,9 @@ bool ICallPromotionFunc::processFunction() { } // A wrapper function that does the actual work. -static bool promoteIndirectCalls(Module &M, bool InLTO, bool SamplePGO) { +static bool promoteIndirectCalls(Module &M, ProfileSummaryInfo *PSI, + bool InLTO, bool SamplePGO, + ModuleAnalysisManager *AM = nullptr) { if (DisableICP) return false; InstrProfSymtab Symtab; @@ -654,8 +397,20 @@ static bool promoteIndirectCalls(Module &M, bool InLTO, bool SamplePGO) { continue; if (F.hasFnAttribute(Attribute::OptimizeNone)) continue; - ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO); - bool FuncChanged = ICallPromotion.processFunction(); + + std::unique_ptr<OptimizationRemarkEmitter> OwnedORE; + OptimizationRemarkEmitter *ORE; + if (AM) { + auto &FAM = + AM->getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + ORE = &FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + } else { + OwnedORE = llvm::make_unique<OptimizationRemarkEmitter>(&F); + ORE = OwnedORE.get(); + } + + ICallPromotionFunc ICallPromotion(F, &M, &Symtab, SamplePGO, *ORE); + bool FuncChanged = ICallPromotion.processFunction(PSI); if (ICPDUMPAFTER && FuncChanged) { DEBUG(dbgs() << "\n== IR Dump After =="; F.print(dbgs())); DEBUG(dbgs() << "\n"); @@ -670,15 +425,20 @@ static bool promoteIndirectCalls(Module &M, bool InLTO, bool SamplePGO) { } bool PGOIndirectCallPromotionLegacyPass::runOnModule(Module &M) { + ProfileSummaryInfo *PSI = + getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); + // Command-line option has the priority for InLTO. - return promoteIndirectCalls(M, InLTO | ICPLTOMode, + return promoteIndirectCalls(M, PSI, InLTO | ICPLTOMode, SamplePGO | ICPSamplePGOMode); } PreservedAnalyses PGOIndirectCallPromotion::run(Module &M, ModuleAnalysisManager &AM) { - if (!promoteIndirectCalls(M, InLTO | ICPLTOMode, - SamplePGO | ICPSamplePGOMode)) + ProfileSummaryInfo *PSI = &AM.getResult<ProfileSummaryAnalysis>(M); + + if (!promoteIndirectCalls(M, PSI, InLTO | ICPLTOMode, + SamplePGO | ICPSamplePGOMode, &AM)) return PreservedAnalyses::all(); return PreservedAnalyses::none(); diff --git a/lib/Transforms/Instrumentation/InstrProfiling.cpp b/lib/Transforms/Instrumentation/InstrProfiling.cpp index db8fa8977947..9b70f95480e4 100644 --- a/lib/Transforms/Instrumentation/InstrProfiling.cpp +++ b/lib/Transforms/Instrumentation/InstrProfiling.cpp @@ -43,7 +43,6 @@ #include "llvm/Support/Error.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/ModuleUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include <algorithm> @@ -245,6 +244,9 @@ public: } bool run(int64_t *NumPromoted) { + // Skip 'infinite' loops: + if (ExitBlocks.size() == 0) + return false; unsigned MaxProm = getMaxNumOfPromotionsInLoop(&L); if (MaxProm == 0) return false; diff --git a/lib/Transforms/Instrumentation/Instrumentation.cpp b/lib/Transforms/Instrumentation/Instrumentation.cpp index 7bb62d2c8455..8e9eea96ced7 100644 --- a/lib/Transforms/Instrumentation/Instrumentation.cpp +++ b/lib/Transforms/Instrumentation/Instrumentation.cpp @@ -58,7 +58,7 @@ BasicBlock::iterator llvm::PrepareToSplitEntryBlock(BasicBlock &BB, void llvm::initializeInstrumentation(PassRegistry &Registry) { initializeAddressSanitizerPass(Registry); initializeAddressSanitizerModulePass(Registry); - initializeBoundsCheckingPass(Registry); + initializeBoundsCheckingLegacyPassPass(Registry); initializeGCOVProfilerLegacyPassPass(Registry); initializePGOInstrumentationGenLegacyPassPass(Registry); initializePGOInstrumentationUseLegacyPassPass(Registry); @@ -66,6 +66,7 @@ void llvm::initializeInstrumentation(PassRegistry &Registry) { initializePGOMemOPSizeOptLegacyPassPass(Registry); initializeInstrProfilingLegacyPassPass(Registry); initializeMemorySanitizerPass(Registry); + initializeHWAddressSanitizerPass(Registry); initializeThreadSanitizerPass(Registry); initializeSanitizerCoverageModulePass(Registry); initializeDataFlowSanitizerPass(Registry); diff --git a/lib/Transforms/Instrumentation/MemorySanitizer.cpp b/lib/Transforms/Instrumentation/MemorySanitizer.cpp index b7c6271869cd..b3c39b5b1665 100644 --- a/lib/Transforms/Instrumentation/MemorySanitizer.cpp +++ b/lib/Transforms/Instrumentation/MemorySanitizer.cpp @@ -1,4 +1,4 @@ -//===-- MemorySanitizer.cpp - detector of uninitialized reads -------------===// +//===- MemorySanitizer.cpp - detector of uninitialized reads --------------===// // // The LLVM Compiler Infrastructure // @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// This file is a part of MemorySanitizer, a detector of uninitialized /// reads. @@ -88,32 +89,64 @@ /// implementation ignores the load aspect of CAS/RMW, always returning a clean /// value. It implements the store part as a simple atomic store by storing a /// clean shadow. - +// //===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueMap.h" +#include "llvm/Pass.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ModuleUtils.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <memory> +#include <string> +#include <tuple> using namespace llvm; @@ -137,18 +170,23 @@ static const size_t kNumberOfAccessSizes = 4; static cl::opt<int> ClTrackOrigins("msan-track-origins", cl::desc("Track origins (allocation sites) of poisoned memory"), cl::Hidden, cl::init(0)); + static cl::opt<bool> ClKeepGoing("msan-keep-going", cl::desc("keep going after reporting a UMR"), cl::Hidden, cl::init(false)); + static cl::opt<bool> ClPoisonStack("msan-poison-stack", cl::desc("poison uninitialized stack variables"), cl::Hidden, cl::init(true)); + static cl::opt<bool> ClPoisonStackWithCall("msan-poison-stack-with-call", cl::desc("poison uninitialized stack variables with a call"), cl::Hidden, cl::init(false)); + static cl::opt<int> ClPoisonStackPattern("msan-poison-stack-pattern", cl::desc("poison uninitialized stack variables with the given pattern"), cl::Hidden, cl::init(0xff)); + static cl::opt<bool> ClPoisonUndef("msan-poison-undef", cl::desc("poison undef temps"), cl::Hidden, cl::init(true)); @@ -217,6 +255,8 @@ struct PlatformMemoryMapParams { const MemoryMapParams *bits64; }; +} // end anonymous namespace + // i386 Linux static const MemoryMapParams Linux_I386_MemoryMapParams = { 0x000080000000, // AndMask @@ -250,7 +290,7 @@ static const MemoryMapParams Linux_MIPS64_MemoryMapParams = { // ppc64 Linux static const MemoryMapParams Linux_PowerPC64_MemoryMapParams = { - 0x200000000000, // AndMask + 0xE00000000000, // AndMask 0x100000000000, // XorMask 0x080000000000, // ShadowBase 0x1C0000000000, // OriginBase @@ -280,6 +320,14 @@ static const MemoryMapParams FreeBSD_X86_64_MemoryMapParams = { 0x380000000000, // OriginBase }; +// x86_64 NetBSD +static const MemoryMapParams NetBSD_X86_64_MemoryMapParams = { + 0, // AndMask + 0x500000000000, // XorMask + 0, // ShadowBase + 0x100000000000, // OriginBase +}; + static const PlatformMemoryMapParams Linux_X86_MemoryMapParams = { &Linux_I386_MemoryMapParams, &Linux_X86_64_MemoryMapParams, @@ -305,27 +353,44 @@ static const PlatformMemoryMapParams FreeBSD_X86_MemoryMapParams = { &FreeBSD_X86_64_MemoryMapParams, }; +static const PlatformMemoryMapParams NetBSD_X86_MemoryMapParams = { + nullptr, + &NetBSD_X86_64_MemoryMapParams, +}; + +namespace { + /// \brief An instrumentation pass implementing detection of uninitialized /// reads. /// /// MemorySanitizer: instrument the code in module to find /// uninitialized reads. class MemorySanitizer : public FunctionPass { - public: +public: + // Pass identification, replacement for typeid. + static char ID; + MemorySanitizer(int TrackOrigins = 0, bool Recover = false) : FunctionPass(ID), TrackOrigins(std::max(TrackOrigins, (int)ClTrackOrigins)), - Recover(Recover || ClKeepGoing), - WarningFn(nullptr) {} + Recover(Recover || ClKeepGoing) {} + StringRef getPassName() const override { return "MemorySanitizer"; } + void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); } + bool runOnFunction(Function &F) override; bool doInitialization(Module &M) override; - static char ID; // Pass identification, replacement for typeid. - private: +private: + friend struct MemorySanitizerVisitor; + friend struct VarArgAMD64Helper; + friend struct VarArgMIPS64Helper; + friend struct VarArgAArch64Helper; + friend struct VarArgPowerPC64Helper; + void initializeCallbacks(Module &M); /// \brief Track origins (allocation points) of uninitialized values. @@ -335,26 +400,34 @@ class MemorySanitizer : public FunctionPass { LLVMContext *C; Type *IntptrTy; Type *OriginTy; + /// \brief Thread-local shadow storage for function parameters. GlobalVariable *ParamTLS; + /// \brief Thread-local origin storage for function parameters. GlobalVariable *ParamOriginTLS; + /// \brief Thread-local shadow storage for function return value. GlobalVariable *RetvalTLS; + /// \brief Thread-local origin storage for function return value. GlobalVariable *RetvalOriginTLS; + /// \brief Thread-local shadow storage for in-register va_arg function /// parameters (x86_64-specific). GlobalVariable *VAArgTLS; + /// \brief Thread-local shadow storage for va_arg overflow area /// (x86_64-specific). GlobalVariable *VAArgOverflowSizeTLS; + /// \brief Thread-local space used to pass origin value to the UMR reporting /// function. GlobalVariable *OriginTLS; /// \brief The run-time callback to print a warning. - Value *WarningFn; + Value *WarningFn = nullptr; + // These arrays are indexed by log2(AccessSize). Value *MaybeWarningFn[kNumberOfAccessSizes]; Value *MaybeStoreOriginFn[kNumberOfAccessSizes]; @@ -362,11 +435,14 @@ class MemorySanitizer : public FunctionPass { /// \brief Run-time helper that generates a new origin value for a stack /// allocation. Value *MsanSetAllocaOrigin4Fn; + /// \brief Run-time helper that poisons stack on function entry. Value *MsanPoisonStackFn; + /// \brief Run-time helper that records a store (or any event) of an /// uninitialized value and returns an updated origin id encoding this info. Value *MsanChainOriginFn; + /// \brief MSan runtime replacements for memmove, memcpy and memset. Value *MemmoveFn, *MemcpyFn, *MemsetFn; @@ -374,21 +450,20 @@ class MemorySanitizer : public FunctionPass { const MemoryMapParams *MapParams; MDNode *ColdCallWeights; + /// \brief Branch weights for origin store. MDNode *OriginStoreWeights; + /// \brief An empty volatile inline asm that prevents callback merge. InlineAsm *EmptyAsm; - Function *MsanCtorFunction; - friend struct MemorySanitizerVisitor; - friend struct VarArgAMD64Helper; - friend struct VarArgMIPS64Helper; - friend struct VarArgAArch64Helper; - friend struct VarArgPowerPC64Helper; + Function *MsanCtorFunction; }; -} // anonymous namespace + +} // end anonymous namespace char MemorySanitizer::ID = 0; + INITIALIZE_PASS_BEGIN( MemorySanitizer, "msan", "MemorySanitizer: detects uninitialized reads.", false, false) @@ -515,6 +590,15 @@ bool MemorySanitizer::doInitialization(Module &M) { report_fatal_error("unsupported architecture"); } break; + case Triple::NetBSD: + switch (TargetTriple.getArch()) { + case Triple::x86_64: + MapParams = NetBSD_X86_MemoryMapParams.bits64; + break; + default: + report_fatal_error("unsupported architecture"); + } + break; case Triple::Linux: switch (TargetTriple.getArch()) { case Triple::x86_64: @@ -586,6 +670,8 @@ namespace { /// the function, and should avoid creating new basic blocks. A new /// instance of this class is created for each instrumented function. struct VarArgHelper { + virtual ~VarArgHelper() = default; + /// \brief Visit a CallSite. virtual void visitCallSite(CallSite &CS, IRBuilder<> &IRB) = 0; @@ -600,21 +686,22 @@ struct VarArgHelper { /// This method is called after visiting all interesting (see above) /// instructions in a function. virtual void finalizeInstrumentation() = 0; - - virtual ~VarArgHelper() {} }; struct MemorySanitizerVisitor; -VarArgHelper* -CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, - MemorySanitizerVisitor &Visitor); +} // end anonymous namespace -unsigned TypeSizeToSizeIndex(unsigned TypeSize) { +static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, + MemorySanitizerVisitor &Visitor); + +static unsigned TypeSizeToSizeIndex(unsigned TypeSize) { if (TypeSize <= 8) return 0; return Log2_32_Ceil((TypeSize + 7) / 8); } +namespace { + /// This class does all the work for a given function. Store and Load /// instructions store and load corresponding shadow and origin /// values. Most instructions propagate shadow from arguments to their @@ -641,8 +728,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Shadow; Value *Origin; Instruction *OrigIns; + ShadowOriginAndInsertPoint(Value *S, Value *O, Instruction *I) - : Shadow(S), Origin(O), OrigIns(I) { } + : Shadow(S), Origin(O), OrigIns(I) {} }; SmallVector<ShadowOriginAndInsertPoint, 16> InstrumentationList; SmallVector<StoreInst *, 16> StoreList; @@ -711,21 +799,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } void storeOrigin(IRBuilder<> &IRB, Value *Addr, Value *Shadow, Value *Origin, - unsigned Alignment, bool AsCall) { + Value *OriginPtr, unsigned Alignment, bool AsCall) { const DataLayout &DL = F.getParent()->getDataLayout(); unsigned OriginAlignment = std::max(kMinOriginAlignment, Alignment); unsigned StoreSize = DL.getTypeStoreSize(Shadow->getType()); if (Shadow->getType()->isAggregateType()) { - paintOrigin(IRB, updateOrigin(Origin, IRB), - getOriginPtr(Addr, IRB, Alignment), StoreSize, + paintOrigin(IRB, updateOrigin(Origin, IRB), OriginPtr, StoreSize, OriginAlignment); } else { Value *ConvertedShadow = convertToShadowTyNoVec(Shadow, IRB); Constant *ConstantShadow = dyn_cast_or_null<Constant>(ConvertedShadow); if (ConstantShadow) { if (ClCheckConstantShadow && !ConstantShadow->isZeroValue()) - paintOrigin(IRB, updateOrigin(Origin, IRB), - getOriginPtr(Addr, IRB, Alignment), StoreSize, + paintOrigin(IRB, updateOrigin(Origin, IRB), OriginPtr, StoreSize, OriginAlignment); return; } @@ -746,8 +832,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Instruction *CheckTerm = SplitBlockAndInsertIfThen( Cmp, &*IRB.GetInsertPoint(), false, MS.OriginStoreWeights); IRBuilder<> IRBNew(CheckTerm); - paintOrigin(IRBNew, updateOrigin(Origin, IRBNew), - getOriginPtr(Addr, IRBNew, Alignment), StoreSize, + paintOrigin(IRBNew, updateOrigin(Origin, IRBNew), OriginPtr, StoreSize, OriginAlignment); } } @@ -759,22 +844,25 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Val = SI->getValueOperand(); Value *Addr = SI->getPointerOperand(); Value *Shadow = SI->isAtomic() ? getCleanShadow(Val) : getShadow(Val); - Value *ShadowPtr = getShadowPtr(Addr, Shadow->getType(), IRB); - - StoreInst *NewSI = - IRB.CreateAlignedStore(Shadow, ShadowPtr, SI->getAlignment()); + Value *ShadowPtr, *OriginPtr; + Type *ShadowTy = Shadow->getType(); + unsigned Alignment = SI->getAlignment(); + unsigned OriginAlignment = std::max(kMinOriginAlignment, Alignment); + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment); + + StoreInst *NewSI = IRB.CreateAlignedStore(Shadow, ShadowPtr, Alignment); DEBUG(dbgs() << " STORE: " << *NewSI << "\n"); - (void)NewSI; if (ClCheckAccessAddress) - insertShadowCheck(Addr, SI); + insertShadowCheck(Addr, NewSI); if (SI->isAtomic()) SI->setOrdering(addReleaseOrdering(SI->getOrdering())); if (MS.TrackOrigins && !SI->isAtomic()) - storeOrigin(IRB, Addr, Shadow, getOrigin(Val), SI->getAlignment(), - InstrumentWithCalls); + storeOrigin(IRB, Addr, Shadow, getOrigin(Val), OriginPtr, + OriginAlignment, InstrumentWithCalls); } } @@ -856,7 +944,6 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { for (BasicBlock *BB : depth_first(&F.getEntryBlock())) visit(*BB); - // Finalize PHI nodes. for (PHINode *PN : ShadowPHINodes) { PHINode *PNS = cast<PHINode>(getShadow(PN)); @@ -954,39 +1041,50 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { return OffsetLong; } - /// \brief Compute the shadow address that corresponds to a given application - /// address. + /// \brief Compute the shadow and origin addresses corresponding to a given + /// application address. /// /// Shadow = ShadowBase + Offset - Value *getShadowPtr(Value *Addr, Type *ShadowTy, - IRBuilder<> &IRB) { - Value *ShadowLong = getShadowPtrOffset(Addr, IRB); + /// Origin = (OriginBase + Offset) & ~3ULL + std::pair<Value *, Value *> getShadowOriginPtrUserspace( + Value *Addr, IRBuilder<> &IRB, Type *ShadowTy, unsigned Alignment, + Instruction **FirstInsn) { + Value *ShadowOffset = getShadowPtrOffset(Addr, IRB); + Value *ShadowLong = ShadowOffset; uint64_t ShadowBase = MS.MapParams->ShadowBase; - if (ShadowBase != 0) + *FirstInsn = dyn_cast<Instruction>(ShadowLong); + if (ShadowBase != 0) { ShadowLong = IRB.CreateAdd(ShadowLong, ConstantInt::get(MS.IntptrTy, ShadowBase)); - return IRB.CreateIntToPtr(ShadowLong, PointerType::get(ShadowTy, 0)); + } + Value *ShadowPtr = + IRB.CreateIntToPtr(ShadowLong, PointerType::get(ShadowTy, 0)); + Value *OriginPtr = nullptr; + if (MS.TrackOrigins) { + Value *OriginLong = ShadowOffset; + uint64_t OriginBase = MS.MapParams->OriginBase; + if (OriginBase != 0) + OriginLong = IRB.CreateAdd(OriginLong, + ConstantInt::get(MS.IntptrTy, OriginBase)); + if (Alignment < kMinOriginAlignment) { + uint64_t Mask = kMinOriginAlignment - 1; + OriginLong = + IRB.CreateAnd(OriginLong, ConstantInt::get(MS.IntptrTy, ~Mask)); + } + OriginPtr = + IRB.CreateIntToPtr(OriginLong, PointerType::get(IRB.getInt32Ty(), 0)); + } + return std::make_pair(ShadowPtr, OriginPtr); } - /// \brief Compute the origin address that corresponds to a given application - /// address. - /// - /// OriginAddr = (OriginBase + Offset) & ~3ULL - Value *getOriginPtr(Value *Addr, IRBuilder<> &IRB, unsigned Alignment) { - Value *OriginLong = getShadowPtrOffset(Addr, IRB); - uint64_t OriginBase = MS.MapParams->OriginBase; - if (OriginBase != 0) - OriginLong = - IRB.CreateAdd(OriginLong, - ConstantInt::get(MS.IntptrTy, OriginBase)); - if (Alignment < kMinOriginAlignment) { - uint64_t Mask = kMinOriginAlignment - 1; - OriginLong = IRB.CreateAnd(OriginLong, - ConstantInt::get(MS.IntptrTy, ~Mask)); - } - return IRB.CreateIntToPtr(OriginLong, - PointerType::get(IRB.getInt32Ty(), 0)); + std::pair<Value *, Value *> getShadowOriginPtr(Value *Addr, IRBuilder<> &IRB, + Type *ShadowTy, + unsigned Alignment) { + Instruction *FirstInsn = nullptr; + std::pair<Value *, Value *> ret = + getShadowOriginPtrUserspace(Addr, IRB, ShadowTy, Alignment, &FirstInsn); + return ret; } /// \brief Compute the shadow address for a given function argument. @@ -1012,9 +1110,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// \brief Compute the shadow address for a retval. Value *getShadowPtrForRetval(Value *A, IRBuilder<> &IRB) { - Value *Base = IRB.CreatePointerCast(MS.RetvalTLS, MS.IntptrTy); - return IRB.CreateIntToPtr(Base, PointerType::get(getShadowTy(A), 0), - "_msret"); + return IRB.CreatePointerCast(MS.RetvalTLS, + PointerType::get(getShadowTy(A), 0), + "_msret"); } /// \brief Compute the origin address for a retval. @@ -1091,6 +1189,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *getShadow(Value *V) { if (!PropagateShadow) return getCleanShadow(V); if (Instruction *I = dyn_cast<Instruction>(V)) { + if (I->getMetadata("nosanitize")) + return getCleanShadow(V); // For instructions the shadow is already stored in the map. Value *Shadow = ShadowMap[V]; if (!Shadow) { @@ -1136,16 +1236,18 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Type *EltType = A->getType()->getPointerElementType(); ArgAlign = DL.getABITypeAlignment(EltType); } + Value *CpShadowPtr = + getShadowOriginPtr(V, EntryIRB, EntryIRB.getInt8Ty(), ArgAlign) + .first; if (Overflow) { // ParamTLS overflow. EntryIRB.CreateMemSet( - getShadowPtr(V, EntryIRB.getInt8Ty(), EntryIRB), - Constant::getNullValue(EntryIRB.getInt8Ty()), Size, ArgAlign); + CpShadowPtr, Constant::getNullValue(EntryIRB.getInt8Ty()), + Size, ArgAlign); } else { unsigned CopyAlign = std::min(ArgAlign, kShadowTLSAlignment); - Value *Cpy = EntryIRB.CreateMemCpy( - getShadowPtr(V, EntryIRB.getInt8Ty(), EntryIRB), Base, Size, - CopyAlign); + Value *Cpy = + EntryIRB.CreateMemCpy(CpShadowPtr, Base, Size, CopyAlign); DEBUG(dbgs() << " ByValCpy: " << *Cpy << "\n"); (void)Cpy; } @@ -1190,6 +1292,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (isa<Constant>(V)) return getCleanOrigin(); assert((isa<Instruction>(V) || isa<Argument>(V)) && "Unexpected value type in getOrigin()"); + if (Instruction *I = dyn_cast<Instruction>(V)) { + if (I->getMetadata("nosanitize")) + return getCleanOrigin(); + } Value *Origin = OriginMap[V]; assert(Origin && "Missing origin"); return Origin; @@ -1270,6 +1376,11 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } // ------------------- Visitors. + using InstVisitor<MemorySanitizerVisitor>::visit; + void visit(Instruction &I) { + if (!I.getMetadata("nosanitize")) + InstVisitor<MemorySanitizerVisitor>::visit(I); + } /// \brief Instrument LoadInst /// @@ -1277,13 +1388,16 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// Optionally, checks that the load address is fully defined. void visitLoadInst(LoadInst &I) { assert(I.getType()->isSized() && "Load type must have size"); + assert(!I.getMetadata("nosanitize")); IRBuilder<> IRB(I.getNextNode()); Type *ShadowTy = getShadowTy(&I); Value *Addr = I.getPointerOperand(); - if (PropagateShadow && !I.getMetadata("nosanitize")) { - Value *ShadowPtr = getShadowPtr(Addr, ShadowTy, IRB); - setShadow(&I, - IRB.CreateAlignedLoad(ShadowPtr, I.getAlignment(), "_msld")); + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = I.getAlignment(); + if (PropagateShadow) { + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment); + setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_msld")); } else { setShadow(&I, getCleanShadow(&I)); } @@ -1296,10 +1410,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins) { if (PropagateShadow) { - unsigned Alignment = I.getAlignment(); unsigned OriginAlignment = std::max(kMinOriginAlignment, Alignment); - setOrigin(&I, IRB.CreateAlignedLoad(getOriginPtr(Addr, IRB, Alignment), - OriginAlignment)); + setOrigin(&I, IRB.CreateAlignedLoad(OriginPtr, OriginAlignment)); } else { setOrigin(&I, getCleanOrigin()); } @@ -1319,7 +1431,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value *Addr = I.getOperand(0); - Value *ShadowPtr = getShadowPtr(Addr, I.getType(), IRB); + Value *ShadowPtr = + getShadowOriginPtr(Addr, IRB, I.getType(), /*Alignment*/ 1).first; if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); @@ -1489,14 +1602,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { /// arguments are initialized. template <bool CombineShadow> class Combiner { - Value *Shadow; - Value *Origin; + Value *Shadow = nullptr; + Value *Origin = nullptr; IRBuilder<> &IRB; MemorySanitizerVisitor *MSV; public: - Combiner(MemorySanitizerVisitor *MSV, IRBuilder<> &IRB) : - Shadow(nullptr), Origin(nullptr), IRB(IRB), MSV(MSV) {} + Combiner(MemorySanitizerVisitor *MSV, IRBuilder<> &IRB) + : IRB(IRB), MSV(MSV) {} /// \brief Add a pair of shadow and origin values to the mix. Combiner &Add(Value *OpShadow, Value *OpOrigin) { @@ -1550,8 +1663,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { } }; - typedef Combiner<true> ShadowAndOriginCombiner; - typedef Combiner<false> OriginCombiner; + using ShadowAndOriginCombiner = Combiner<true>; + using OriginCombiner = Combiner<false>; /// \brief Propagate origin for arbitrary operation. void setOriginForNaryOp(Instruction &I) { @@ -1940,18 +2053,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value* Addr = I.getArgOperand(0); Value *Shadow = getShadow(&I, 1); - Value *ShadowPtr = getShadowPtr(Addr, Shadow->getType(), IRB); + Value *ShadowPtr, *OriginPtr; // We don't know the pointer alignment (could be unaligned SSE store!). // Have to assume to worst case. + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Addr, IRB, Shadow->getType(), /*Alignment*/ 1); IRB.CreateAlignedStore(Shadow, ShadowPtr, 1); if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); // FIXME: factor out common code from materializeStores - if (MS.TrackOrigins) - IRB.CreateStore(getOrigin(&I, 1), getOriginPtr(Addr, IRB, 1)); + if (MS.TrackOrigins) IRB.CreateStore(getOrigin(&I, 1), OriginPtr); return true; } @@ -1964,11 +2078,14 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Addr = I.getArgOperand(0); Type *ShadowTy = getShadowTy(&I); + Value *ShadowPtr, *OriginPtr; if (PropagateShadow) { - Value *ShadowPtr = getShadowPtr(Addr, ShadowTy, IRB); // We don't know the pointer alignment (could be unaligned SSE load!). // Have to assume to worst case. - setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, 1, "_msld")); + unsigned Alignment = 1; + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Addr, IRB, ShadowTy, Alignment); + setShadow(&I, IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_msld")); } else { setShadow(&I, getCleanShadow(&I)); } @@ -1978,7 +2095,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (MS.TrackOrigins) { if (PropagateShadow) - setOrigin(&I, IRB.CreateLoad(getOriginPtr(Addr, IRB, 1))); + setOrigin(&I, IRB.CreateLoad(OriginPtr)); else setOrigin(&I, getCleanOrigin()); } @@ -2204,28 +2321,28 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { // intrinsic. Intrinsic::ID getSignedPackIntrinsic(Intrinsic::ID id) { switch (id) { - case llvm::Intrinsic::x86_sse2_packsswb_128: - case llvm::Intrinsic::x86_sse2_packuswb_128: - return llvm::Intrinsic::x86_sse2_packsswb_128; + case Intrinsic::x86_sse2_packsswb_128: + case Intrinsic::x86_sse2_packuswb_128: + return Intrinsic::x86_sse2_packsswb_128; - case llvm::Intrinsic::x86_sse2_packssdw_128: - case llvm::Intrinsic::x86_sse41_packusdw: - return llvm::Intrinsic::x86_sse2_packssdw_128; + case Intrinsic::x86_sse2_packssdw_128: + case Intrinsic::x86_sse41_packusdw: + return Intrinsic::x86_sse2_packssdw_128; - case llvm::Intrinsic::x86_avx2_packsswb: - case llvm::Intrinsic::x86_avx2_packuswb: - return llvm::Intrinsic::x86_avx2_packsswb; + case Intrinsic::x86_avx2_packsswb: + case Intrinsic::x86_avx2_packuswb: + return Intrinsic::x86_avx2_packsswb; - case llvm::Intrinsic::x86_avx2_packssdw: - case llvm::Intrinsic::x86_avx2_packusdw: - return llvm::Intrinsic::x86_avx2_packssdw; + case Intrinsic::x86_avx2_packssdw: + case Intrinsic::x86_avx2_packusdw: + return Intrinsic::x86_avx2_packssdw; - case llvm::Intrinsic::x86_mmx_packsswb: - case llvm::Intrinsic::x86_mmx_packuswb: - return llvm::Intrinsic::x86_mmx_packsswb; + case Intrinsic::x86_mmx_packsswb: + case Intrinsic::x86_mmx_packuswb: + return Intrinsic::x86_mmx_packsswb; - case llvm::Intrinsic::x86_mmx_packssdw: - return llvm::Intrinsic::x86_mmx_packssdw; + case Intrinsic::x86_mmx_packssdw: + return Intrinsic::x86_mmx_packssdw; default: llvm_unreachable("unexpected intrinsic id"); } @@ -2255,9 +2372,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { S2 = IRB.CreateBitCast(S2, T); } Value *S1_ext = IRB.CreateSExt( - IRB.CreateICmpNE(S1, llvm::Constant::getNullValue(T)), T); + IRB.CreateICmpNE(S1, Constant::getNullValue(T)), T); Value *S2_ext = IRB.CreateSExt( - IRB.CreateICmpNE(S2, llvm::Constant::getNullValue(T)), T); + IRB.CreateICmpNE(S2, Constant::getNullValue(T)), T); if (isX86_MMX) { Type *X86_MMXTy = Type::getX86_MMXTy(*MS.C); S1_ext = IRB.CreateBitCast(S1_ext, X86_MMXTy); @@ -2336,7 +2453,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRBuilder<> IRB(&I); Value* Addr = I.getArgOperand(0); Type *Ty = IRB.getInt32Ty(); - Value *ShadowPtr = getShadowPtr(Addr, Ty, IRB); + Value *ShadowPtr = getShadowOriginPtr(Addr, IRB, Ty, /*Alignment*/ 1).first; IRB.CreateStore(getCleanShadow(Ty), IRB.CreatePointerCast(ShadowPtr, Ty->getPointerTo())); @@ -2352,227 +2469,228 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { Value *Addr = I.getArgOperand(0); Type *Ty = IRB.getInt32Ty(); unsigned Alignment = 1; + Value *ShadowPtr, *OriginPtr; + std::tie(ShadowPtr, OriginPtr) = + getShadowOriginPtr(Addr, IRB, Ty, Alignment); if (ClCheckAccessAddress) insertShadowCheck(Addr, &I); - Value *Shadow = IRB.CreateAlignedLoad(getShadowPtr(Addr, Ty, IRB), - Alignment, "_ldmxcsr"); - Value *Origin = MS.TrackOrigins - ? IRB.CreateLoad(getOriginPtr(Addr, IRB, Alignment)) - : getCleanOrigin(); + Value *Shadow = IRB.CreateAlignedLoad(ShadowPtr, Alignment, "_ldmxcsr"); + Value *Origin = + MS.TrackOrigins ? IRB.CreateLoad(OriginPtr) : getCleanOrigin(); insertShadowCheck(Shadow, Origin, &I); } void visitIntrinsicInst(IntrinsicInst &I) { switch (I.getIntrinsicID()) { - case llvm::Intrinsic::bswap: + case Intrinsic::bswap: handleBswap(I); break; - case llvm::Intrinsic::x86_sse_stmxcsr: + case Intrinsic::x86_sse_stmxcsr: handleStmxcsr(I); break; - case llvm::Intrinsic::x86_sse_ldmxcsr: + case Intrinsic::x86_sse_ldmxcsr: handleLdmxcsr(I); break; - case llvm::Intrinsic::x86_avx512_vcvtsd2usi64: - case llvm::Intrinsic::x86_avx512_vcvtsd2usi32: - case llvm::Intrinsic::x86_avx512_vcvtss2usi64: - case llvm::Intrinsic::x86_avx512_vcvtss2usi32: - case llvm::Intrinsic::x86_avx512_cvttss2usi64: - case llvm::Intrinsic::x86_avx512_cvttss2usi: - case llvm::Intrinsic::x86_avx512_cvttsd2usi64: - case llvm::Intrinsic::x86_avx512_cvttsd2usi: - case llvm::Intrinsic::x86_avx512_cvtusi2sd: - case llvm::Intrinsic::x86_avx512_cvtusi2ss: - case llvm::Intrinsic::x86_avx512_cvtusi642sd: - case llvm::Intrinsic::x86_avx512_cvtusi642ss: - case llvm::Intrinsic::x86_sse2_cvtsd2si64: - case llvm::Intrinsic::x86_sse2_cvtsd2si: - case llvm::Intrinsic::x86_sse2_cvtsd2ss: - case llvm::Intrinsic::x86_sse2_cvtsi2sd: - case llvm::Intrinsic::x86_sse2_cvtsi642sd: - case llvm::Intrinsic::x86_sse2_cvtss2sd: - case llvm::Intrinsic::x86_sse2_cvttsd2si64: - case llvm::Intrinsic::x86_sse2_cvttsd2si: - case llvm::Intrinsic::x86_sse_cvtsi2ss: - case llvm::Intrinsic::x86_sse_cvtsi642ss: - case llvm::Intrinsic::x86_sse_cvtss2si64: - case llvm::Intrinsic::x86_sse_cvtss2si: - case llvm::Intrinsic::x86_sse_cvttss2si64: - case llvm::Intrinsic::x86_sse_cvttss2si: + case Intrinsic::x86_avx512_vcvtsd2usi64: + case Intrinsic::x86_avx512_vcvtsd2usi32: + case Intrinsic::x86_avx512_vcvtss2usi64: + case Intrinsic::x86_avx512_vcvtss2usi32: + case Intrinsic::x86_avx512_cvttss2usi64: + case Intrinsic::x86_avx512_cvttss2usi: + case Intrinsic::x86_avx512_cvttsd2usi64: + case Intrinsic::x86_avx512_cvttsd2usi: + case Intrinsic::x86_avx512_cvtusi2sd: + case Intrinsic::x86_avx512_cvtusi2ss: + case Intrinsic::x86_avx512_cvtusi642sd: + case Intrinsic::x86_avx512_cvtusi642ss: + case Intrinsic::x86_sse2_cvtsd2si64: + case Intrinsic::x86_sse2_cvtsd2si: + case Intrinsic::x86_sse2_cvtsd2ss: + case Intrinsic::x86_sse2_cvtsi2sd: + case Intrinsic::x86_sse2_cvtsi642sd: + case Intrinsic::x86_sse2_cvtss2sd: + case Intrinsic::x86_sse2_cvttsd2si64: + case Intrinsic::x86_sse2_cvttsd2si: + case Intrinsic::x86_sse_cvtsi2ss: + case Intrinsic::x86_sse_cvtsi642ss: + case Intrinsic::x86_sse_cvtss2si64: + case Intrinsic::x86_sse_cvtss2si: + case Intrinsic::x86_sse_cvttss2si64: + case Intrinsic::x86_sse_cvttss2si: handleVectorConvertIntrinsic(I, 1); break; - case llvm::Intrinsic::x86_sse_cvtps2pi: - case llvm::Intrinsic::x86_sse_cvttps2pi: + case Intrinsic::x86_sse_cvtps2pi: + case Intrinsic::x86_sse_cvttps2pi: handleVectorConvertIntrinsic(I, 2); break; - case llvm::Intrinsic::x86_avx512_psll_w_512: - case llvm::Intrinsic::x86_avx512_psll_d_512: - case llvm::Intrinsic::x86_avx512_psll_q_512: - case llvm::Intrinsic::x86_avx512_pslli_w_512: - case llvm::Intrinsic::x86_avx512_pslli_d_512: - case llvm::Intrinsic::x86_avx512_pslli_q_512: - case llvm::Intrinsic::x86_avx512_psrl_w_512: - case llvm::Intrinsic::x86_avx512_psrl_d_512: - case llvm::Intrinsic::x86_avx512_psrl_q_512: - case llvm::Intrinsic::x86_avx512_psra_w_512: - case llvm::Intrinsic::x86_avx512_psra_d_512: - case llvm::Intrinsic::x86_avx512_psra_q_512: - case llvm::Intrinsic::x86_avx512_psrli_w_512: - case llvm::Intrinsic::x86_avx512_psrli_d_512: - case llvm::Intrinsic::x86_avx512_psrli_q_512: - case llvm::Intrinsic::x86_avx512_psrai_w_512: - case llvm::Intrinsic::x86_avx512_psrai_d_512: - case llvm::Intrinsic::x86_avx512_psrai_q_512: - case llvm::Intrinsic::x86_avx512_psra_q_256: - case llvm::Intrinsic::x86_avx512_psra_q_128: - case llvm::Intrinsic::x86_avx512_psrai_q_256: - case llvm::Intrinsic::x86_avx512_psrai_q_128: - case llvm::Intrinsic::x86_avx2_psll_w: - case llvm::Intrinsic::x86_avx2_psll_d: - case llvm::Intrinsic::x86_avx2_psll_q: - case llvm::Intrinsic::x86_avx2_pslli_w: - case llvm::Intrinsic::x86_avx2_pslli_d: - case llvm::Intrinsic::x86_avx2_pslli_q: - case llvm::Intrinsic::x86_avx2_psrl_w: - case llvm::Intrinsic::x86_avx2_psrl_d: - case llvm::Intrinsic::x86_avx2_psrl_q: - case llvm::Intrinsic::x86_avx2_psra_w: - case llvm::Intrinsic::x86_avx2_psra_d: - case llvm::Intrinsic::x86_avx2_psrli_w: - case llvm::Intrinsic::x86_avx2_psrli_d: - case llvm::Intrinsic::x86_avx2_psrli_q: - case llvm::Intrinsic::x86_avx2_psrai_w: - case llvm::Intrinsic::x86_avx2_psrai_d: - case llvm::Intrinsic::x86_sse2_psll_w: - case llvm::Intrinsic::x86_sse2_psll_d: - case llvm::Intrinsic::x86_sse2_psll_q: - case llvm::Intrinsic::x86_sse2_pslli_w: - case llvm::Intrinsic::x86_sse2_pslli_d: - case llvm::Intrinsic::x86_sse2_pslli_q: - case llvm::Intrinsic::x86_sse2_psrl_w: - case llvm::Intrinsic::x86_sse2_psrl_d: - case llvm::Intrinsic::x86_sse2_psrl_q: - case llvm::Intrinsic::x86_sse2_psra_w: - case llvm::Intrinsic::x86_sse2_psra_d: - case llvm::Intrinsic::x86_sse2_psrli_w: - case llvm::Intrinsic::x86_sse2_psrli_d: - case llvm::Intrinsic::x86_sse2_psrli_q: - case llvm::Intrinsic::x86_sse2_psrai_w: - case llvm::Intrinsic::x86_sse2_psrai_d: - case llvm::Intrinsic::x86_mmx_psll_w: - case llvm::Intrinsic::x86_mmx_psll_d: - case llvm::Intrinsic::x86_mmx_psll_q: - case llvm::Intrinsic::x86_mmx_pslli_w: - case llvm::Intrinsic::x86_mmx_pslli_d: - case llvm::Intrinsic::x86_mmx_pslli_q: - case llvm::Intrinsic::x86_mmx_psrl_w: - case llvm::Intrinsic::x86_mmx_psrl_d: - case llvm::Intrinsic::x86_mmx_psrl_q: - case llvm::Intrinsic::x86_mmx_psra_w: - case llvm::Intrinsic::x86_mmx_psra_d: - case llvm::Intrinsic::x86_mmx_psrli_w: - case llvm::Intrinsic::x86_mmx_psrli_d: - case llvm::Intrinsic::x86_mmx_psrli_q: - case llvm::Intrinsic::x86_mmx_psrai_w: - case llvm::Intrinsic::x86_mmx_psrai_d: + case Intrinsic::x86_avx512_psll_w_512: + case Intrinsic::x86_avx512_psll_d_512: + case Intrinsic::x86_avx512_psll_q_512: + case Intrinsic::x86_avx512_pslli_w_512: + case Intrinsic::x86_avx512_pslli_d_512: + case Intrinsic::x86_avx512_pslli_q_512: + case Intrinsic::x86_avx512_psrl_w_512: + case Intrinsic::x86_avx512_psrl_d_512: + case Intrinsic::x86_avx512_psrl_q_512: + case Intrinsic::x86_avx512_psra_w_512: + case Intrinsic::x86_avx512_psra_d_512: + case Intrinsic::x86_avx512_psra_q_512: + case Intrinsic::x86_avx512_psrli_w_512: + case Intrinsic::x86_avx512_psrli_d_512: + case Intrinsic::x86_avx512_psrli_q_512: + case Intrinsic::x86_avx512_psrai_w_512: + case Intrinsic::x86_avx512_psrai_d_512: + case Intrinsic::x86_avx512_psrai_q_512: + case Intrinsic::x86_avx512_psra_q_256: + case Intrinsic::x86_avx512_psra_q_128: + case Intrinsic::x86_avx512_psrai_q_256: + case Intrinsic::x86_avx512_psrai_q_128: + case Intrinsic::x86_avx2_psll_w: + case Intrinsic::x86_avx2_psll_d: + case Intrinsic::x86_avx2_psll_q: + case Intrinsic::x86_avx2_pslli_w: + case Intrinsic::x86_avx2_pslli_d: + case Intrinsic::x86_avx2_pslli_q: + case Intrinsic::x86_avx2_psrl_w: + case Intrinsic::x86_avx2_psrl_d: + case Intrinsic::x86_avx2_psrl_q: + case Intrinsic::x86_avx2_psra_w: + case Intrinsic::x86_avx2_psra_d: + case Intrinsic::x86_avx2_psrli_w: + case Intrinsic::x86_avx2_psrli_d: + case Intrinsic::x86_avx2_psrli_q: + case Intrinsic::x86_avx2_psrai_w: + case Intrinsic::x86_avx2_psrai_d: + case Intrinsic::x86_sse2_psll_w: + case Intrinsic::x86_sse2_psll_d: + case Intrinsic::x86_sse2_psll_q: + case Intrinsic::x86_sse2_pslli_w: + case Intrinsic::x86_sse2_pslli_d: + case Intrinsic::x86_sse2_pslli_q: + case Intrinsic::x86_sse2_psrl_w: + case Intrinsic::x86_sse2_psrl_d: + case Intrinsic::x86_sse2_psrl_q: + case Intrinsic::x86_sse2_psra_w: + case Intrinsic::x86_sse2_psra_d: + case Intrinsic::x86_sse2_psrli_w: + case Intrinsic::x86_sse2_psrli_d: + case Intrinsic::x86_sse2_psrli_q: + case Intrinsic::x86_sse2_psrai_w: + case Intrinsic::x86_sse2_psrai_d: + case Intrinsic::x86_mmx_psll_w: + case Intrinsic::x86_mmx_psll_d: + case Intrinsic::x86_mmx_psll_q: + case Intrinsic::x86_mmx_pslli_w: + case Intrinsic::x86_mmx_pslli_d: + case Intrinsic::x86_mmx_pslli_q: + case Intrinsic::x86_mmx_psrl_w: + case Intrinsic::x86_mmx_psrl_d: + case Intrinsic::x86_mmx_psrl_q: + case Intrinsic::x86_mmx_psra_w: + case Intrinsic::x86_mmx_psra_d: + case Intrinsic::x86_mmx_psrli_w: + case Intrinsic::x86_mmx_psrli_d: + case Intrinsic::x86_mmx_psrli_q: + case Intrinsic::x86_mmx_psrai_w: + case Intrinsic::x86_mmx_psrai_d: handleVectorShiftIntrinsic(I, /* Variable */ false); break; - case llvm::Intrinsic::x86_avx2_psllv_d: - case llvm::Intrinsic::x86_avx2_psllv_d_256: - case llvm::Intrinsic::x86_avx512_psllv_d_512: - case llvm::Intrinsic::x86_avx2_psllv_q: - case llvm::Intrinsic::x86_avx2_psllv_q_256: - case llvm::Intrinsic::x86_avx512_psllv_q_512: - case llvm::Intrinsic::x86_avx2_psrlv_d: - case llvm::Intrinsic::x86_avx2_psrlv_d_256: - case llvm::Intrinsic::x86_avx512_psrlv_d_512: - case llvm::Intrinsic::x86_avx2_psrlv_q: - case llvm::Intrinsic::x86_avx2_psrlv_q_256: - case llvm::Intrinsic::x86_avx512_psrlv_q_512: - case llvm::Intrinsic::x86_avx2_psrav_d: - case llvm::Intrinsic::x86_avx2_psrav_d_256: - case llvm::Intrinsic::x86_avx512_psrav_d_512: - case llvm::Intrinsic::x86_avx512_psrav_q_128: - case llvm::Intrinsic::x86_avx512_psrav_q_256: - case llvm::Intrinsic::x86_avx512_psrav_q_512: + case Intrinsic::x86_avx2_psllv_d: + case Intrinsic::x86_avx2_psllv_d_256: + case Intrinsic::x86_avx512_psllv_d_512: + case Intrinsic::x86_avx2_psllv_q: + case Intrinsic::x86_avx2_psllv_q_256: + case Intrinsic::x86_avx512_psllv_q_512: + case Intrinsic::x86_avx2_psrlv_d: + case Intrinsic::x86_avx2_psrlv_d_256: + case Intrinsic::x86_avx512_psrlv_d_512: + case Intrinsic::x86_avx2_psrlv_q: + case Intrinsic::x86_avx2_psrlv_q_256: + case Intrinsic::x86_avx512_psrlv_q_512: + case Intrinsic::x86_avx2_psrav_d: + case Intrinsic::x86_avx2_psrav_d_256: + case Intrinsic::x86_avx512_psrav_d_512: + case Intrinsic::x86_avx512_psrav_q_128: + case Intrinsic::x86_avx512_psrav_q_256: + case Intrinsic::x86_avx512_psrav_q_512: handleVectorShiftIntrinsic(I, /* Variable */ true); break; - case llvm::Intrinsic::x86_sse2_packsswb_128: - case llvm::Intrinsic::x86_sse2_packssdw_128: - case llvm::Intrinsic::x86_sse2_packuswb_128: - case llvm::Intrinsic::x86_sse41_packusdw: - case llvm::Intrinsic::x86_avx2_packsswb: - case llvm::Intrinsic::x86_avx2_packssdw: - case llvm::Intrinsic::x86_avx2_packuswb: - case llvm::Intrinsic::x86_avx2_packusdw: + case Intrinsic::x86_sse2_packsswb_128: + case Intrinsic::x86_sse2_packssdw_128: + case Intrinsic::x86_sse2_packuswb_128: + case Intrinsic::x86_sse41_packusdw: + case Intrinsic::x86_avx2_packsswb: + case Intrinsic::x86_avx2_packssdw: + case Intrinsic::x86_avx2_packuswb: + case Intrinsic::x86_avx2_packusdw: handleVectorPackIntrinsic(I); break; - case llvm::Intrinsic::x86_mmx_packsswb: - case llvm::Intrinsic::x86_mmx_packuswb: + case Intrinsic::x86_mmx_packsswb: + case Intrinsic::x86_mmx_packuswb: handleVectorPackIntrinsic(I, 16); break; - case llvm::Intrinsic::x86_mmx_packssdw: + case Intrinsic::x86_mmx_packssdw: handleVectorPackIntrinsic(I, 32); break; - case llvm::Intrinsic::x86_mmx_psad_bw: - case llvm::Intrinsic::x86_sse2_psad_bw: - case llvm::Intrinsic::x86_avx2_psad_bw: + case Intrinsic::x86_mmx_psad_bw: + case Intrinsic::x86_sse2_psad_bw: + case Intrinsic::x86_avx2_psad_bw: handleVectorSadIntrinsic(I); break; - case llvm::Intrinsic::x86_sse2_pmadd_wd: - case llvm::Intrinsic::x86_avx2_pmadd_wd: - case llvm::Intrinsic::x86_ssse3_pmadd_ub_sw_128: - case llvm::Intrinsic::x86_avx2_pmadd_ub_sw: + case Intrinsic::x86_sse2_pmadd_wd: + case Intrinsic::x86_avx2_pmadd_wd: + case Intrinsic::x86_ssse3_pmadd_ub_sw_128: + case Intrinsic::x86_avx2_pmadd_ub_sw: handleVectorPmaddIntrinsic(I); break; - case llvm::Intrinsic::x86_ssse3_pmadd_ub_sw: + case Intrinsic::x86_ssse3_pmadd_ub_sw: handleVectorPmaddIntrinsic(I, 8); break; - case llvm::Intrinsic::x86_mmx_pmadd_wd: + case Intrinsic::x86_mmx_pmadd_wd: handleVectorPmaddIntrinsic(I, 16); break; - case llvm::Intrinsic::x86_sse_cmp_ss: - case llvm::Intrinsic::x86_sse2_cmp_sd: - case llvm::Intrinsic::x86_sse_comieq_ss: - case llvm::Intrinsic::x86_sse_comilt_ss: - case llvm::Intrinsic::x86_sse_comile_ss: - case llvm::Intrinsic::x86_sse_comigt_ss: - case llvm::Intrinsic::x86_sse_comige_ss: - case llvm::Intrinsic::x86_sse_comineq_ss: - case llvm::Intrinsic::x86_sse_ucomieq_ss: - case llvm::Intrinsic::x86_sse_ucomilt_ss: - case llvm::Intrinsic::x86_sse_ucomile_ss: - case llvm::Intrinsic::x86_sse_ucomigt_ss: - case llvm::Intrinsic::x86_sse_ucomige_ss: - case llvm::Intrinsic::x86_sse_ucomineq_ss: - case llvm::Intrinsic::x86_sse2_comieq_sd: - case llvm::Intrinsic::x86_sse2_comilt_sd: - case llvm::Intrinsic::x86_sse2_comile_sd: - case llvm::Intrinsic::x86_sse2_comigt_sd: - case llvm::Intrinsic::x86_sse2_comige_sd: - case llvm::Intrinsic::x86_sse2_comineq_sd: - case llvm::Intrinsic::x86_sse2_ucomieq_sd: - case llvm::Intrinsic::x86_sse2_ucomilt_sd: - case llvm::Intrinsic::x86_sse2_ucomile_sd: - case llvm::Intrinsic::x86_sse2_ucomigt_sd: - case llvm::Intrinsic::x86_sse2_ucomige_sd: - case llvm::Intrinsic::x86_sse2_ucomineq_sd: + case Intrinsic::x86_sse_cmp_ss: + case Intrinsic::x86_sse2_cmp_sd: + case Intrinsic::x86_sse_comieq_ss: + case Intrinsic::x86_sse_comilt_ss: + case Intrinsic::x86_sse_comile_ss: + case Intrinsic::x86_sse_comigt_ss: + case Intrinsic::x86_sse_comige_ss: + case Intrinsic::x86_sse_comineq_ss: + case Intrinsic::x86_sse_ucomieq_ss: + case Intrinsic::x86_sse_ucomilt_ss: + case Intrinsic::x86_sse_ucomile_ss: + case Intrinsic::x86_sse_ucomigt_ss: + case Intrinsic::x86_sse_ucomige_ss: + case Intrinsic::x86_sse_ucomineq_ss: + case Intrinsic::x86_sse2_comieq_sd: + case Intrinsic::x86_sse2_comilt_sd: + case Intrinsic::x86_sse2_comile_sd: + case Intrinsic::x86_sse2_comigt_sd: + case Intrinsic::x86_sse2_comige_sd: + case Intrinsic::x86_sse2_comineq_sd: + case Intrinsic::x86_sse2_ucomieq_sd: + case Intrinsic::x86_sse2_ucomilt_sd: + case Intrinsic::x86_sse2_ucomile_sd: + case Intrinsic::x86_sse2_ucomigt_sd: + case Intrinsic::x86_sse2_ucomige_sd: + case Intrinsic::x86_sse2_ucomineq_sd: handleVectorCompareScalarIntrinsic(I); break; - case llvm::Intrinsic::x86_sse_cmp_ps: - case llvm::Intrinsic::x86_sse2_cmp_pd: + case Intrinsic::x86_sse_cmp_ps: + case Intrinsic::x86_sse2_cmp_pd: // FIXME: For x86_avx_cmp_pd_256 and x86_avx_cmp_ps_256 this function // generates reasonably looking IR that fails in the backend with "Do not // know how to split the result of this operator!". @@ -2588,6 +2706,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitCallSite(CallSite CS) { Instruction &I = *CS.getInstruction(); + assert(!I.getMetadata("nosanitize")); assert((CS.isCall() || CS.isInvoke()) && "Unknown type of CallSite"); if (CS.isCall()) { CallInst *Call = cast<CallInst>(&I); @@ -2646,9 +2765,10 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { if (ArgOffset + Size > kParamTLSSize) break; unsigned ParamAlignment = CS.getParamAlignment(i); unsigned Alignment = std::min(ParamAlignment, kShadowTLSAlignment); - Store = IRB.CreateMemCpy(ArgShadowBase, - getShadowPtr(A, Type::getInt8Ty(*MS.C), IRB), - Size, Alignment); + Value *AShadowPtr = + getShadowOriginPtr(A, IRB, IRB.getInt8Ty(), Alignment).first; + + Store = IRB.CreateMemCpy(ArgShadowBase, AShadowPtr, Size, Alignment); } else { Size = DL.getTypeAllocSize(A->getType()); if (ArgOffset + Size > kParamTLSSize) break; @@ -2695,6 +2815,8 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { setOrigin(&I, getCleanOrigin()); return; } + // FIXME: NextInsn is likely in a basic block that has not been visited yet. + // Anything inserted there will be instrumented by MSan later! NextInsn = NormalDest->getFirstInsertionPt(); assert(NextInsn != NormalDest->end() && "Could not find insertion point for retval shadow load"); @@ -2766,7 +2888,9 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { IRB.CreateCall(MS.MsanPoisonStackFn, {IRB.CreatePointerCast(&I, IRB.getInt8PtrTy()), Len}); } else { - Value *ShadowBase = getShadowPtr(&I, Type::getInt8PtrTy(*MS.C), IRB); + Value *ShadowBase = + getShadowOriginPtr(&I, IRB, IRB.getInt8Ty(), I.getAlignment()).first; + Value *PoisonValue = IRB.getInt8(PoisonStack ? ClPoisonStackPattern : 0); IRB.CreateMemSet(ShadowBase, PoisonValue, Len, I.getAlignment()); } @@ -2845,7 +2969,7 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> { void visitLandingPadInst(LandingPadInst &I) { // Do nothing. - // See http://code.google.com/p/memory-sanitizer/issues/detail?id=1 + // See https://github.com/google/sanitizers/issues/504 setShadow(&I, getCleanShadow(&I)); setOrigin(&I, getCleanOrigin()); } @@ -2938,18 +3062,16 @@ struct VarArgAMD64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy; - Value *VAArgOverflowSize; + Value *VAArgTLSCopy = nullptr; + Value *VAArgOverflowSize = nullptr; SmallVector<CallInst*, 16> VAStartInstrumentationList; - VarArgAMD64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV), VAArgTLSCopy(nullptr), - VAArgOverflowSize(nullptr) {} - enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory }; + VarArgAMD64Helper(Function &F, MemorySanitizer &MS, + MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} + ArgKind classifyArgument(Value* arg) { // A very rough approximation of X86_64 argument classification rules. Type *T = arg->getType(); @@ -2990,38 +3112,44 @@ struct VarArgAMD64Helper : public VarArgHelper { assert(A->getType()->isPointerTy()); Type *RealTy = A->getType()->getPointerElementType(); uint64_t ArgSize = DL.getTypeAllocSize(RealTy); - Value *Base = getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); + Value *ShadowBase = + getShadowPtrForVAArgument(RealTy, IRB, OverflowOffset); OverflowOffset += alignTo(ArgSize, 8); - IRB.CreateMemCpy(Base, MSV.getShadowPtr(A, IRB.getInt8Ty(), IRB), - ArgSize, kShadowTLSAlignment); + Value *ShadowPtr, *OriginPtr; + std::tie(ShadowPtr, OriginPtr) = MSV.getShadowOriginPtr( + A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment); + + IRB.CreateMemCpy(ShadowBase, ShadowPtr, ArgSize, kShadowTLSAlignment); } else { ArgKind AK = classifyArgument(A); if (AK == AK_GeneralPurpose && GpOffset >= AMD64GpEndOffset) AK = AK_Memory; if (AK == AK_FloatingPoint && FpOffset >= AMD64FpEndOffset) AK = AK_Memory; - Value *Base; + Value *ShadowBase; switch (AK) { case AK_GeneralPurpose: - Base = getShadowPtrForVAArgument(A->getType(), IRB, GpOffset); + ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, GpOffset); GpOffset += 8; break; case AK_FloatingPoint: - Base = getShadowPtrForVAArgument(A->getType(), IRB, FpOffset); + ShadowBase = getShadowPtrForVAArgument(A->getType(), IRB, FpOffset); FpOffset += 16; break; case AK_Memory: if (IsFixed) continue; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); - Base = getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); + ShadowBase = + getShadowPtrForVAArgument(A->getType(), IRB, OverflowOffset); OverflowOffset += alignTo(ArgSize, 8); } // Take fixed arguments into account for GpOffset and FpOffset, // but don't actually store shadows for them. if (IsFixed) continue; - IRB.CreateAlignedStore(MSV.getShadow(A), Base, kShadowTLSAlignment); + IRB.CreateAlignedStore(MSV.getShadow(A), ShadowBase, + kShadowTLSAlignment); } } Constant *OverflowSize = @@ -3038,31 +3166,32 @@ struct VarArgAMD64Helper : public VarArgHelper { "_msarg"); } - void visitVAStartInst(VAStartInst &I) override { - if (F.getCallingConv() == CallingConv::Win64) - return; + void unpoisonVAListTagForInst(IntrinsicInst &I) { IRBuilder<> IRB(&I); - VAStartInstrumentationList.push_back(&I); Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = 8; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); // Unpoison the whole __va_list_tag. // FIXME: magic ABI constants. IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */24, /* alignment */8, false); + /* size */ 24, Alignment, false); + // We shouldn't need to zero out the origins, as they're only checked for + // nonzero shadow. } - void visitVACopyInst(VACopyInst &I) override { + void visitVAStartInst(VAStartInst &I) override { if (F.getCallingConv() == CallingConv::Win64) return; - IRBuilder<> IRB(&I); - Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); + VAStartInstrumentationList.push_back(&I); + unpoisonVAListTagForInst(I); + } - // Unpoison the whole __va_list_tag. - // FIXME: magic ABI constants. - IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */24, /* alignment */8, false); + void visitVACopyInst(VACopyInst &I) override { + if (F.getCallingConv() == CallingConv::Win64) return; + unpoisonVAListTagForInst(I); } void finalizeInstrumentation() override { @@ -3087,28 +3216,31 @@ struct VarArgAMD64Helper : public VarArgHelper { IRBuilder<> IRB(OrigInst->getNextNode()); Value *VAListTag = OrigInst->getArgOperand(0); - Value *RegSaveAreaPtrPtr = - IRB.CreateIntToPtr( + Value *RegSaveAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 16)), Type::getInt64PtrTy(*MS.C)); Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); - Value *RegSaveAreaShadowPtr = - MSV.getShadowPtr(RegSaveAreaPtr, IRB.getInt8Ty(), IRB); - IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, - AMD64FpEndOffset, 16); - - Value *OverflowArgAreaPtrPtr = - IRB.CreateIntToPtr( + Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; + unsigned Alignment = 16; + std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = + MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), + Alignment); + IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, AMD64FpEndOffset, + Alignment); + Value *OverflowArgAreaPtrPtr = IRB.CreateIntToPtr( IRB.CreateAdd(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), ConstantInt::get(MS.IntptrTy, 8)), Type::getInt64PtrTy(*MS.C)); Value *OverflowArgAreaPtr = IRB.CreateLoad(OverflowArgAreaPtrPtr); - Value *OverflowArgAreaShadowPtr = - MSV.getShadowPtr(OverflowArgAreaPtr, IRB.getInt8Ty(), IRB); + Value *OverflowArgAreaShadowPtr, *OverflowArgAreaOriginPtr; + std::tie(OverflowArgAreaShadowPtr, OverflowArgAreaOriginPtr) = + MSV.getShadowOriginPtr(OverflowArgAreaPtr, IRB, IRB.getInt8Ty(), + Alignment); Value *SrcPtr = IRB.CreateConstGEP1_32(IRB.getInt8Ty(), VAArgTLSCopy, AMD64FpEndOffset); - IRB.CreateMemCpy(OverflowArgAreaShadowPtr, SrcPtr, VAArgOverflowSize, 16); + IRB.CreateMemCpy(OverflowArgAreaShadowPtr, SrcPtr, VAArgOverflowSize, + Alignment); } } }; @@ -3118,15 +3250,13 @@ struct VarArgMIPS64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy; - Value *VAArgSize; + Value *VAArgTLSCopy = nullptr; + Value *VAArgSize = nullptr; SmallVector<CallInst*, 16> VAStartInstrumentationList; VarArgMIPS64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV), VAArgTLSCopy(nullptr), - VAArgSize(nullptr) {} + MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} void visitCallSite(CallSite &CS, IRBuilder<> &IRB) override { unsigned VAArgOffset = 0; @@ -3134,11 +3264,11 @@ struct VarArgMIPS64Helper : public VarArgHelper { for (CallSite::arg_iterator ArgIt = CS.arg_begin() + CS.getFunctionType()->getNumParams(), End = CS.arg_end(); ArgIt != End; ++ArgIt) { - llvm::Triple TargetTriple(F.getParent()->getTargetTriple()); + Triple TargetTriple(F.getParent()->getTargetTriple()); Value *A = *ArgIt; Value *Base; uint64_t ArgSize = DL.getTypeAllocSize(A->getType()); - if (TargetTriple.getArch() == llvm::Triple::mips64) { + if (TargetTriple.getArch() == Triple::mips64) { // Adjusting the shadow for argument with size < 8 to match the placement // of bits in big endian system if (ArgSize < 8) @@ -3169,19 +3299,24 @@ struct VarArgMIPS64Helper : public VarArgHelper { IRBuilder<> IRB(&I); VAStartInstrumentationList.push_back(&I); Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = 8; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */8, /* alignment */8, false); + /* size */ 8, Alignment, false); } void visitVACopyInst(VACopyInst &I) override { IRBuilder<> IRB(&I); + VAStartInstrumentationList.push_back(&I); Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); - // Unpoison the whole __va_list_tag. - // FIXME: magic ABI constants. + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = 8; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */8, /* alignment */8, false); + /* size */ 8, Alignment, false); } void finalizeInstrumentation() override { @@ -3209,14 +3344,16 @@ struct VarArgMIPS64Helper : public VarArgHelper { IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), Type::getInt64PtrTy(*MS.C)); Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); - Value *RegSaveAreaShadowPtr = - MSV.getShadowPtr(RegSaveAreaPtr, IRB.getInt8Ty(), IRB); - IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, CopySize, 8); + Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; + unsigned Alignment = 8; + std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = + MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), + Alignment); + IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, CopySize, Alignment); } } }; - /// \brief AArch64-specific implementation of VarArgHelper. struct VarArgAArch64Helper : public VarArgHelper { static const unsigned kAArch64GrArgSize = 64; @@ -3233,18 +3370,16 @@ struct VarArgAArch64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy; - Value *VAArgOverflowSize; + Value *VAArgTLSCopy = nullptr; + Value *VAArgOverflowSize = nullptr; SmallVector<CallInst*, 16> VAStartInstrumentationList; - VarArgAArch64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV), VAArgTLSCopy(nullptr), - VAArgOverflowSize(nullptr) {} - enum ArgKind { AK_GeneralPurpose, AK_FloatingPoint, AK_Memory }; + VarArgAArch64Helper(Function &F, MemorySanitizer &MS, + MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} + ArgKind classifyArgument(Value* arg) { Type *T = arg->getType(); if (T->isFPOrFPVectorTy()) @@ -3324,21 +3459,24 @@ struct VarArgAArch64Helper : public VarArgHelper { IRBuilder<> IRB(&I); VAStartInstrumentationList.push_back(&I); Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); - // Unpoison the whole __va_list_tag. - // FIXME: magic ABI constants (size of va_list). + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = 8; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */32, /* alignment */8, false); + /* size */ 32, Alignment, false); } void visitVACopyInst(VACopyInst &I) override { IRBuilder<> IRB(&I); + VAStartInstrumentationList.push_back(&I); Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); - // Unpoison the whole __va_list_tag. - // FIXME: magic ABI constants (size of va_list). + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = 8; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */32, /* alignment */8, false); + /* size */ 32, Alignment, false); } // Retrieve a va_list field of 'void*' size. @@ -3424,7 +3562,9 @@ struct VarArgAArch64Helper : public VarArgHelper { IRB.CreateAdd(GrArgSize, GrOffSaveArea); Value *GrRegSaveAreaShadowPtr = - MSV.getShadowPtr(GrRegSaveAreaPtr, IRB.getInt8Ty(), IRB); + MSV.getShadowOriginPtr(GrRegSaveAreaPtr, IRB, IRB.getInt8Ty(), + /*Alignment*/ 8) + .first; Value *GrSrcPtr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), VAArgTLSCopy, GrRegSaveAreaShadowPtrOff); @@ -3437,7 +3577,9 @@ struct VarArgAArch64Helper : public VarArgHelper { IRB.CreateAdd(VrArgSize, VrOffSaveArea); Value *VrRegSaveAreaShadowPtr = - MSV.getShadowPtr(VrRegSaveAreaPtr, IRB.getInt8Ty(), IRB); + MSV.getShadowOriginPtr(VrRegSaveAreaPtr, IRB, IRB.getInt8Ty(), + /*Alignment*/ 8) + .first; Value *VrSrcPtr = IRB.CreateInBoundsGEP( IRB.getInt8Ty(), @@ -3450,7 +3592,9 @@ struct VarArgAArch64Helper : public VarArgHelper { // And finally for remaining arguments. Value *StackSaveAreaShadowPtr = - MSV.getShadowPtr(StackSaveAreaPtr, IRB.getInt8Ty(), IRB); + MSV.getShadowOriginPtr(StackSaveAreaPtr, IRB, IRB.getInt8Ty(), + /*Alignment*/ 16) + .first; Value *StackSrcPtr = IRB.CreateInBoundsGEP(IRB.getInt8Ty(), VAArgTLSCopy, @@ -3467,15 +3611,13 @@ struct VarArgPowerPC64Helper : public VarArgHelper { Function &F; MemorySanitizer &MS; MemorySanitizerVisitor &MSV; - Value *VAArgTLSCopy; - Value *VAArgSize; + Value *VAArgTLSCopy = nullptr; + Value *VAArgSize = nullptr; SmallVector<CallInst*, 16> VAStartInstrumentationList; VarArgPowerPC64Helper(Function &F, MemorySanitizer &MS, - MemorySanitizerVisitor &MSV) - : F(F), MS(MS), MSV(MSV), VAArgTLSCopy(nullptr), - VAArgSize(nullptr) {} + MemorySanitizerVisitor &MSV) : F(F), MS(MS), MSV(MSV) {} void visitCallSite(CallSite &CS, IRBuilder<> &IRB) override { // For PowerPC, we need to deal with alignment of stack arguments - @@ -3485,12 +3627,12 @@ struct VarArgPowerPC64Helper : public VarArgHelper { // compute current offset from stack pointer (which is always properly // aligned), and offset for the first vararg, then subtract them. unsigned VAArgBase; - llvm::Triple TargetTriple(F.getParent()->getTargetTriple()); + Triple TargetTriple(F.getParent()->getTargetTriple()); // Parameter save area starts at 48 bytes from frame pointer for ABIv1, // and 32 bytes for ABIv2. This is usually determined by target // endianness, but in theory could be overriden by function attribute. // For simplicity, we ignore it here (it'd only matter for QPX vectors). - if (TargetTriple.getArch() == llvm::Triple::ppc64) + if (TargetTriple.getArch() == Triple::ppc64) VAArgBase = 48; else VAArgBase = 32; @@ -3513,8 +3655,11 @@ struct VarArgPowerPC64Helper : public VarArgHelper { if (!IsFixed) { Value *Base = getShadowPtrForVAArgument(RealTy, IRB, VAArgOffset - VAArgBase); - IRB.CreateMemCpy(Base, MSV.getShadowPtr(A, IRB.getInt8Ty(), IRB), - ArgSize, kShadowTLSAlignment); + Value *AShadowPtr, *AOriginPtr; + std::tie(AShadowPtr, AOriginPtr) = MSV.getShadowOriginPtr( + A, IRB, IRB.getInt8Ty(), kShadowTLSAlignment); + + IRB.CreateMemCpy(Base, AShadowPtr, ArgSize, kShadowTLSAlignment); } VAArgOffset += alignTo(ArgSize, 8); } else { @@ -3572,19 +3717,25 @@ struct VarArgPowerPC64Helper : public VarArgHelper { IRBuilder<> IRB(&I); VAStartInstrumentationList.push_back(&I); Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = 8; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */8, /* alignment */8, false); + /* size */ 8, Alignment, false); } void visitVACopyInst(VACopyInst &I) override { IRBuilder<> IRB(&I); Value *VAListTag = I.getArgOperand(0); - Value *ShadowPtr = MSV.getShadowPtr(VAListTag, IRB.getInt8Ty(), IRB); + Value *ShadowPtr, *OriginPtr; + unsigned Alignment = 8; + std::tie(ShadowPtr, OriginPtr) = + MSV.getShadowOriginPtr(VAListTag, IRB, IRB.getInt8Ty(), Alignment); // Unpoison the whole __va_list_tag. // FIXME: magic ABI constants. IRB.CreateMemSet(ShadowPtr, Constant::getNullValue(IRB.getInt8Ty()), - /* size */8, /* alignment */8, false); + /* size */ 8, Alignment, false); } void finalizeInstrumentation() override { @@ -3612,9 +3763,12 @@ struct VarArgPowerPC64Helper : public VarArgHelper { IRB.CreateIntToPtr(IRB.CreatePtrToInt(VAListTag, MS.IntptrTy), Type::getInt64PtrTy(*MS.C)); Value *RegSaveAreaPtr = IRB.CreateLoad(RegSaveAreaPtrPtr); - Value *RegSaveAreaShadowPtr = - MSV.getShadowPtr(RegSaveAreaPtr, IRB.getInt8Ty(), IRB); - IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, CopySize, 8); + Value *RegSaveAreaShadowPtr, *RegSaveAreaOriginPtr; + unsigned Alignment = 8; + std::tie(RegSaveAreaShadowPtr, RegSaveAreaOriginPtr) = + MSV.getShadowOriginPtr(RegSaveAreaPtr, IRB, IRB.getInt8Ty(), + Alignment); + IRB.CreateMemCpy(RegSaveAreaShadowPtr, VAArgTLSCopy, CopySize, Alignment); } } }; @@ -3633,27 +3787,27 @@ struct VarArgNoOpHelper : public VarArgHelper { void finalizeInstrumentation() override {} }; -VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, - MemorySanitizerVisitor &Visitor) { +} // end anonymous namespace + +static VarArgHelper *CreateVarArgHelper(Function &Func, MemorySanitizer &Msan, + MemorySanitizerVisitor &Visitor) { // VarArg handling is only implemented on AMD64. False positives are possible // on other platforms. - llvm::Triple TargetTriple(Func.getParent()->getTargetTriple()); - if (TargetTriple.getArch() == llvm::Triple::x86_64) + Triple TargetTriple(Func.getParent()->getTargetTriple()); + if (TargetTriple.getArch() == Triple::x86_64) return new VarArgAMD64Helper(Func, Msan, Visitor); - else if (TargetTriple.getArch() == llvm::Triple::mips64 || - TargetTriple.getArch() == llvm::Triple::mips64el) + else if (TargetTriple.getArch() == Triple::mips64 || + TargetTriple.getArch() == Triple::mips64el) return new VarArgMIPS64Helper(Func, Msan, Visitor); - else if (TargetTriple.getArch() == llvm::Triple::aarch64) + else if (TargetTriple.getArch() == Triple::aarch64) return new VarArgAArch64Helper(Func, Msan, Visitor); - else if (TargetTriple.getArch() == llvm::Triple::ppc64 || - TargetTriple.getArch() == llvm::Triple::ppc64le) + else if (TargetTriple.getArch() == Triple::ppc64 || + TargetTriple.getArch() == Triple::ppc64le) return new VarArgPowerPC64Helper(Func, Msan, Visitor); else return new VarArgNoOpHelper(Func, Msan, Visitor); } -} // anonymous namespace - bool MemorySanitizer::runOnFunction(Function &F) { if (&F == MsanCtorFunction) return false; diff --git a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp index 8e4bfc0b91bc..cb4b3a9c2545 100644 --- a/lib/Transforms/Instrumentation/PGOInstrumentation.cpp +++ b/lib/Transforms/Instrumentation/PGOInstrumentation.cpp @@ -1,4 +1,4 @@ -//===-- PGOInstrumentation.cpp - MST-based PGO Instrumentation ------------===// +//===- PGOInstrumentation.cpp - MST-based PGO Instrumentation -------------===// // // The LLVM Compiler Infrastructure // @@ -50,36 +50,69 @@ #include "llvm/Transforms/PGOInstrumentation.h" #include "CFGMST.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/ADT/Triple.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/IndirectCallSiteVisitor.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Comdat.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/ProfileSummary.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h" #include "llvm/ProfileData/InstrProfReader.h" -#include "llvm/ProfileData/ProfileCommon.h" #include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/DOTGraphTraits.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/Error.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GraphWriter.h" #include "llvm/Support/JamCRC.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Instrumentation.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include <algorithm> +#include <cassert> +#include <cstdint> +#include <memory> +#include <numeric> #include <string> #include <unordered_map> #include <utility> @@ -166,15 +199,18 @@ static cl::opt<bool> cl::desc("Use this option to turn on/off SELECT " "instruction instrumentation. ")); -// Command line option to turn on CFG dot dump of raw profile counts -static cl::opt<bool> - PGOViewRawCounts("pgo-view-raw-counts", cl::init(false), cl::Hidden, - cl::desc("A boolean option to show CFG dag " - "with raw profile counts from " - "profile data. See also option " - "-pgo-view-counts. To limit graph " - "display to only one function, use " - "filtering option -view-bfi-func-name.")); +// Command line option to turn on CFG dot or text dump of raw profile counts +static cl::opt<PGOViewCountsType> PGOViewRawCounts( + "pgo-view-raw-counts", cl::Hidden, + cl::desc("A boolean option to show CFG dag or text " + "with raw profile counts from " + "profile data. See also option " + "-pgo-view-counts. To limit graph " + "display to only one function, use " + "filtering option -view-bfi-func-name."), + cl::values(clEnumValN(PGOVCT_None, "none", "do not show."), + clEnumValN(PGOVCT_Graph, "graph", "show a graph."), + clEnumValN(PGOVCT_Text, "text", "show in text."))); // Command line option to enable/disable memop intrinsic call.size profiling. static cl::opt<bool> @@ -192,17 +228,15 @@ static cl::opt<bool> // Command line option to turn on CFG dot dump after profile annotation. // Defined in Analysis/BlockFrequencyInfo.cpp: -pgo-view-counts -extern cl::opt<bool> PGOViewCounts; +extern cl::opt<PGOViewCountsType> PGOViewCounts; // Command line option to specify the name of the function for CFG dump // Defined in Analysis/BlockFrequencyInfo.cpp: -view-bfi-func-name= extern cl::opt<std::string> ViewBlockFreqFuncName; -namespace { - // Return a string describing the branch condition that can be // used in static branch probability heuristics: -std::string getBranchCondString(Instruction *TI) { +static std::string getBranchCondString(Instruction *TI) { BranchInst *BI = dyn_cast<BranchInst>(TI); if (!BI || !BI->isConditional()) return std::string(); @@ -233,6 +267,8 @@ std::string getBranchCondString(Instruction *TI) { return result; } +namespace { + /// The select instruction visitor plays three roles specified /// by the mode. In \c VM_counting mode, it simply counts the number of /// select instructions. In \c VM_instrument mode, it inserts code to count @@ -259,6 +295,7 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { Mode = VM_counting; visit(Func); } + // Visit the IR stream and instrument all select instructions. \p // Ind is a pointer to the counter index variable; \p TotalNC // is the total number of counters; \p FNV is the pointer to the @@ -283,8 +320,10 @@ struct SelectInstVisitor : public InstVisitor<SelectInstVisitor> { void instrumentOneSelectInst(SelectInst &SI); void annotateOneSelectInst(SelectInst &SI); + // Visit \p SI instruction and perform tasks according to visit mode. void visitSelectInst(SelectInst &SI); + // Return the number of select instructions. This needs be called after // countSelects(). unsigned getNumOfSelectInsts() const { return NSIs; } @@ -328,8 +367,10 @@ struct MemIntrinsicVisitor : public InstVisitor<MemIntrinsicVisitor> { // Visit the IR stream and annotate all mem intrinsic call instructions. void instrumentOneMemIntrinsic(MemIntrinsic &MI); + // Visit \p MI instruction and perform tasks according to visit mode. void visitMemIntrinsic(MemIntrinsic &SI); + unsigned getNumOfMemIntrinsics() const { return NMemIs; } }; @@ -371,6 +412,7 @@ private: std::string ProfileFileName; bool runOnModule(Module &M) override; + void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<BlockFrequencyInfoWrapperPass>(); } @@ -379,6 +421,7 @@ private: } // end anonymous namespace char PGOInstrumentationGenLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(PGOInstrumentationGenLegacyPass, "pgo-instr-gen", "PGO instrumentation.", false, false) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) @@ -391,6 +434,7 @@ ModulePass *llvm::createPGOInstrumentationGenLegacyPass() { } char PGOInstrumentationUseLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(PGOInstrumentationUseLegacyPass, "pgo-instr-use", "Read PGO instrumentation profile.", false, false) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) @@ -403,6 +447,7 @@ ModulePass *llvm::createPGOInstrumentationUseLegacyPass(StringRef Filename) { } namespace { + /// \brief An MST based instrumentation for PGO /// /// Implements a Minimum Spanning Tree (MST) based instrumentation for PGO @@ -413,12 +458,13 @@ struct PGOEdge { const BasicBlock *SrcBB; const BasicBlock *DestBB; uint64_t Weight; - bool InMST; - bool Removed; - bool IsCritical; - PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1) - : SrcBB(Src), DestBB(Dest), Weight(W), InMST(false), Removed(false), - IsCritical(false) {} + bool InMST = false; + bool Removed = false; + bool IsCritical = false; + + PGOEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W = 1) + : SrcBB(Src), DestBB(Dest), Weight(W) {} + // Return the information string of an edge. const std::string infoString() const { return (Twine(Removed ? "-" : " ") + (InMST ? " " : "*") + @@ -430,9 +476,9 @@ struct PGOEdge { struct BBInfo { BBInfo *Group; uint32_t Index; - uint32_t Rank; + uint32_t Rank = 0; - BBInfo(unsigned IX) : Group(this), Index(IX), Rank(0) {} + BBInfo(unsigned IX) : Group(this), Index(IX) {} // Return the information string of this object. const std::string infoString() const { @@ -444,19 +490,22 @@ struct BBInfo { template <class Edge, class BBInfo> class FuncPGOInstrumentation { private: Function &F; - void computeCFGHash(); - void renameComdatFunction(); + // A map that stores the Comdat group in function F. std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers; + void computeCFGHash(); + void renameComdatFunction(); + public: std::vector<std::vector<Instruction *>> ValueSites; SelectInstVisitor SIVisitor; MemIntrinsicVisitor MIVisitor; std::string FuncName; GlobalVariable *FuncNameVar; + // CFG hash value for this function. - uint64_t FunctionHash; + uint64_t FunctionHash = 0; // The Minimum Spanning Tree of function CFG. CFGMST<Edge, BBInfo> MST; @@ -483,8 +532,7 @@ public: bool CreateGlobalVar = false, BranchProbabilityInfo *BPI = nullptr, BlockFrequencyInfo *BFI = nullptr) : F(Func), ComdatMembers(ComdatMembers), ValueSites(IPVK_Last + 1), - SIVisitor(Func), MIVisitor(Func), FunctionHash(0), MST(F, BPI, BFI) { - + SIVisitor(Func), MIVisitor(Func), MST(F, BPI, BFI) { // This should be done before CFG hash computation. SIVisitor.countSelects(Func); MIVisitor.countMemIntrinsics(Func); @@ -495,7 +543,7 @@ public: FuncName = getPGOFuncName(F); computeCFGHash(); - if (ComdatMembers.size()) + if (!ComdatMembers.empty()) renameComdatFunction(); DEBUG(dumpInfo("after CFGMST")); @@ -523,6 +571,8 @@ public: } }; +} // end anonymous namespace + // Compute Hash value for the CFG: the lower 32 bits are CRC32 of the index // value of each BB in the CFG. The higher 32 bits record the number of edges. template <class Edge, class BBInfo> @@ -545,6 +595,12 @@ void FuncPGOInstrumentation<Edge, BBInfo>::computeCFGHash() { FunctionHash = (uint64_t)SIVisitor.getNumOfSelectInsts() << 56 | (uint64_t)ValueSites[IPVK_IndirectCallTarget].size() << 48 | (uint64_t)MST.AllEdges.size() << 32 | JC.getCRC(); + DEBUG(dbgs() << "Function Hash Computation for " << F.getName() << ":\n" + << " CRC = " << JC.getCRC() + << ", Selects = " << SIVisitor.getNumOfSelectInsts() + << ", Edges = " << MST.AllEdges.size() + << ", ICSites = " << ValueSites[IPVK_IndirectCallTarget].size() + << ", Hash = " << FunctionHash << "\n";); } // Check if we can safely rename this Comdat function. @@ -660,6 +716,9 @@ BasicBlock *FuncPGOInstrumentation<Edge, BBInfo>::getInstrBB(Edge *E) { static void instrumentOneFunc( Function &F, Module *M, BranchProbabilityInfo *BPI, BlockFrequencyInfo *BFI, std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers) { + // Split indirectbr critical edges here before computing the MST rather than + // later in getInstrBB() to avoid invalidating it. + SplitIndirectBrCriticalEdges(F, BPI, BFI); FuncPGOInstrumentation<PGOEdge, BBInfo> FuncInfo(F, ComdatMembers, true, BPI, BFI); unsigned NumCounters = FuncInfo.getNumCounters(); @@ -676,7 +735,7 @@ static void instrumentOneFunc( "Cannot get the Instrumentation point"); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment), - {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), + {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), Builder.getInt64(FuncInfo.FunctionHash), Builder.getInt32(NumCounters), Builder.getInt32(I++)}); } @@ -700,7 +759,7 @@ static void instrumentOneFunc( "Cannot get the Instrumentation point"); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), - {llvm::ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), + {ConstantExpr::getBitCast(FuncInfo.FuncNameVar, I8PtrTy), Builder.getInt64(FuncInfo.FunctionHash), Builder.CreatePtrToInt(Callee, Builder.getInt64Ty()), Builder.getInt32(IPVK_IndirectCallTarget), @@ -713,12 +772,15 @@ static void instrumentOneFunc( F, NumCounters, FuncInfo.FuncNameVar, FuncInfo.FunctionHash); } +namespace { + // This class represents a CFG edge in profile use compilation. struct PGOUseEdge : public PGOEdge { - bool CountValid; - uint64_t CountValue; - PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, unsigned W = 1) - : PGOEdge(Src, Dest, W), CountValid(false), CountValue(0) {} + bool CountValid = false; + uint64_t CountValue = 0; + + PGOUseEdge(const BasicBlock *Src, const BasicBlock *Dest, uint64_t W = 1) + : PGOEdge(Src, Dest, W) {} // Set edge count value void setEdgeCount(uint64_t Value) { @@ -735,22 +797,21 @@ struct PGOUseEdge : public PGOEdge { } }; -typedef SmallVector<PGOUseEdge *, 2> DirectEdges; +using DirectEdges = SmallVector<PGOUseEdge *, 2>; // This class stores the auxiliary information for each BB. struct UseBBInfo : public BBInfo { - uint64_t CountValue; + uint64_t CountValue = 0; bool CountValid; - int32_t UnknownCountInEdge; - int32_t UnknownCountOutEdge; + int32_t UnknownCountInEdge = 0; + int32_t UnknownCountOutEdge = 0; DirectEdges InEdges; DirectEdges OutEdges; - UseBBInfo(unsigned IX) - : BBInfo(IX), CountValue(0), CountValid(false), UnknownCountInEdge(0), - UnknownCountOutEdge(0) {} + + UseBBInfo(unsigned IX) : BBInfo(IX), CountValid(false) {} + UseBBInfo(unsigned IX, uint64_t C) - : BBInfo(IX), CountValue(C), CountValid(true), UnknownCountInEdge(0), - UnknownCountOutEdge(0) {} + : BBInfo(IX), CountValue(C), CountValid(true) {} // Set the profile count value for this BB. void setBBInfoCount(uint64_t Value) { @@ -766,6 +827,8 @@ struct UseBBInfo : public BBInfo { } }; +} // end anonymous namespace + // Sum up the count values for all the edges. static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) { uint64_t Total = 0; @@ -777,14 +840,17 @@ static uint64_t sumEdgeCount(const ArrayRef<PGOUseEdge *> Edges) { return Total; } +namespace { + class PGOUseFunc { public: PGOUseFunc(Function &Func, Module *Modu, std::unordered_multimap<Comdat *, GlobalValue *> &ComdatMembers, BranchProbabilityInfo *BPI = nullptr, - BlockFrequencyInfo *BFI = nullptr) - : F(Func), M(Modu), FuncInfo(Func, ComdatMembers, false, BPI, BFI), - CountPosition(0), ProfileCountSize(0), FreqAttr(FFA_Normal) {} + BlockFrequencyInfo *BFIin = nullptr) + : F(Func), M(Modu), BFI(BFIin), + FuncInfo(Func, ComdatMembers, false, BPI, BFIin), + FreqAttr(FFA_Normal) {} // Read counts for the instrumented BB from profile. bool readCounters(IndexedInstrProfReader *PGOReader); @@ -801,6 +867,9 @@ public: // Annotate the value profile call sites for one value kind. void annotateValueSites(uint32_t Kind); + // Annotate the irreducible loop header weights. + void annotateIrrLoopHeaderWeights(); + // The hotness of the function from the profile count. enum FuncFreqAttr { FFA_Normal, FFA_Cold, FFA_Hot }; @@ -809,6 +878,7 @@ public: // Return the function hash. uint64_t getFuncHash() const { return FuncInfo.FunctionHash; } + // Return the profile record for this function; InstrProfRecord &getProfileRecord() { return ProfileRecord; } @@ -824,9 +894,15 @@ public: Function &getFunc() const { return F; } + void dumpInfo(std::string Str = "") const { + FuncInfo.dumpInfo(Str); + } + private: Function &F; Module *M; + BlockFrequencyInfo *BFI; + // This member stores the shared information with class PGOGenFunc. FuncPGOInstrumentation<PGOUseEdge, UseBBInfo> FuncInfo; @@ -835,10 +911,10 @@ private: uint64_t ProgramMaxCount; // Position of counter that remains to be read. - uint32_t CountPosition; + uint32_t CountPosition = 0; // Total size of the profile count for this function. - uint32_t ProfileCountSize; + uint32_t ProfileCountSize = 0; // ProfileRecord for this function. InstrProfRecord ProfileRecord; @@ -873,11 +949,12 @@ private: } }; +} // end anonymous namespace + // Visit all the edges and assign the count value for the instrumented // edges and the BB. void PGOUseFunc::setInstrumentedCounts( const std::vector<uint64_t> &CountFromProfile) { - assert(FuncInfo.getNumCounters() == CountFromProfile.size()); // Use a worklist as we will update the vector during the iteration. std::vector<PGOUseEdge *> WorkList; @@ -1087,7 +1164,8 @@ void PGOUseFunc::setBranchWeights() { TerminatorInst *TI = BB.getTerminator(); if (TI->getNumSuccessors() < 2) continue; - if (!isa<BranchInst>(TI) && !isa<SwitchInst>(TI)) + if (!(isa<BranchInst>(TI) || isa<SwitchInst>(TI) || + isa<IndirectBrInst>(TI))) continue; if (getBBInfo(&BB).CountValue == 0) continue; @@ -1113,6 +1191,29 @@ void PGOUseFunc::setBranchWeights() { } } +static bool isIndirectBrTarget(BasicBlock *BB) { + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + if (isa<IndirectBrInst>((*PI)->getTerminator())) + return true; + } + return false; +} + +void PGOUseFunc::annotateIrrLoopHeaderWeights() { + DEBUG(dbgs() << "\nAnnotating irreducible loop header weights.\n"); + // Find irr loop headers + for (auto &BB : F) { + // As a heuristic also annotate indrectbr targets as they have a high chance + // to become an irreducible loop header after the indirectbr tail + // duplication. + if (BFI->isIrrLoopHeader(&BB) || isIndirectBrTarget(&BB)) { + TerminatorInst *TI = BB.getTerminator(); + const UseBBInfo &BBCountInfo = getBBInfo(&BB); + setIrrLoopHeaderMetadata(M, TI, BBCountInfo.CountValue); + } + } +} + void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { Module *M = F.getParent(); IRBuilder<> Builder(&SI); @@ -1121,7 +1222,7 @@ void SelectInstVisitor::instrumentOneSelectInst(SelectInst &SI) { auto *Step = Builder.CreateZExt(SI.getCondition(), Int64Ty); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_increment_step), - {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + {ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), Builder.getInt64(FuncHash), Builder.getInt32(TotalNumCtrs), Builder.getInt32(*CurCtrIdx), Step}); ++(*CurCtrIdx); @@ -1176,7 +1277,7 @@ void MemIntrinsicVisitor::instrumentOneMemIntrinsic(MemIntrinsic &MI) { assert(!dyn_cast<ConstantInt>(Length)); Builder.CreateCall( Intrinsic::getDeclaration(M, Intrinsic::instrprof_value_profile), - {llvm::ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), + {ConstantExpr::getBitCast(FuncNameVar, I8PtrTy), Builder.getInt64(FuncHash), Builder.CreateZExtOrTrunc(Length, Int64Ty), Builder.getInt32(IPVK_MemOPSize), Builder.getInt32(CurCtrId)}); ++CurCtrId; @@ -1242,7 +1343,6 @@ void PGOUseFunc::annotateValueSites(uint32_t Kind) { ValueSiteIndex++; } } -} // end anonymous namespace // Create a COMDAT variable INSTR_PROF_RAW_VERSION_VAR to make the runtime // aware this is an ir_level profile so it can set the version flag. @@ -1312,7 +1412,6 @@ bool PGOInstrumentationGenLegacyPass::runOnModule(Module &M) { PreservedAnalyses PGOInstrumentationGen::run(Module &M, ModuleAnalysisManager &AM) { - auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); auto LookupBPI = [&FAM](Function &F) { return &FAM.getResult<BranchProbabilityAnalysis>(F); @@ -1367,33 +1466,48 @@ static bool annotateAllFunctions( continue; auto *BPI = LookupBPI(F); auto *BFI = LookupBFI(F); + // Split indirectbr critical edges here before computing the MST rather than + // later in getInstrBB() to avoid invalidating it. + SplitIndirectBrCriticalEdges(F, BPI, BFI); PGOUseFunc Func(F, &M, ComdatMembers, BPI, BFI); if (!Func.readCounters(PGOReader.get())) continue; Func.populateCounters(); Func.setBranchWeights(); Func.annotateValueSites(); + Func.annotateIrrLoopHeaderWeights(); PGOUseFunc::FuncFreqAttr FreqAttr = Func.getFuncFreqAttr(); if (FreqAttr == PGOUseFunc::FFA_Cold) ColdFunctions.push_back(&F); else if (FreqAttr == PGOUseFunc::FFA_Hot) HotFunctions.push_back(&F); - if (PGOViewCounts && (ViewBlockFreqFuncName.empty() || - F.getName().equals(ViewBlockFreqFuncName))) { + if (PGOViewCounts != PGOVCT_None && + (ViewBlockFreqFuncName.empty() || + F.getName().equals(ViewBlockFreqFuncName))) { LoopInfo LI{DominatorTree(F)}; std::unique_ptr<BranchProbabilityInfo> NewBPI = llvm::make_unique<BranchProbabilityInfo>(F, LI); std::unique_ptr<BlockFrequencyInfo> NewBFI = llvm::make_unique<BlockFrequencyInfo>(F, *NewBPI, LI); - - NewBFI->view(); + if (PGOViewCounts == PGOVCT_Graph) + NewBFI->view(); + else if (PGOViewCounts == PGOVCT_Text) { + dbgs() << "pgo-view-counts: " << Func.getFunc().getName() << "\n"; + NewBFI->print(dbgs()); + } } - if (PGOViewRawCounts && (ViewBlockFreqFuncName.empty() || - F.getName().equals(ViewBlockFreqFuncName))) { - if (ViewBlockFreqFuncName.empty()) - WriteGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); - else - ViewGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); + if (PGOViewRawCounts != PGOVCT_None && + (ViewBlockFreqFuncName.empty() || + F.getName().equals(ViewBlockFreqFuncName))) { + if (PGOViewRawCounts == PGOVCT_Graph) + if (ViewBlockFreqFuncName.empty()) + WriteGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); + else + ViewGraph(&Func, Twine("PGORawCounts_") + Func.getFunc().getName()); + else if (PGOViewRawCounts == PGOVCT_Text) { + dbgs() << "pgo-view-raw-counts: " << Func.getFunc().getName() << "\n"; + Func.dumpInfo(); + } } } M.setProfileSummary(PGOReader->getSummary().getMD(M.getContext())); @@ -1402,12 +1516,12 @@ static bool annotateAllFunctions( // can affect the BranchProbabilityInfo of any callers, resulting in an // inconsistent MST between prof-gen and prof-use. for (auto &F : HotFunctions) { - F->addFnAttr(llvm::Attribute::InlineHint); + F->addFnAttr(Attribute::InlineHint); DEBUG(dbgs() << "Set inline attribute to function: " << F->getName() << "\n"); } for (auto &F : ColdFunctions) { - F->addFnAttr(llvm::Attribute::Cold); + F->addFnAttr(Attribute::Cold); DEBUG(dbgs() << "Set cold attribute to function: " << F->getName() << "\n"); } return true; @@ -1451,9 +1565,19 @@ bool PGOInstrumentationUseLegacyPass::runOnModule(Module &M) { return annotateAllFunctions(M, ProfileFileName, LookupBPI, LookupBFI); } -namespace llvm { -void setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts, - uint64_t MaxCount) { +static std::string getSimpleNodeName(const BasicBlock *Node) { + if (!Node->getName().empty()) + return Node->getName(); + + std::string SimpleNodeName; + raw_string_ostream OS(SimpleNodeName); + Node->printAsOperand(OS, false); + return OS.str(); +} + +void llvm::setProfMetadata(Module *M, Instruction *TI, + ArrayRef<uint64_t> EdgeCounts, + uint64_t MaxCount) { MDBuilder MDB(M->getContext()); assert(MaxCount > 0 && "Bad max count"); uint64_t Scale = calculateCountScale(MaxCount); @@ -1464,7 +1588,7 @@ void setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts, DEBUG(dbgs() << "Weight is: "; for (const auto &W : Weights) { dbgs() << W << " "; } dbgs() << "\n";); - TI->setMetadata(llvm::LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); + TI->setMetadata(LLVMContext::MD_prof, MDB.createBranchWeights(Weights)); if (EmitBranchProbability) { std::string BrCondStr = getBranchCondString(TI); if (BrCondStr.empty()) @@ -1483,43 +1607,46 @@ void setProfMetadata(Module *M, Instruction *TI, ArrayRef<uint64_t> EdgeCounts, OS << " (total count : " << TotalCount << ")"; OS.flush(); Function *F = TI->getParent()->getParent(); - emitOptimizationRemarkAnalysis( - F->getContext(), "pgo-use-annot", *F, TI->getDebugLoc(), - Twine(BrCondStr) + - " is true with probability : " + Twine(BranchProbStr)); + OptimizationRemarkEmitter ORE(F); + ORE.emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "pgo-instrumentation", TI) + << BrCondStr << " is true with probability : " << BranchProbStr; + }); } } +namespace llvm { + +void setIrrLoopHeaderMetadata(Module *M, Instruction *TI, uint64_t Count) { + MDBuilder MDB(M->getContext()); + TI->setMetadata(llvm::LLVMContext::MD_irr_loop, + MDB.createIrrLoopHeaderWeight(Count)); +} + template <> struct GraphTraits<PGOUseFunc *> { - typedef const BasicBlock *NodeRef; - typedef succ_const_iterator ChildIteratorType; - typedef pointer_iterator<Function::const_iterator> nodes_iterator; + using NodeRef = const BasicBlock *; + using ChildIteratorType = succ_const_iterator; + using nodes_iterator = pointer_iterator<Function::const_iterator>; static NodeRef getEntryNode(const PGOUseFunc *G) { return &G->getFunc().front(); } + static ChildIteratorType child_begin(const NodeRef N) { return succ_begin(N); } + static ChildIteratorType child_end(const NodeRef N) { return succ_end(N); } + static nodes_iterator nodes_begin(const PGOUseFunc *G) { return nodes_iterator(G->getFunc().begin()); } + static nodes_iterator nodes_end(const PGOUseFunc *G) { return nodes_iterator(G->getFunc().end()); } }; -static std::string getSimpleNodeName(const BasicBlock *Node) { - if (!Node->getName().empty()) - return Node->getName(); - - std::string SimpleNodeName; - raw_string_ostream OS(SimpleNodeName); - Node->printAsOperand(OS, false); - return OS.str(); -} - template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits { explicit DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} @@ -1559,4 +1686,5 @@ template <> struct DOTGraphTraits<PGOUseFunc *> : DefaultDOTGraphTraits { return Result; } }; -} // namespace llvm + +} // end namespace llvm diff --git a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp index 0bc9ddfbe4d3..95eb3680403a 100644 --- a/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp +++ b/lib/Transforms/Instrumentation/PGOMemOPSizeOpt.cpp @@ -21,16 +21,16 @@ #include "llvm/ADT/Twine.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstVisitor.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/InstVisitor.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" @@ -110,6 +110,7 @@ private: bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<BlockFrequencyInfoWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } }; @@ -131,8 +132,9 @@ FunctionPass *llvm::createPGOMemOPSizeOptLegacyPass() { namespace { class MemOPSizeOpt : public InstVisitor<MemOPSizeOpt> { public: - MemOPSizeOpt(Function &Func, BlockFrequencyInfo &BFI) - : Func(Func), BFI(BFI), Changed(false) { + MemOPSizeOpt(Function &Func, BlockFrequencyInfo &BFI, + OptimizationRemarkEmitter &ORE) + : Func(Func), BFI(BFI), ORE(ORE), Changed(false) { ValueDataArray = llvm::make_unique<InstrProfValueData[]>(MemOPMaxVersion + 2); // Get the MemOPSize range information from option MemOPSizeRange, @@ -166,6 +168,7 @@ public: private: Function &Func; BlockFrequencyInfo &BFI; + OptimizationRemarkEmitter &ORE; bool Changed; std::vector<MemIntrinsic *> WorkList; // Start of the previse range. @@ -358,12 +361,15 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { DEBUG(dbgs() << "\n\n== Basic Block After==\n"); for (uint64_t SizeId : SizeIds) { - ConstantInt *CaseSizeId = ConstantInt::get(Type::getInt64Ty(Ctx), SizeId); BasicBlock *CaseBB = BasicBlock::Create( Ctx, Twine("MemOP.Case.") + Twine(SizeId), &Func, DefaultBB); Instruction *NewInst = MI->clone(); // Fix the argument. - dyn_cast<MemIntrinsic>(NewInst)->setLength(CaseSizeId); + MemIntrinsic * MemI = dyn_cast<MemIntrinsic>(NewInst); + IntegerType *SizeType = dyn_cast<IntegerType>(MemI->getLength()->getType()); + assert(SizeType && "Expected integer type size argument."); + ConstantInt *CaseSizeId = ConstantInt::get(SizeType, SizeId); + MemI->setLength(CaseSizeId); CaseBB->getInstList().push_back(NewInst); IRBuilder<> IRBCase(CaseBB); IRBCase.CreateBr(MergeBB); @@ -376,23 +382,27 @@ bool MemOPSizeOpt::perform(MemIntrinsic *MI) { DEBUG(dbgs() << *DefaultBB << "\n"); DEBUG(dbgs() << *MergeBB << "\n"); - emitOptimizationRemark(Func.getContext(), "memop-opt", Func, - MI->getDebugLoc(), - Twine("optimize ") + getMIName(MI) + " with count " + - Twine(SumForOpt) + " out of " + Twine(TotalCount) + - " for " + Twine(Version) + " versions"); + ORE.emit([&]() { + using namespace ore; + return OptimizationRemark(DEBUG_TYPE, "memopt-opt", MI) + << "optimized " << NV("Intrinsic", StringRef(getMIName(MI))) + << " with count " << NV("Count", SumForOpt) << " out of " + << NV("Total", TotalCount) << " for " << NV("Versions", Version) + << " versions"; + }); return true; } } // namespace -static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI) { +static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI, + OptimizationRemarkEmitter &ORE) { if (DisableMemOPOPT) return false; if (F.hasFnAttribute(Attribute::OptimizeForSize)) return false; - MemOPSizeOpt MemOPSizeOpt(F, BFI); + MemOPSizeOpt MemOPSizeOpt(F, BFI, ORE); MemOPSizeOpt.perform(); return MemOPSizeOpt.isChanged(); } @@ -400,7 +410,8 @@ static bool PGOMemOPSizeOptImpl(Function &F, BlockFrequencyInfo &BFI) { bool PGOMemOPSizeOptLegacyPass::runOnFunction(Function &F) { BlockFrequencyInfo &BFI = getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(); - return PGOMemOPSizeOptImpl(F, BFI); + auto &ORE = getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + return PGOMemOPSizeOptImpl(F, BFI, ORE); } namespace llvm { @@ -409,7 +420,8 @@ char &PGOMemOPSizeOptID = PGOMemOPSizeOptLegacyPass::ID; PreservedAnalyses PGOMemOPSizeOpt::run(Function &F, FunctionAnalysisManager &FAM) { auto &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); - bool Changed = PGOMemOPSizeOptImpl(F, BFI); + auto &ORE = FAM.getResult<OptimizationRemarkEmitterAnalysis>(F); + bool Changed = PGOMemOPSizeOptImpl(F, BFI, ORE); if (!Changed) return PreservedAnalyses::all(); auto PA = PreservedAnalyses(); diff --git a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp index 06fe07598374..d950e2e730f2 100644 --- a/lib/Transforms/Instrumentation/SanitizerCoverage.cpp +++ b/lib/Transforms/Instrumentation/SanitizerCoverage.cpp @@ -17,12 +17,16 @@ #include "llvm/Analysis/PostDominators.h" #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DebugInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" @@ -46,6 +50,14 @@ static const char *const SanCovTraceCmp1 = "__sanitizer_cov_trace_cmp1"; static const char *const SanCovTraceCmp2 = "__sanitizer_cov_trace_cmp2"; static const char *const SanCovTraceCmp4 = "__sanitizer_cov_trace_cmp4"; static const char *const SanCovTraceCmp8 = "__sanitizer_cov_trace_cmp8"; +static const char *const SanCovTraceConstCmp1 = + "__sanitizer_cov_trace_const_cmp1"; +static const char *const SanCovTraceConstCmp2 = + "__sanitizer_cov_trace_const_cmp2"; +static const char *const SanCovTraceConstCmp4 = + "__sanitizer_cov_trace_const_cmp4"; +static const char *const SanCovTraceConstCmp8 = + "__sanitizer_cov_trace_const_cmp8"; static const char *const SanCovTraceDiv4 = "__sanitizer_cov_trace_div4"; static const char *const SanCovTraceDiv8 = "__sanitizer_cov_trace_div8"; static const char *const SanCovTraceGep = "__sanitizer_cov_trace_gep"; @@ -57,11 +69,15 @@ static const char *const SanCovTracePCGuardName = "__sanitizer_cov_trace_pc_guard"; static const char *const SanCovTracePCGuardInitName = "__sanitizer_cov_trace_pc_guard_init"; -static const char *const SanCov8bitCountersInitName = +static const char *const SanCov8bitCountersInitName = "__sanitizer_cov_8bit_counters_init"; +static const char *const SanCovPCsInitName = "__sanitizer_cov_pcs_init"; static const char *const SanCovGuardsSectionName = "sancov_guards"; static const char *const SanCovCountersSectionName = "sancov_cntrs"; +static const char *const SanCovPCsSectionName = "sancov_pcs"; + +static const char *const SanCovLowestStackName = "__sancov_lowest_stack"; static cl::opt<int> ClCoverageLevel( "sanitizer-coverage-level", @@ -77,9 +93,19 @@ static cl::opt<bool> ClTracePCGuard("sanitizer-coverage-trace-pc-guard", cl::desc("pc tracing with a guard"), cl::Hidden, cl::init(false)); -static cl::opt<bool> ClInline8bitCounters("sanitizer-coverage-inline-8bit-counters", - cl::desc("increments 8-bit counter for every edge"), - cl::Hidden, cl::init(false)); +// If true, we create a global variable that contains PCs of all instrumented +// BBs, put this global into a named section, and pass this section's bounds +// to __sanitizer_cov_pcs_init. +// This way the coverage instrumentation does not need to acquire the PCs +// at run-time. Works with trace-pc-guard and inline-8bit-counters. +static cl::opt<bool> ClCreatePCTable("sanitizer-coverage-pc-table", + cl::desc("create a static PC table"), + cl::Hidden, cl::init(false)); + +static cl::opt<bool> + ClInline8bitCounters("sanitizer-coverage-inline-8bit-counters", + cl::desc("increments 8-bit counter for every edge"), + cl::Hidden, cl::init(false)); static cl::opt<bool> ClCMPTracing("sanitizer-coverage-trace-compares", @@ -99,6 +125,10 @@ static cl::opt<bool> cl::desc("Reduce the number of instrumented blocks"), cl::Hidden, cl::init(true)); +static cl::opt<bool> ClStackDepth("sanitizer-coverage-stack-depth", + cl::desc("max stack depth tracing"), + cl::Hidden, cl::init(false)); + namespace { SanitizerCoverageOptions getOptions(int LegacyCoverageLevel) { @@ -135,9 +165,12 @@ SanitizerCoverageOptions OverrideFromCL(SanitizerCoverageOptions Options) { Options.TracePC |= ClTracePC; Options.TracePCGuard |= ClTracePCGuard; Options.Inline8bitCounters |= ClInline8bitCounters; - if (!Options.TracePCGuard && !Options.TracePC && !Options.Inline8bitCounters) - Options.TracePCGuard = true; // TracePCGuard is default. + Options.PCTable |= ClCreatePCTable; Options.NoPrune |= !ClPruneBlocks; + Options.StackDepth |= ClStackDepth; + if (!Options.TracePCGuard && !Options.TracePC && + !Options.Inline8bitCounters && !Options.StackDepth) + Options.TracePCGuard = true; // TracePCGuard is default. return Options; } @@ -168,14 +201,19 @@ private: ArrayRef<GetElementPtrInst *> GepTraceTargets); void InjectTraceForSwitch(Function &F, ArrayRef<Instruction *> SwitchTraceTargets); - bool InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks); + bool InjectCoverage(Function &F, ArrayRef<BasicBlock *> AllBlocks, + bool IsLeafFunc = true); GlobalVariable *CreateFunctionLocalArrayInSection(size_t NumElements, Function &F, Type *Ty, const char *Section); - void CreateFunctionLocalArrays(size_t NumGuards, Function &F); - void InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx); - void CreateInitCallForSection(Module &M, const char *InitFunctionName, - Type *Ty, const std::string &Section); + GlobalVariable *CreatePCArray(Function &F, ArrayRef<BasicBlock *> AllBlocks); + void CreateFunctionLocalArrays(Function &F, ArrayRef<BasicBlock *> AllBlocks); + void InjectCoverageAtBlock(Function &F, BasicBlock &BB, size_t Idx, + bool IsLeafFunc = true); + Function *CreateInitCallsForSections(Module &M, const char *InitFunctionName, + Type *Ty, const char *Section); + std::pair<GlobalVariable *, GlobalVariable *> + CreateSecStartEnd(Module &M, const char *Section, Type *Ty); void SetNoSanitizeMetadata(Instruction *I) { I->setMetadata(I->getModule()->getMDKindID("nosanitize"), @@ -188,12 +226,14 @@ private: Function *SanCovTracePCIndir; Function *SanCovTracePC, *SanCovTracePCGuard; Function *SanCovTraceCmpFunction[4]; + Function *SanCovTraceConstCmpFunction[4]; Function *SanCovTraceDivFunction[2]; Function *SanCovTraceGepFunction; Function *SanCovTraceSwitchFunction; + GlobalVariable *SanCovLowestStack; InlineAsm *EmptyAsm; Type *IntptrTy, *IntptrPtrTy, *Int64Ty, *Int64PtrTy, *Int32Ty, *Int32PtrTy, - *Int8Ty, *Int8PtrTy; + *Int16Ty, *Int8Ty, *Int8PtrTy; Module *CurModule; Triple TargetTriple; LLVMContext *C; @@ -201,17 +241,17 @@ private: GlobalVariable *FunctionGuardArray; // for trace-pc-guard. GlobalVariable *Function8bitCounterArray; // for inline-8bit-counters. + GlobalVariable *FunctionPCsArray; // for pc-table. + SmallVector<GlobalValue *, 20> GlobalsToAppendToUsed; SanitizerCoverageOptions Options; }; } // namespace -void SanitizerCoverageModule::CreateInitCallForSection( - Module &M, const char *InitFunctionName, Type *Ty, - const std::string &Section) { - IRBuilder<> IRB(M.getContext()); - Function *CtorFunc; +std::pair<GlobalVariable *, GlobalVariable *> +SanitizerCoverageModule::CreateSecStartEnd(Module &M, const char *Section, + Type *Ty) { GlobalVariable *SecStart = new GlobalVariable(M, Ty, false, GlobalVariable::ExternalLinkage, nullptr, getSectionStart(Section)); @@ -221,6 +261,18 @@ void SanitizerCoverageModule::CreateInitCallForSection( nullptr, getSectionEnd(Section)); SecEnd->setVisibility(GlobalValue::HiddenVisibility); + return std::make_pair(SecStart, SecEnd); +} + + +Function *SanitizerCoverageModule::CreateInitCallsForSections( + Module &M, const char *InitFunctionName, Type *Ty, + const char *Section) { + IRBuilder<> IRB(M.getContext()); + auto SecStartEnd = CreateSecStartEnd(M, Section, Ty); + auto SecStart = SecStartEnd.first; + auto SecEnd = SecStartEnd.second; + Function *CtorFunc; std::tie(CtorFunc, std::ignore) = createSanitizerCtorAndInitFunctions( M, SanCovModuleCtorName, InitFunctionName, {Ty, Ty}, {IRB.CreatePointerCast(SecStart, Ty), IRB.CreatePointerCast(SecEnd, Ty)}); @@ -232,6 +284,7 @@ void SanitizerCoverageModule::CreateInitCallForSection( } else { appendToGlobalCtors(M, CtorFunc, SanCtorAndDtorPriority); } + return CtorFunc; } bool SanitizerCoverageModule::runOnModule(Module &M) { @@ -243,6 +296,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { TargetTriple = Triple(M.getTargetTriple()); FunctionGuardArray = nullptr; Function8bitCounterArray = nullptr; + FunctionPCsArray = nullptr; IntptrTy = Type::getIntNTy(*C, DL->getPointerSizeInBits()); IntptrPtrTy = PointerType::getUnqual(IntptrTy); Type *VoidTy = Type::getVoidTy(*C); @@ -252,6 +306,7 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { Int8PtrTy = PointerType::getUnqual(IRB.getInt8Ty()); Int64Ty = IRB.getInt64Ty(); Int32Ty = IRB.getInt32Ty(); + Int16Ty = IRB.getInt16Ty(); Int8Ty = IRB.getInt8Ty(); SanCovTracePCIndir = checkSanitizerInterfaceFunction( @@ -269,6 +324,19 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { checkSanitizerInterfaceFunction(M.getOrInsertFunction( SanCovTraceCmp8, VoidTy, Int64Ty, Int64Ty)); + SanCovTraceConstCmpFunction[0] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceConstCmp1, VoidTy, Int8Ty, Int8Ty)); + SanCovTraceConstCmpFunction[1] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceConstCmp2, VoidTy, Int16Ty, Int16Ty)); + SanCovTraceConstCmpFunction[2] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceConstCmp4, VoidTy, Int32Ty, Int32Ty)); + SanCovTraceConstCmpFunction[3] = + checkSanitizerInterfaceFunction(M.getOrInsertFunction( + SanCovTraceConstCmp8, VoidTy, Int64Ty, Int64Ty)); + SanCovTraceDivFunction[0] = checkSanitizerInterfaceFunction(M.getOrInsertFunction( SanCovTraceDiv4, VoidTy, IRB.getInt32Ty())); @@ -281,12 +349,23 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { SanCovTraceSwitchFunction = checkSanitizerInterfaceFunction(M.getOrInsertFunction( SanCovTraceSwitchName, VoidTy, Int64Ty, Int64PtrTy)); + + Constant *SanCovLowestStackConstant = + M.getOrInsertGlobal(SanCovLowestStackName, IntptrTy); + SanCovLowestStack = cast<GlobalVariable>(SanCovLowestStackConstant); + SanCovLowestStack->setThreadLocalMode( + GlobalValue::ThreadLocalMode::InitialExecTLSModel); + if (Options.StackDepth && !SanCovLowestStack->isDeclaration()) + SanCovLowestStack->setInitializer(Constant::getAllOnesValue(IntptrTy)); + // Make sure smaller parameters are zero-extended to i64 as required by the // x86_64 ABI. if (TargetTriple.getArch() == Triple::x86_64) { for (int i = 0; i < 3; i++) { SanCovTraceCmpFunction[i]->addParamAttr(0, Attribute::ZExt); SanCovTraceCmpFunction[i]->addParamAttr(1, Attribute::ZExt); + SanCovTraceConstCmpFunction[i]->addParamAttr(0, Attribute::ZExt); + SanCovTraceConstCmpFunction[i]->addParamAttr(1, Attribute::ZExt); } SanCovTraceDivFunction[0]->addParamAttr(0, Attribute::ZExt); } @@ -305,13 +384,27 @@ bool SanitizerCoverageModule::runOnModule(Module &M) { for (auto &F : M) runOnFunction(F); + Function *Ctor = nullptr; + if (FunctionGuardArray) - CreateInitCallForSection(M, SanCovTracePCGuardInitName, Int32PtrTy, - SanCovGuardsSectionName); + Ctor = CreateInitCallsForSections(M, SanCovTracePCGuardInitName, Int32PtrTy, + SanCovGuardsSectionName); if (Function8bitCounterArray) - CreateInitCallForSection(M, SanCov8bitCountersInitName, Int8PtrTy, - SanCovCountersSectionName); - + Ctor = CreateInitCallsForSections(M, SanCov8bitCountersInitName, Int8PtrTy, + SanCovCountersSectionName); + if (Ctor && Options.PCTable) { + auto SecStartEnd = CreateSecStartEnd(M, SanCovPCsSectionName, IntptrPtrTy); + Function *InitFunction = declareSanitizerInitFunction( + M, SanCovPCsInitName, {IntptrPtrTy, IntptrPtrTy}); + IRBuilder<> IRBCtor(Ctor->getEntryBlock().getTerminator()); + IRBCtor.CreateCall(InitFunction, + {IRB.CreatePointerCast(SecStartEnd.first, IntptrPtrTy), + IRB.CreatePointerCast(SecStartEnd.second, IntptrPtrTy)}); + } + // We don't reference these arrays directly in any of our runtime functions, + // so we need to prevent them from being dead stripped. + if (TargetTriple.isOSBinFormatMachO()) + appendToUsed(M, GlobalsToAppendToUsed); return true; } @@ -362,6 +455,10 @@ static bool shouldInstrumentBlock(const Function &F, const BasicBlock *BB, if (Options.NoPrune || &F.getEntryBlock() == BB) return true; + if (Options.CoverageType == SanitizerCoverageOptions::SCK_Function && + &F.getEntryBlock() != BB) + return false; + // Do not instrument full dominators, or full post-dominators with multiple // predecessors. return !isFullDominator(BB, DT) @@ -375,6 +472,9 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { return false; // Should not instrument sanitizer init functions. if (F.getName().startswith("__sanitizer_")) return false; // Don't instrument __sanitizer_* callbacks. + // Don't touch available_externally functions, their actual body is elewhere. + if (F.getLinkage() == GlobalValue::AvailableExternallyLinkage) + return false; // Don't instrument MSVC CRT configuration helpers. They may run before normal // initialization. if (F.getName() == "__local_stdio_printf_options" || @@ -399,6 +499,7 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { &getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); const PostDominatorTree *PDT = &getAnalysis<PostDominatorTreeWrapperPass>(F).getPostDomTree(); + bool IsLeafFunc = true; for (auto &BB : F) { if (shouldInstrumentBlock(F, &BB, DT, PDT, Options)) @@ -423,10 +524,14 @@ bool SanitizerCoverageModule::runOnFunction(Function &F) { if (Options.TraceGep) if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&Inst)) GepTraceTargets.push_back(GEP); - } + if (Options.StackDepth) + if (isa<InvokeInst>(Inst) || + (isa<CallInst>(Inst) && !isa<IntrinsicInst>(Inst))) + IsLeafFunc = false; + } } - InjectCoverage(F, BlocksToInstrument); + InjectCoverage(F, BlocksToInstrument, IsLeafFunc); InjectCoverageForIndirectCalls(F, IndirCalls); InjectTraceForCmp(F, CmpTraceTargets); InjectTraceForSwitch(F, SwitchTraceTargets); @@ -444,35 +549,65 @@ GlobalVariable *SanitizerCoverageModule::CreateFunctionLocalArrayInSection( if (auto Comdat = F.getComdat()) Array->setComdat(Comdat); Array->setSection(getSectionName(Section)); + Array->setAlignment(Ty->isPointerTy() ? DL->getPointerSize() + : Ty->getPrimitiveSizeInBits() / 8); return Array; } -void SanitizerCoverageModule::CreateFunctionLocalArrays(size_t NumGuards, - Function &F) { - if (Options.TracePCGuard) + +GlobalVariable * +SanitizerCoverageModule::CreatePCArray(Function &F, + ArrayRef<BasicBlock *> AllBlocks) { + size_t N = AllBlocks.size(); + assert(N); + SmallVector<Constant *, 32> PCs; + IRBuilder<> IRB(&*F.getEntryBlock().getFirstInsertionPt()); + for (size_t i = 0; i < N; i++) { + if (&F.getEntryBlock() == AllBlocks[i]) { + PCs.push_back((Constant *)IRB.CreatePointerCast(&F, IntptrPtrTy)); + PCs.push_back((Constant *)IRB.CreateIntToPtr( + ConstantInt::get(IntptrTy, 1), IntptrPtrTy)); + } else { + PCs.push_back((Constant *)IRB.CreatePointerCast( + BlockAddress::get(AllBlocks[i]), IntptrPtrTy)); + PCs.push_back((Constant *)IRB.CreateIntToPtr( + ConstantInt::get(IntptrTy, 0), IntptrPtrTy)); + } + } + auto *PCArray = CreateFunctionLocalArrayInSection(N * 2, F, IntptrPtrTy, + SanCovPCsSectionName); + PCArray->setInitializer( + ConstantArray::get(ArrayType::get(IntptrPtrTy, N * 2), PCs)); + PCArray->setConstant(true); + + return PCArray; +} + +void SanitizerCoverageModule::CreateFunctionLocalArrays( + Function &F, ArrayRef<BasicBlock *> AllBlocks) { + if (Options.TracePCGuard) { FunctionGuardArray = CreateFunctionLocalArrayInSection( - NumGuards, F, Int32Ty, SanCovGuardsSectionName); - if (Options.Inline8bitCounters) + AllBlocks.size(), F, Int32Ty, SanCovGuardsSectionName); + GlobalsToAppendToUsed.push_back(FunctionGuardArray); + } + if (Options.Inline8bitCounters) { Function8bitCounterArray = CreateFunctionLocalArrayInSection( - NumGuards, F, Int8Ty, SanCovCountersSectionName); + AllBlocks.size(), F, Int8Ty, SanCovCountersSectionName); + GlobalsToAppendToUsed.push_back(Function8bitCounterArray); + } + if (Options.PCTable) { + FunctionPCsArray = CreatePCArray(F, AllBlocks); + GlobalsToAppendToUsed.push_back(FunctionPCsArray); + } } bool SanitizerCoverageModule::InjectCoverage(Function &F, - ArrayRef<BasicBlock *> AllBlocks) { + ArrayRef<BasicBlock *> AllBlocks, + bool IsLeafFunc) { if (AllBlocks.empty()) return false; - switch (Options.CoverageType) { - case SanitizerCoverageOptions::SCK_None: - return false; - case SanitizerCoverageOptions::SCK_Function: - CreateFunctionLocalArrays(1, F); - InjectCoverageAtBlock(F, F.getEntryBlock(), 0); - return true; - default: { - CreateFunctionLocalArrays(AllBlocks.size(), F); - for (size_t i = 0, N = AllBlocks.size(); i < N; i++) - InjectCoverageAtBlock(F, *AllBlocks[i], i); - return true; - } - } + CreateFunctionLocalArrays(F, AllBlocks); + for (size_t i = 0, N = AllBlocks.size(); i < N; i++) + InjectCoverageAtBlock(F, *AllBlocks[i], i, IsLeafFunc); + return true; } // On every indirect call we call a run-time function @@ -585,16 +720,28 @@ void SanitizerCoverageModule::InjectTraceForCmp( TypeSize == 64 ? 3 : -1; if (CallbackIdx < 0) continue; // __sanitizer_cov_trace_cmp((type_size << 32) | predicate, A0, A1); + auto CallbackFunc = SanCovTraceCmpFunction[CallbackIdx]; + bool FirstIsConst = isa<ConstantInt>(A0); + bool SecondIsConst = isa<ConstantInt>(A1); + // If both are const, then we don't need such a comparison. + if (FirstIsConst && SecondIsConst) continue; + // If only one is const, then make it the first callback argument. + if (FirstIsConst || SecondIsConst) { + CallbackFunc = SanCovTraceConstCmpFunction[CallbackIdx]; + if (SecondIsConst) + std::swap(A0, A1); + } + auto Ty = Type::getIntNTy(*C, TypeSize); - IRB.CreateCall( - SanCovTraceCmpFunction[CallbackIdx], - {IRB.CreateIntCast(A0, Ty, true), IRB.CreateIntCast(A1, Ty, true)}); + IRB.CreateCall(CallbackFunc, {IRB.CreateIntCast(A0, Ty, true), + IRB.CreateIntCast(A1, Ty, true)}); } } } void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, - size_t Idx) { + size_t Idx, + bool IsLeafFunc) { BasicBlock::iterator IP = BB.getFirstInsertionPt(); bool IsEntryBB = &BB == &F.getEntryBlock(); DebugLoc EntryLoc; @@ -633,6 +780,21 @@ void SanitizerCoverageModule::InjectCoverageAtBlock(Function &F, BasicBlock &BB, SetNoSanitizeMetadata(Load); SetNoSanitizeMetadata(Store); } + if (Options.StackDepth && IsEntryBB && !IsLeafFunc) { + // Check stack depth. If it's the deepest so far, record it. + Function *GetFrameAddr = + Intrinsic::getDeclaration(F.getParent(), Intrinsic::frameaddress); + auto FrameAddrPtr = + IRB.CreateCall(GetFrameAddr, {Constant::getNullValue(Int32Ty)}); + auto FrameAddrInt = IRB.CreatePtrToInt(FrameAddrPtr, IntptrTy); + auto LowestStack = IRB.CreateLoad(SanCovLowestStack); + auto IsStackLower = IRB.CreateICmpULT(FrameAddrInt, LowestStack); + auto ThenTerm = SplitBlockAndInsertIfThen(IsStackLower, &*IP, false); + IRBuilder<> ThenIRB(ThenTerm); + auto Store = ThenIRB.CreateStore(FrameAddrInt, SanCovLowestStack); + SetNoSanitizeMetadata(LowestStack); + SetNoSanitizeMetadata(Store); + } } std::string diff --git a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h index cb3b5757f8d0..ba4924c9cb2d 100644 --- a/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h +++ b/lib/Transforms/ObjCARC/ARCRuntimeEntryPoints.h @@ -1,4 +1,4 @@ -//===- ARCRuntimeEntryPoints.h - ObjC ARC Optimization --*- C++ -*---------===// +//===- ARCRuntimeEntryPoints.h - ObjC ARC Optimization ----------*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// This file contains a class ARCRuntimeEntryPoints for use in /// creating/managing references to entry points to the arc objective c runtime. @@ -16,15 +17,25 @@ /// WARNING: This file knows about how certain Objective-C library functions are /// used. Naive LLVM IR transformations which would otherwise be /// behavior-preserving may break these assumptions. -/// +// //===----------------------------------------------------------------------===// #ifndef LLVM_LIB_TRANSFORMS_OBJCARC_ARCRUNTIMEENTRYPOINTS_H #define LLVM_LIB_TRANSFORMS_OBJCARC_ARCRUNTIMEENTRYPOINTS_H -#include "ObjCARC.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Support/ErrorHandling.h" +#include <cassert> namespace llvm { + +class Constant; +class LLVMContext; + namespace objcarc { enum class ARCRuntimeEntryPointKind { @@ -43,16 +54,7 @@ enum class ARCRuntimeEntryPointKind { /// lazily to avoid cluttering up the Module with unused declarations. class ARCRuntimeEntryPoints { public: - ARCRuntimeEntryPoints() : TheModule(nullptr), - AutoreleaseRV(nullptr), - Release(nullptr), - Retain(nullptr), - RetainBlock(nullptr), - Autorelease(nullptr), - StoreStrong(nullptr), - RetainRV(nullptr), - RetainAutorelease(nullptr), - RetainAutoreleaseRV(nullptr) { } + ARCRuntimeEntryPoints() = default; void init(Module *M) { TheModule = M; @@ -100,26 +102,34 @@ public: private: /// Cached reference to the module which we will insert declarations into. - Module *TheModule; + Module *TheModule = nullptr; /// Declaration for ObjC runtime function objc_autoreleaseReturnValue. - Constant *AutoreleaseRV; + Constant *AutoreleaseRV = nullptr; + /// Declaration for ObjC runtime function objc_release. - Constant *Release; + Constant *Release = nullptr; + /// Declaration for ObjC runtime function objc_retain. - Constant *Retain; + Constant *Retain = nullptr; + /// Declaration for ObjC runtime function objc_retainBlock. - Constant *RetainBlock; + Constant *RetainBlock = nullptr; + /// Declaration for ObjC runtime function objc_autorelease. - Constant *Autorelease; + Constant *Autorelease = nullptr; + /// Declaration for objc_storeStrong(). - Constant *StoreStrong; + Constant *StoreStrong = nullptr; + /// Declaration for objc_retainAutoreleasedReturnValue(). - Constant *RetainRV; + Constant *RetainRV = nullptr; + /// Declaration for objc_retainAutorelease(). - Constant *RetainAutorelease; + Constant *RetainAutorelease = nullptr; + /// Declaration for objc_retainAutoreleaseReturnValue(). - Constant *RetainAutoreleaseRV; + Constant *RetainAutoreleaseRV = nullptr; Constant *getVoidRetI8XEntryPoint(Constant *&Decl, StringRef Name) { if (Decl) @@ -170,10 +180,10 @@ private: return Decl = TheModule->getOrInsertFunction(Name, Fty, Attr); } +}; -}; // class ARCRuntimeEntryPoints +} // end namespace objcarc -} // namespace objcarc -} // namespace llvm +} // end namespace llvm -#endif +#endif // LLVM_LIB_TRANSFORMS_OBJCARC_ARCRUNTIMEENTRYPOINTS_H diff --git a/lib/Transforms/ObjCARC/BlotMapVector.h b/lib/Transforms/ObjCARC/BlotMapVector.h index 9c5cf6f5f5ab..5518b49c4095 100644 --- a/lib/Transforms/ObjCARC/BlotMapVector.h +++ b/lib/Transforms/ObjCARC/BlotMapVector.h @@ -1,4 +1,4 @@ -//===- BlotMapVector.h - A MapVector with the blot operation -*- C++ -*----===// +//===- BlotMapVector.h - A MapVector with the blot operation ----*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -7,30 +7,29 @@ // //===----------------------------------------------------------------------===// +#ifndef LLVM_LIB_TRANSFORMS_OBJCARC_BLOTMAPVECTOR_H +#define LLVM_LIB_TRANSFORMS_OBJCARC_BLOTMAPVECTOR_H + #include "llvm/ADT/DenseMap.h" -#include <algorithm> +#include <cassert> +#include <cstddef> +#include <utility> #include <vector> namespace llvm { + /// \brief An associative container with fast insertion-order (deterministic) /// iteration over its elements. Plus the special blot operation. template <class KeyT, class ValueT> class BlotMapVector { /// Map keys to indices in Vector. - typedef DenseMap<KeyT, size_t> MapTy; + using MapTy = DenseMap<KeyT, size_t>; MapTy Map; - typedef std::vector<std::pair<KeyT, ValueT>> VectorTy; /// Keys and values. + using VectorTy = std::vector<std::pair<KeyT, ValueT>>; VectorTy Vector; public: - typedef typename VectorTy::iterator iterator; - typedef typename VectorTy::const_iterator const_iterator; - iterator begin() { return Vector.begin(); } - iterator end() { return Vector.end(); } - const_iterator begin() const { return Vector.begin(); } - const_iterator end() const { return Vector.end(); } - #ifdef EXPENSIVE_CHECKS ~BlotMapVector() { assert(Vector.size() >= Map.size()); // May differ due to blotting. @@ -46,6 +45,14 @@ public: } #endif + using iterator = typename VectorTy::iterator; + using const_iterator = typename VectorTy::const_iterator; + + iterator begin() { return Vector.begin(); } + iterator end() { return Vector.end(); } + const_iterator begin() const { return Vector.begin(); } + const_iterator end() const { return Vector.end(); } + ValueT &operator[](const KeyT &Arg) { std::pair<typename MapTy::iterator, bool> Pair = Map.insert(std::make_pair(Arg, size_t(0))); @@ -105,4 +112,7 @@ public: return Map.empty(); } }; -} // + +} // end namespace llvm + +#endif // LLVM_LIB_TRANSFORMS_OBJCARC_BLOTMAPVECTOR_H diff --git a/lib/Transforms/ObjCARC/ObjCARC.cpp b/lib/Transforms/ObjCARC/ObjCARC.cpp index 688dd12c408a..c30aaebd0f4d 100644 --- a/lib/Transforms/ObjCARC/ObjCARC.cpp +++ b/lib/Transforms/ObjCARC/ObjCARC.cpp @@ -14,7 +14,6 @@ //===----------------------------------------------------------------------===// #include "ObjCARC.h" -#include "llvm-c/Core.h" #include "llvm-c/Initialization.h" #include "llvm/InitializePasses.h" diff --git a/lib/Transforms/ObjCARC/ObjCARCContract.cpp b/lib/Transforms/ObjCARC/ObjCARCContract.cpp index e70e7591f6a7..c4e61218f3f3 100644 --- a/lib/Transforms/ObjCARC/ObjCARCContract.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCContract.cpp @@ -248,7 +248,7 @@ static StoreInst *findSafeStoreForStoreStrongContraction(LoadInst *Load, // Ok, now we know we have not seen a store yet. See if Inst can write to // our load location, if it can not, just ignore the instruction. - if (!(AA->getModRefInfo(Inst, Loc) & MRI_Mod)) + if (!isModSet(AA->getModRefInfo(Inst, Loc))) continue; Store = dyn_cast<StoreInst>(Inst); diff --git a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp index 8f3a33f66c7f..99ed6863c22e 100644 --- a/lib/Transforms/ObjCARC/ObjCARCOpts.cpp +++ b/lib/Transforms/ObjCARC/ObjCARCOpts.cpp @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// This file defines ObjC ARC optimizations. ARC stands for Automatic /// Reference Counting and is a system for managing reference counts for objects @@ -21,7 +22,7 @@ /// WARNING: This file knows about how certain Objective-C library functions are /// used. Naive LLVM IR transformations which would otherwise be /// behavior-preserving may break these assumptions. -/// +// //===----------------------------------------------------------------------===// #include "ARCRuntimeEntryPoints.h" @@ -31,16 +32,41 @@ #include "ProvenanceAnalysis.h" #include "PtrState.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ObjCARCAliasAnalysis.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/Analysis/ObjCARCInstKind.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" -#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <iterator> +#include <utility> using namespace llvm; using namespace llvm::objcarc; @@ -147,14 +173,15 @@ STATISTIC(NumReleasesAfterOpt, #endif namespace { + /// \brief Per-BasicBlock state. class BBState { /// The number of unique control paths from the entry which can reach this /// block. - unsigned TopDownPathCount; + unsigned TopDownPathCount = 0; /// The number of unique control paths to exits from this block. - unsigned BottomUpPathCount; + unsigned BottomUpPathCount = 0; /// The top-down traversal uses this to record information known about a /// pointer at the bottom of each block. @@ -175,10 +202,10 @@ namespace { public: static const unsigned OverflowOccurredValue; - BBState() : TopDownPathCount(0), BottomUpPathCount(0) { } + BBState() = default; - typedef decltype(PerPtrTopDown)::iterator top_down_ptr_iterator; - typedef decltype(PerPtrTopDown)::const_iterator const_top_down_ptr_iterator; + using top_down_ptr_iterator = decltype(PerPtrTopDown)::iterator; + using const_top_down_ptr_iterator = decltype(PerPtrTopDown)::const_iterator; top_down_ptr_iterator top_down_ptr_begin() { return PerPtrTopDown.begin(); } top_down_ptr_iterator top_down_ptr_end() { return PerPtrTopDown.end(); } @@ -192,9 +219,9 @@ namespace { return !PerPtrTopDown.empty(); } - typedef decltype(PerPtrBottomUp)::iterator bottom_up_ptr_iterator; - typedef decltype( - PerPtrBottomUp)::const_iterator const_bottom_up_ptr_iterator; + using bottom_up_ptr_iterator = decltype(PerPtrBottomUp)::iterator; + using const_bottom_up_ptr_iterator = + decltype(PerPtrBottomUp)::const_iterator; bottom_up_ptr_iterator bottom_up_ptr_begin() { return PerPtrBottomUp.begin(); @@ -270,7 +297,8 @@ namespace { } // Specialized CFG utilities. - typedef SmallVectorImpl<BasicBlock *>::const_iterator edge_iterator; + using edge_iterator = SmallVectorImpl<BasicBlock *>::const_iterator; + edge_iterator pred_begin() const { return Preds.begin(); } edge_iterator pred_end() const { return Preds.end(); } edge_iterator succ_begin() const { return Succs.begin(); } @@ -282,13 +310,16 @@ namespace { bool isExit() const { return Succs.empty(); } }; - const unsigned BBState::OverflowOccurredValue = 0xffffffff; -} +} // end anonymous namespace + +const unsigned BBState::OverflowOccurredValue = 0xffffffff; namespace llvm { + raw_ostream &operator<<(raw_ostream &OS, BBState &BBState) LLVM_ATTRIBUTE_UNUSED; -} + +} // end namespace llvm void BBState::InitFromPred(const BBState &Other) { PerPtrTopDown = Other.PerPtrTopDown; @@ -391,7 +422,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, BBState &BBInfo) { // Dump the pointers we are tracking. OS << " TopDown State:\n"; if (!BBInfo.hasTopDownPtrs()) { - DEBUG(llvm::dbgs() << " NONE!\n"); + DEBUG(dbgs() << " NONE!\n"); } else { for (auto I = BBInfo.top_down_ptr_begin(), E = BBInfo.top_down_ptr_end(); I != E; ++I) { @@ -411,7 +442,7 @@ raw_ostream &llvm::operator<<(raw_ostream &OS, BBState &BBInfo) { OS << " BottomUp State:\n"; if (!BBInfo.hasBottomUpPtrs()) { - DEBUG(llvm::dbgs() << " NONE!\n"); + DEBUG(dbgs() << " NONE!\n"); } else { for (auto I = BBInfo.bottom_up_ptr_begin(), E = BBInfo.bottom_up_ptr_end(); I != E; ++I) { @@ -513,13 +544,16 @@ namespace { public: static char ID; + ObjCARCOpt() : FunctionPass(ID) { initializeObjCARCOptPass(*PassRegistry::getPassRegistry()); } }; -} + +} // end anonymous namespace char ObjCARCOpt::ID = 0; + INITIALIZE_PASS_BEGIN(ObjCARCOpt, "objc-arc", "ObjC ARC optimization", false, false) INITIALIZE_PASS_DEPENDENCY(ObjCARCAAWrapperPass) @@ -643,7 +677,6 @@ void ObjCARCOpt::OptimizeAutoreleaseRVCall(Function &F, Class = ARCInstKind::Autorelease; DEBUG(dbgs() << "New: " << *AutoreleaseRV << "\n"); - } /// Visit each call, one at a time, and make simplifications without doing any @@ -692,7 +725,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { new StoreInst(UndefValue::get(cast<PointerType>(Ty)->getElementType()), Constant::getNullValue(Ty), CI); - llvm::Value *NewValue = UndefValue::get(CI->getType()); + Value *NewValue = UndefValue::get(CI->getType()); DEBUG(dbgs() << "A null pointer-to-weak-pointer is undefined behavior." "\nOld = " << *CI << "\nNew = " << *NewValue << "\n"); CI->replaceAllUsesWith(NewValue); @@ -712,7 +745,7 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { Constant::getNullValue(Ty), CI); - llvm::Value *NewValue = UndefValue::get(CI->getType()); + Value *NewValue = UndefValue::get(CI->getType()); DEBUG(dbgs() << "A null pointer-to-weak-pointer is undefined behavior." "\nOld = " << *CI << "\nNew = " << *NewValue << "\n"); @@ -808,9 +841,14 @@ void ObjCARCOpt::OptimizeIndividualCalls(Function &F) { // If Arg is a PHI, and one or more incoming values to the // PHI are null, and the call is control-equivalent to the PHI, and there - // are no relevant side effects between the PHI and the call, the call - // could be pushed up to just those paths with non-null incoming values. - // For now, don't bother splitting critical edges for this. + // are no relevant side effects between the PHI and the call, and the call + // is not a release that doesn't have the clang.imprecise_release tag, the + // call could be pushed up to just those paths with non-null incoming + // values. For now, don't bother splitting critical edges for this. + if (Class == ARCInstKind::Release && + !Inst->getMetadata(MDKindCache.get(ARCMDKindID::ImpreciseRelease))) + continue; + SmallVector<std::pair<Instruction *, const Value *>, 4> Worklist; Worklist.push_back(std::make_pair(Inst, Arg)); do { @@ -1040,12 +1078,11 @@ ObjCARCOpt::CheckForCFGHazards(const BasicBlock *BB, continue; break; } - case S_CanRelease: { + case S_CanRelease: CheckForCanReleaseCFGHazard(SuccSSeq, SuccSRRIKnownSafe, S, SomeSuccHasSame, AllSuccsHaveSame, NotAllSeqEqualButKnownSafe); break; - } case S_Retain: case S_None: case S_Stop: @@ -1100,7 +1137,7 @@ bool ObjCARCOpt::VisitInstructionBottomUp( // Don't do retain+release tracking for ARCInstKind::RetainRV, because // it's better to let it remain as the first instruction after a call. if (Class != ARCInstKind::RetainRV) { - DEBUG(llvm::dbgs() << " Matching with: " << *Inst << "\n"); + DEBUG(dbgs() << " Matching with: " << *Inst << "\n"); Retains[Inst] = S.GetRRInfo(); } S.ClearSequenceProgress(); @@ -1142,7 +1179,6 @@ bool ObjCARCOpt::VisitInstructionBottomUp( bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, DenseMap<const BasicBlock *, BBState> &BBStates, BlotMapVector<Value *, RRInfo> &Retains) { - DEBUG(dbgs() << "\n== ObjCARCOpt::VisitBottomUp ==\n"); bool NestingDetected = false; @@ -1166,8 +1202,8 @@ bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, } } - DEBUG(llvm::dbgs() << "Before:\n" << BBStates[BB] << "\n" - << "Performing Dataflow:\n"); + DEBUG(dbgs() << "Before:\n" << BBStates[BB] << "\n" + << "Performing Dataflow:\n"); // Visit all the instructions, bottom-up. for (BasicBlock::iterator I = BB->end(), E = BB->begin(); I != E; --I) { @@ -1192,7 +1228,7 @@ bool ObjCARCOpt::VisitBottomUp(BasicBlock *BB, NestingDetected |= VisitInstructionBottomUp(II, BB, Retains, MyStates); } - DEBUG(llvm::dbgs() << "\nFinal State:\n" << BBStates[BB] << "\n"); + DEBUG(dbgs() << "\nFinal State:\n" << BBStates[BB] << "\n"); return NestingDetected; } @@ -1205,7 +1241,7 @@ ObjCARCOpt::VisitInstructionTopDown(Instruction *Inst, ARCInstKind Class = GetARCInstKind(Inst); const Value *Arg = nullptr; - DEBUG(llvm::dbgs() << " Class: " << Class << "\n"); + DEBUG(dbgs() << " Class: " << Class << "\n"); switch (Class) { case ARCInstKind::RetainBlock: @@ -1231,7 +1267,7 @@ ObjCARCOpt::VisitInstructionTopDown(Instruction *Inst, if (S.MatchWithRelease(MDKindCache, Inst)) { // If we succeed, copy S's RRInfo into the Release -> {Retain Set // Map}. Then we clear S. - DEBUG(llvm::dbgs() << " Matching with: " << *Inst << "\n"); + DEBUG(dbgs() << " Matching with: " << *Inst << "\n"); Releases[Inst] = S.GetRRInfo(); S.ClearSequenceProgress(); } @@ -1293,8 +1329,8 @@ ObjCARCOpt::VisitTopDown(BasicBlock *BB, } } - DEBUG(llvm::dbgs() << "Before:\n" << BBStates[BB] << "\n" - << "Performing Dataflow:\n"); + DEBUG(dbgs() << "Before:\n" << BBStates[BB] << "\n" + << "Performing Dataflow:\n"); // Visit all the instructions, top-down. for (Instruction &Inst : *BB) { @@ -1303,10 +1339,10 @@ ObjCARCOpt::VisitTopDown(BasicBlock *BB, NestingDetected |= VisitInstructionTopDown(&Inst, Releases, MyStates); } - DEBUG(llvm::dbgs() << "\nState Before Checking for CFG Hazards:\n" - << BBStates[BB] << "\n\n"); + DEBUG(dbgs() << "\nState Before Checking for CFG Hazards:\n" + << BBStates[BB] << "\n\n"); CheckForCFGHazards(BB, BBStates, MyStates); - DEBUG(llvm::dbgs() << "Final State:\n" << BBStates[BB] << "\n"); + DEBUG(dbgs() << "Final State:\n" << BBStates[BB] << "\n"); return NestingDetected; } @@ -1395,7 +1431,6 @@ bool ObjCARCOpt::Visit(Function &F, DenseMap<const BasicBlock *, BBState> &BBStates, BlotMapVector<Value *, RRInfo> &Retains, DenseMap<Value *, RRInfo> &Releases) { - // Use reverse-postorder traversals, because we magically know that loops // will be well behaved, i.e. they won't repeatedly call retain on a single // pointer without doing a release. We can't use the ReversePostOrderTraversal @@ -1409,12 +1444,12 @@ bool ObjCARCOpt::Visit(Function &F, // Use reverse-postorder on the reverse CFG for bottom-up. bool BottomUpNestingDetected = false; - for (BasicBlock *BB : reverse(ReverseCFGPostOrder)) + for (BasicBlock *BB : llvm::reverse(ReverseCFGPostOrder)) BottomUpNestingDetected |= VisitBottomUp(BB, BBStates, Retains); // Use reverse-postorder for top-down. bool TopDownNestingDetected = false; - for (BasicBlock *BB : reverse(PostOrder)) + for (BasicBlock *BB : llvm::reverse(PostOrder)) TopDownNestingDetected |= VisitTopDown(BB, BBStates, Releases); return TopDownNestingDetected && BottomUpNestingDetected; @@ -1471,7 +1506,6 @@ void ObjCARCOpt::MoveCalls(Value *Arg, RRInfo &RetainsToMove, DeadInsts.push_back(OrigRelease); DEBUG(dbgs() << "Deleting release: " << *OrigRelease << "\n"); } - } bool ObjCARCOpt::PairUpRetainsAndReleases( @@ -1519,7 +1553,6 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( return false; if (ReleasesToMove.Calls.insert(NewRetainRelease).second) { - // If we overflow when we compute the path count, don't remove/move // anything. const BBState &NRRBBState = BBStates[NewRetainRelease->getParent()]; @@ -1646,8 +1679,9 @@ bool ObjCARCOpt::PairUpRetainsAndReleases( // At this point, we are not going to remove any RR pairs, but we still are // able to move RR pairs. If one of our pointers is afflicted with // CFGHazards, we cannot perform such code motion so exit early. - const bool WillPerformCodeMotion = RetainsToMove.ReverseInsertPts.size() || - ReleasesToMove.ReverseInsertPts.size(); + const bool WillPerformCodeMotion = + !RetainsToMove.ReverseInsertPts.empty() || + !ReleasesToMove.ReverseInsertPts.empty(); if (CFGHazardAfflicted && WillPerformCodeMotion) return false; } @@ -2027,7 +2061,8 @@ void ObjCARCOpt::OptimizeReturns(Function &F) { continue; CallInst *Retain = FindPredecessorRetainWithSafePath( - Arg, &BB, Autorelease, DependingInstructions, Visited, PA); + Arg, Autorelease->getParent(), Autorelease, DependingInstructions, + Visited, PA); DependingInstructions.clear(); Visited.clear(); @@ -2058,10 +2093,10 @@ void ObjCARCOpt::OptimizeReturns(Function &F) { #ifndef NDEBUG void ObjCARCOpt::GatherStatistics(Function &F, bool AfterOptimization) { - llvm::Statistic &NumRetains = - AfterOptimization? NumRetainsAfterOpt : NumRetainsBeforeOpt; - llvm::Statistic &NumReleases = - AfterOptimization? NumReleasesAfterOpt : NumReleasesBeforeOpt; + Statistic &NumRetains = + AfterOptimization ? NumRetainsAfterOpt : NumRetainsBeforeOpt; + Statistic &NumReleases = + AfterOptimization ? NumReleasesAfterOpt : NumReleasesBeforeOpt; for (inst_iterator I = inst_begin(&F), E = inst_end(&F); I != E; ) { Instruction *Inst = &*I++; diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp index 62fc52f6d091..f89fc8eb62aa 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.cpp @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// /// This file defines a special form of Alias Analysis called ``Provenance @@ -19,13 +20,21 @@ /// WARNING: This file knows about how certain Objective-C library functions are /// used. Naive LLVM IR transformations which would otherwise be /// behavior-preserving may break these assumptions. -/// +// //===----------------------------------------------------------------------===// #include "ProvenanceAnalysis.h" -#include "ObjCARC.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include <utility> using namespace llvm; using namespace llvm::objcarc; diff --git a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h index 1a12b659e5a3..5e676167a6a1 100644 --- a/lib/Transforms/ObjCARC/ProvenanceAnalysis.h +++ b/lib/Transforms/ObjCARC/ProvenanceAnalysis.h @@ -1,4 +1,4 @@ -//===- ProvenanceAnalysis.h - ObjC ARC Optimization ---*- C++ -*-----------===// +//===- ProvenanceAnalysis.h - ObjC ARC Optimization -------------*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// /// This file declares a special form of Alias Analysis called ``Provenance @@ -19,7 +20,7 @@ /// WARNING: This file knows about how certain Objective-C library functions are /// used. Naive LLVM IR transformations which would otherwise be /// behavior-preserving may break these assumptions. -/// +// //===----------------------------------------------------------------------===// #ifndef LLVM_LIB_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H @@ -27,15 +28,15 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/Analysis/AliasAnalysis.h" +#include <utility> namespace llvm { - class Value; - class DataLayout; - class PHINode; - class SelectInst; -} -namespace llvm { +class DataLayout; +class PHINode; +class SelectInst; +class Value; + namespace objcarc { /// \brief This is similar to BasicAliasAnalysis, and it uses many of the same @@ -50,19 +51,19 @@ namespace objcarc { class ProvenanceAnalysis { AliasAnalysis *AA; - typedef std::pair<const Value *, const Value *> ValuePairTy; - typedef DenseMap<ValuePairTy, bool> CachedResultsTy; + using ValuePairTy = std::pair<const Value *, const Value *>; + using CachedResultsTy = DenseMap<ValuePairTy, bool>; + CachedResultsTy CachedResults; bool relatedCheck(const Value *A, const Value *B, const DataLayout &DL); bool relatedSelect(const SelectInst *A, const Value *B); bool relatedPHI(const PHINode *A, const Value *B); - void operator=(const ProvenanceAnalysis &) = delete; - ProvenanceAnalysis(const ProvenanceAnalysis &) = delete; - public: - ProvenanceAnalysis() {} + ProvenanceAnalysis() = default; + ProvenanceAnalysis(const ProvenanceAnalysis &) = delete; + ProvenanceAnalysis &operator=(const ProvenanceAnalysis &) = delete; void setAA(AliasAnalysis *aa) { AA = aa; } @@ -76,6 +77,7 @@ public: }; } // end namespace objcarc + } // end namespace llvm -#endif +#endif // LLVM_LIB_TRANSFORMS_OBJCARC_PROVENANCEANALYSIS_H diff --git a/lib/Transforms/ObjCARC/PtrState.cpp b/lib/Transforms/ObjCARC/PtrState.cpp index d13e941044f1..e1774b88fd35 100644 --- a/lib/Transforms/ObjCARC/PtrState.cpp +++ b/lib/Transforms/ObjCARC/PtrState.cpp @@ -1,4 +1,4 @@ -//===--- PtrState.cpp -----------------------------------------------------===// +//===- PtrState.cpp -------------------------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -10,8 +10,20 @@ #include "PtrState.h" #include "DependencyAnalysis.h" #include "ObjCARC.h" +#include "llvm/Analysis/ObjCARCAnalysisUtils.h" +#include "llvm/Analysis/ObjCARCInstKind.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <iterator> +#include <utility> using namespace llvm; using namespace llvm::objcarc; @@ -250,10 +262,14 @@ void BottomUpPtrState::HandlePotentialUse(BasicBlock *BB, Instruction *Inst, // If this is an invoke instruction, we're scanning it as part of // one of its successor blocks, since we can't insert code after it // in its own block, and we don't want to split critical edges. - if (isa<InvokeInst>(Inst)) - InsertReverseInsertPt(&*BB->getFirstInsertionPt()); - else - InsertReverseInsertPt(&*++Inst->getIterator()); + BasicBlock::iterator InsertAfter; + if (isa<InvokeInst>(Inst)) { + const auto IP = BB->getFirstInsertionPt(); + InsertAfter = IP == BB->end() ? std::prev(BB->end()) : IP; + } else { + InsertAfter = std::next(Inst->getIterator()); + } + InsertReverseInsertPt(&*InsertAfter); }; // Check for possible direct uses. diff --git a/lib/Transforms/ObjCARC/PtrState.h b/lib/Transforms/ObjCARC/PtrState.h index 87298fa59bfd..e1e95afcf76b 100644 --- a/lib/Transforms/ObjCARC/PtrState.h +++ b/lib/Transforms/ObjCARC/PtrState.h @@ -1,4 +1,4 @@ -//===--- PtrState.h - ARC State for a Ptr -------------------*- C++ -*-----===// +//===- PtrState.h - ARC State for a Ptr -------------------------*- C++ -*-===// // // The LLVM Compiler Infrastructure // @@ -19,12 +19,16 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/ObjCARCInstKind.h" -#include "llvm/IR/Instruction.h" -#include "llvm/IR/Value.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/Support/Compiler.h" namespace llvm { + +class BasicBlock; +class Instruction; +class MDNode; +class raw_ostream; +class Value; + namespace objcarc { class ARCMDKindCache; @@ -63,14 +67,14 @@ struct RRInfo { /// of any intervening side effects. /// /// KnownSafe is true when either of these conditions is satisfied. - bool KnownSafe; + bool KnownSafe = false; /// True of the objc_release calls are all marked with the "tail" keyword. - bool IsTailCallRelease; + bool IsTailCallRelease = false; /// If the Calls are objc_release calls and they all have a /// clang.imprecise_release tag, this is the metadata tag. - MDNode *ReleaseMetadata; + MDNode *ReleaseMetadata = nullptr; /// For a top-down sequence, the set of objc_retains or /// objc_retainBlocks. For bottom-up, the set of objc_releases. @@ -82,11 +86,9 @@ struct RRInfo { /// If this is true, we cannot perform code motion but can still remove /// retain/release pairs. - bool CFGHazardAfflicted; + bool CFGHazardAfflicted = false; - RRInfo() - : KnownSafe(false), IsTailCallRelease(false), ReleaseMetadata(nullptr), - CFGHazardAfflicted(false) {} + RRInfo() = default; void clear(); @@ -100,11 +102,11 @@ struct RRInfo { class PtrState { protected: /// True if the reference count is known to be incremented. - bool KnownPositiveRefCount; + bool KnownPositiveRefCount = false; /// True if we've seen an opportunity for partial RR elimination, such as /// pushing calls into a CFG triangle or into one side of a CFG diamond. - bool Partial; + bool Partial = false; /// The current position in the sequence. unsigned char Seq : 8; @@ -112,7 +114,7 @@ protected: /// Unidirectional information about the current sequence. RRInfo RRI; - PtrState() : KnownPositiveRefCount(false), Partial(false), Seq(S_None) {} + PtrState() : Seq(S_None) {} public: bool IsKnownSafe() const { return RRI.KnownSafe; } @@ -165,7 +167,7 @@ public: }; struct BottomUpPtrState : PtrState { - BottomUpPtrState() : PtrState() {} + BottomUpPtrState() = default; /// (Re-)Initialize this bottom up pointer returning true if we detected a /// pointer with nested releases. @@ -186,7 +188,7 @@ struct BottomUpPtrState : PtrState { }; struct TopDownPtrState : PtrState { - TopDownPtrState() : PtrState() {} + TopDownPtrState() = default; /// (Re-)Initialize this bottom up pointer returning true if we detected a /// pointer with nested releases. @@ -205,6 +207,7 @@ struct TopDownPtrState : PtrState { }; } // end namespace objcarc + } // end namespace llvm -#endif +#endif // LLVM_LIB_TRANSFORMS_OBJCARC_PTRSTATE_H diff --git a/lib/Transforms/Scalar/ADCE.cpp b/lib/Transforms/Scalar/ADCE.cpp index 5b467dc9fe12..1e683db50206 100644 --- a/lib/Transforms/Scalar/ADCE.cpp +++ b/lib/Transforms/Scalar/ADCE.cpp @@ -15,8 +15,10 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/ADCE.h" - +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" @@ -27,13 +29,29 @@ #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Function.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include <cassert> +#include <cstddef> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "adce" @@ -52,10 +70,12 @@ static cl::opt<bool> RemoveLoops("adce-remove-loops", cl::init(false), cl::Hidden); namespace { + /// Information about Instructions struct InstInfoType { /// True if the associated instruction is live. bool Live = false; + /// Quick access to information for block containing associated Instruction. struct BlockInfoType *Block = nullptr; }; @@ -64,10 +84,13 @@ struct InstInfoType { struct BlockInfoType { /// True when this block contains a live instructions. bool Live = false; + /// True when this block ends in an unconditional branch. bool UnconditionalBranch = false; + /// True when this block is known to have live PHI nodes. bool HasLivePhiNodes = false; + /// Control dependence sources need to be live for this block. bool CFLive = false; @@ -75,8 +98,6 @@ struct BlockInfoType { /// holds the value &InstInfo[Terminator] InstInfoType *TerminatorLiveInfo = nullptr; - bool terminatorIsLive() const { return TerminatorLiveInfo->Live; } - /// Corresponding BasicBlock. BasicBlock *BB = nullptr; @@ -85,14 +106,21 @@ struct BlockInfoType { /// Post-order numbering of reverse control flow graph. unsigned PostOrder; + + bool terminatorIsLive() const { return TerminatorLiveInfo->Live; } }; class AggressiveDeadCodeElimination { Function &F; + + // ADCE does not use DominatorTree per se, but it updates it to preserve the + // analysis. + DominatorTree &DT; PostDominatorTree &PDT; /// Mapping of blocks to associated information, an element in BlockInfoVec. - DenseMap<BasicBlock *, BlockInfoType> BlockInfo; + /// Use MapVector to get deterministic iteration order. + MapVector<BasicBlock *, BlockInfoType> BlockInfo; bool isLive(BasicBlock *BB) { return BlockInfo[BB].Live; } /// Mapping of instructions to associated information. @@ -102,6 +130,7 @@ class AggressiveDeadCodeElimination { /// Instructions known to be live where we need to mark /// reaching definitions as live. SmallVector<Instruction *, 128> Worklist; + /// Debug info scopes around a live instruction. SmallPtrSet<const Metadata *, 32> AliveScopes; @@ -116,15 +145,19 @@ class AggressiveDeadCodeElimination { /// Set up auxiliary data structures for Instructions and BasicBlocks and /// initialize the Worklist to the set of must-be-live Instruscions. void initialize(); + /// Return true for operations which are always treated as live. bool isAlwaysLive(Instruction &I); + /// Return true for instrumentation instructions for value profiling. bool isInstrumentsConstant(Instruction &I); /// Propagate liveness to reaching definitions. void markLiveInstructions(); + /// Mark an instruction as live. void markLive(Instruction *I); + /// Mark a block as live. void markLive(BlockInfoType &BB); void markLive(BasicBlock *BB) { markLive(BlockInfo[BB]); } @@ -157,11 +190,14 @@ class AggressiveDeadCodeElimination { void makeUnconditional(BasicBlock *BB, BasicBlock *Target); public: - AggressiveDeadCodeElimination(Function &F, PostDominatorTree &PDT) - : F(F), PDT(PDT) {} + AggressiveDeadCodeElimination(Function &F, DominatorTree &DT, + PostDominatorTree &PDT) + : F(F), DT(DT), PDT(PDT) {} + bool performDeadCodeElimination(); }; -} + +} // end anonymous namespace bool AggressiveDeadCodeElimination::performDeadCodeElimination() { initialize(); @@ -175,7 +211,6 @@ static bool isUnconditionalBranch(TerminatorInst *Term) { } void AggressiveDeadCodeElimination::initialize() { - auto NumBlocks = F.size(); // We will have an entry in the map for each block so we grow the @@ -217,7 +252,8 @@ void AggressiveDeadCodeElimination::initialize() { // to recording which nodes have been visited we also record whether // a node is currently on the "stack" of active ancestors of the current // node. - typedef DenseMap<BasicBlock *, bool> StatusMap ; + using StatusMap = DenseMap<BasicBlock *, bool>; + class DFState : public StatusMap { public: std::pair<StatusMap::iterator, bool> insert(BasicBlock *BB) { @@ -253,27 +289,23 @@ void AggressiveDeadCodeElimination::initialize() { } } - // Mark blocks live if there is no path from the block to the - // return of the function or a successor for which this is true. - // This protects IDFCalculator which cannot handle such blocks. - for (auto &BBInfoPair : BlockInfo) { - auto &BBInfo = BBInfoPair.second; - if (BBInfo.terminatorIsLive()) - continue; - auto *BB = BBInfo.BB; - if (!PDT.getNode(BB)) { - DEBUG(dbgs() << "Not post-dominated by return: " << BB->getName() + // Mark blocks live if there is no path from the block to a + // return of the function. + // We do this by seeing which of the postdomtree root children exit the + // program, and for all others, mark the subtree live. + for (auto &PDTChild : children<DomTreeNode *>(PDT.getRootNode())) { + auto *BB = PDTChild->getBlock(); + auto &Info = BlockInfo[BB]; + // Real function return + if (isa<ReturnInst>(Info.Terminator)) { + DEBUG(dbgs() << "post-dom root child is a return: " << BB->getName() << '\n';); - markLive(BBInfo.Terminator); continue; } - for (auto *Succ : successors(BB)) - if (!PDT.getNode(Succ)) { - DEBUG(dbgs() << "Successor not post-dominated by return: " - << BB->getName() << '\n';); - markLive(BBInfo.Terminator); - break; - } + + // This child is something else, like an infinite loop. + for (auto DFNode : depth_first(PDTChild)) + markLive(BlockInfo[DFNode->getBlock()].Terminator); } // Treat the entry block as always live @@ -318,7 +350,6 @@ bool AggressiveDeadCodeElimination::isInstrumentsConstant(Instruction &I) { } void AggressiveDeadCodeElimination::markLiveInstructions() { - // Propagate liveness backwards to operands. do { // Worklist holds newly discovered live instructions @@ -343,7 +374,6 @@ void AggressiveDeadCodeElimination::markLiveInstructions() { } void AggressiveDeadCodeElimination::markLive(Instruction *I) { - auto &Info = InstInfo[I]; if (Info.Live) return; @@ -430,7 +460,6 @@ void AggressiveDeadCodeElimination::markPhiLive(PHINode *PN) { } void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { - if (BlocksWithDeadTerminators.empty()) return; @@ -469,7 +498,6 @@ void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { // //===----------------------------------------------------------------------===// bool AggressiveDeadCodeElimination::removeDeadInstructions() { - // Updates control and dataflow around dead blocks updateDeadRegions(); @@ -527,7 +555,6 @@ bool AggressiveDeadCodeElimination::removeDeadInstructions() { // A dead region is the set of dead blocks with a common live post-dominator. void AggressiveDeadCodeElimination::updateDeadRegions() { - DEBUG({ dbgs() << "final dead terminator blocks: " << '\n'; for (auto *BB : BlocksWithDeadTerminators) @@ -561,21 +588,40 @@ void AggressiveDeadCodeElimination::updateDeadRegions() { } assert((PreferredSucc && PreferredSucc->PostOrder > 0) && "Failed to find safe successor for dead branch"); + + // Collect removed successors to update the (Post)DominatorTrees. + SmallPtrSet<BasicBlock *, 4> RemovedSuccessors; bool First = true; for (auto *Succ : successors(BB)) { - if (!First || Succ != PreferredSucc->BB) + if (!First || Succ != PreferredSucc->BB) { Succ->removePredecessor(BB); - else + RemovedSuccessors.insert(Succ); + } else First = false; } makeUnconditional(BB, PreferredSucc->BB); + + // Inform the dominators about the deleted CFG edges. + SmallVector<DominatorTree::UpdateType, 4> DeletedEdges; + for (auto *Succ : RemovedSuccessors) { + // It might have happened that the same successor appeared multiple times + // and the CFG edge wasn't really removed. + if (Succ != PreferredSucc->BB) { + DEBUG(dbgs() << "ADCE: (Post)DomTree edge enqueued for deletion" + << BB->getName() << " -> " << Succ->getName() << "\n"); + DeletedEdges.push_back({DominatorTree::Delete, BB, Succ}); + } + } + + DT.applyUpdates(DeletedEdges); + PDT.applyUpdates(DeletedEdges); + NumBranchesRemoved += 1; } } // reverse top-sort order void AggressiveDeadCodeElimination::computeReversePostOrder() { - // This provides a post-order numbering of the reverse control flow graph // Note that it is incomplete in the presence of infinite loops but we don't // need numbers blocks which don't reach the end of the functions since @@ -613,6 +659,9 @@ void AggressiveDeadCodeElimination::makeUnconditional(BasicBlock *BB, InstInfo[NewTerm].Live = true; if (const DILocation *DL = PredTerm->getDebugLoc()) NewTerm->setDebugLoc(DL); + + InstInfo.erase(PredTerm); + PredTerm->eraseFromParent(); } //===----------------------------------------------------------------------===// @@ -621,19 +670,24 @@ void AggressiveDeadCodeElimination::makeUnconditional(BasicBlock *BB, // //===----------------------------------------------------------------------===// PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) { + auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F); - if (!AggressiveDeadCodeElimination(F, PDT).performDeadCodeElimination()) + if (!AggressiveDeadCodeElimination(F, DT, PDT).performDeadCodeElimination()) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserveSet<CFGAnalyses>(); PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); return PA; } namespace { + struct ADCELegacyPass : public FunctionPass { static char ID; // Pass identification, replacement for typeid + ADCELegacyPass() : FunctionPass(ID) { initializeADCELegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -641,22 +695,34 @@ struct ADCELegacyPass : public FunctionPass { bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); - return AggressiveDeadCodeElimination(F, PDT).performDeadCodeElimination(); + return AggressiveDeadCodeElimination(F, DT, PDT) + .performDeadCodeElimination(); } void getAnalysisUsage(AnalysisUsage &AU) const override { + // We require DominatorTree here only to update and thus preserve it. + AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<PostDominatorTreeWrapperPass>(); if (!RemoveControlFlowFlag) AU.setPreservesCFG(); + else { + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<PostDominatorTreeWrapperPass>(); + } AU.addPreserved<GlobalsAAWrapperPass>(); } }; -} + +} // end anonymous namespace char ADCELegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) INITIALIZE_PASS_END(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", false, false) diff --git a/lib/Transforms/Scalar/BDCE.cpp b/lib/Transforms/Scalar/BDCE.cpp index 2e5618686ec2..851efa000f65 100644 --- a/lib/Transforms/Scalar/BDCE.cpp +++ b/lib/Transforms/Scalar/BDCE.cpp @@ -20,11 +20,8 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/IR/CFG.h" #include "llvm/IR/InstIterator.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Operator.h" #include "llvm/Pass.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -48,8 +45,18 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { // If all bits of a user are demanded, then we know that nothing below that // in the def-use chain needs to be changed. auto *J = dyn_cast<Instruction>(JU); - if (J && !DB.getDemandedBits(J).isAllOnesValue()) + if (J && J->getType()->isSized() && + !DB.getDemandedBits(J).isAllOnesValue()) WorkList.push_back(J); + + // Note that we need to check for unsized types above before asking for + // demanded bits. Normally, the only way to reach an instruction with an + // unsized type is via an instruction that has side effects (or otherwise + // will demand its input bits). However, if we have a readnone function + // that returns an unsized type (e.g., void), we must avoid asking for the + // demanded bits of the function call's return value. A void-returning + // readnone function is always dead (and so we can stop walking the use/def + // chain here), but the check is necessary to avoid asserting. } // DFS through subsequent users while tracking visits to avoid cycles. @@ -70,7 +77,8 @@ static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { // If all bits of a user are demanded, then we know that nothing below // that in the def-use chain needs to be changed. auto *K = dyn_cast<Instruction>(KU); - if (K && !Visited.count(K) && !DB.getDemandedBits(K).isAllOnesValue()) + if (K && !Visited.count(K) && K->getType()->isSized() && + !DB.getDemandedBits(K).isAllOnesValue()) WorkList.push_back(K); } } diff --git a/lib/Transforms/Scalar/CMakeLists.txt b/lib/Transforms/Scalar/CMakeLists.txt index 457c9427ab9a..0562d3882f8b 100644 --- a/lib/Transforms/Scalar/CMakeLists.txt +++ b/lib/Transforms/Scalar/CMakeLists.txt @@ -2,11 +2,13 @@ add_llvm_library(LLVMScalarOpts ADCE.cpp AlignmentFromAssumptions.cpp BDCE.cpp + CallSiteSplitting.cpp ConstantHoisting.cpp ConstantProp.cpp CorrelatedValuePropagation.cpp DCE.cpp DeadStoreElimination.cpp + DivRemPairs.cpp EarlyCSE.cpp FlattenCFGPass.cpp Float2Int.cpp @@ -42,6 +44,7 @@ add_llvm_library(LLVMScalarOpts LowerExpectIntrinsic.cpp LowerGuardIntrinsic.cpp MemCpyOptimizer.cpp + MergeICmps.cpp MergedLoadStoreMotion.cpp NaryReassociate.cpp NewGVN.cpp @@ -59,6 +62,7 @@ add_llvm_library(LLVMScalarOpts SimplifyCFGPass.cpp Sink.cpp SpeculativeExecution.cpp + SpeculateAroundPHIs.cpp StraightLineStrengthReduce.cpp StructurizeCFG.cpp TailRecursionElimination.cpp diff --git a/lib/Transforms/Scalar/CallSiteSplitting.cpp b/lib/Transforms/Scalar/CallSiteSplitting.cpp new file mode 100644 index 000000000000..d8c408035038 --- /dev/null +++ b/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -0,0 +1,428 @@ +//===- CallSiteSplitting.cpp ----------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a transformation that tries to split a call-site to pass +// more constrained arguments if its argument is predicated in the control flow +// so that we can expose better context to the later passes (e.g, inliner, jump +// threading, or IPA-CP based function cloning, etc.). +// As of now we support two cases : +// +// 1) If a call site is dominated by an OR condition and if any of its arguments +// are predicated on this OR condition, try to split the condition with more +// constrained arguments. For example, in the code below, we try to split the +// call site since we can predicate the argument(ptr) based on the OR condition. +// +// Split from : +// if (!ptr || c) +// callee(ptr); +// to : +// if (!ptr) +// callee(null) // set the known constant value +// else if (c) +// callee(nonnull ptr) // set non-null attribute in the argument +// +// 2) We can also split a call-site based on constant incoming values of a PHI +// For example, +// from : +// Header: +// %c = icmp eq i32 %i1, %i2 +// br i1 %c, label %Tail, label %TBB +// TBB: +// br label Tail% +// Tail: +// %p = phi i32 [ 0, %Header], [ 1, %TBB] +// call void @bar(i32 %p) +// to +// Header: +// %c = icmp eq i32 %i1, %i2 +// br i1 %c, label %Tail-split0, label %TBB +// TBB: +// br label %Tail-split1 +// Tail-split0: +// call void @bar(i32 0) +// br label %Tail +// Tail-split1: +// call void @bar(i32 1) +// br label %Tail +// Tail: +// %p = phi i32 [ 0, %Tail-split0 ], [ 1, %Tail-split1 ] +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/CallSiteSplitting.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "callsite-splitting" + +STATISTIC(NumCallSiteSplit, "Number of call-site split"); + +static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI, + Value *Op) { + CallSite CS(NewCallI); + unsigned ArgNo = 0; + for (auto &I : CS.args()) { + if (&*I == Op) + CS.addParamAttr(ArgNo, Attribute::NonNull); + ++ArgNo; + } +} + +static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI, + Value *Op, Constant *ConstValue) { + CallSite CS(NewCallI); + unsigned ArgNo = 0; + for (auto &I : CS.args()) { + if (&*I == Op) + CS.setArgument(ArgNo, ConstValue); + ++ArgNo; + } +} + +static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { + assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand."); + Value *Op0 = Cmp->getOperand(0); + unsigned ArgNo = 0; + for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; + ++I, ++ArgNo) { + // Don't consider constant or arguments that are already known non-null. + if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull)) + continue; + + if (*I == Op0) + return true; + } + return false; +} + +/// If From has a conditional jump to To, add the condition to Conditions, +/// if it is relevant to any argument at CS. +static void +recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To, + SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { + auto *BI = dyn_cast<BranchInst>(From->getTerminator()); + if (!BI || !BI->isConditional()) + return; + + CmpInst::Predicate Pred; + Value *Cond = BI->getCondition(); + if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) + return; + + ICmpInst *Cmp = cast<ICmpInst>(Cond); + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) + if (isCondRelevantToAnyCallArgument(Cmp, CS)) + Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To + ? Pred + : Cmp->getInversePredicate()}); +} + +/// Record ICmp conditions relevant to any argument in CS following Pred's +/// single successors. If there are conflicting conditions along a path, like +/// x == 1 and x == 0, the first condition will be used. +static void +recordConditions(const CallSite &CS, BasicBlock *Pred, + SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { + recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); + BasicBlock *From = Pred; + BasicBlock *To = Pred; + SmallPtrSet<BasicBlock *, 4> Visited = {From}; + while (!Visited.count(From->getSinglePredecessor()) && + (From = From->getSinglePredecessor())) { + recordCondition(CS, From, To, Conditions); + To = From; + } +} + +static Instruction * +addConditions(CallSite &CS, + SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { + if (Conditions.empty()) + return nullptr; + + Instruction *NewCI = CS.getInstruction()->clone(); + for (auto &Cond : Conditions) { + Value *Arg = Cond.first->getOperand(0); + Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1)); + if (Cond.second == ICmpInst::ICMP_EQ) + setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal); + else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { + assert(Cond.second == ICmpInst::ICMP_NE); + addNonNullAttribute(CS.getInstruction(), NewCI, Arg); + } + } + return NewCI; +} + +static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) { + SmallVector<BasicBlock *, 2> Preds(predecessors((BB))); + assert(Preds.size() == 2 && "Expected exactly 2 predecessors!"); + return Preds; +} + +static bool canSplitCallSite(CallSite CS) { + // FIXME: As of now we handle only CallInst. InvokeInst could be handled + // without too much effort. + Instruction *Instr = CS.getInstruction(); + if (!isa<CallInst>(Instr)) + return false; + + // Allow splitting a call-site only when there is no instruction before the + // call-site in the basic block. Based on this constraint, we only clone the + // call instruction, and we do not move a call-site across any other + // instruction. + BasicBlock *CallSiteBB = Instr->getParent(); + if (Instr != CallSiteBB->getFirstNonPHIOrDbg()) + return false; + + // Need 2 predecessors and cannot split an edge from an IndirectBrInst. + SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB)); + if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) || + isa<IndirectBrInst>(Preds[1]->getTerminator())) + return false; + + return CallSiteBB->canSplitPredecessors(); +} + +/// Return true if the CS is split into its new predecessors which are directly +/// hooked to each of its orignial predecessors pointed by PredBB1 and PredBB2. +/// In OR predicated case, PredBB1 will point the header, and PredBB2 will point +/// to the second compare block. CallInst1 and CallInst2 will be the new +/// call-sites placed in the new predecessors split for PredBB1 and PredBB2, +/// repectively. Therefore, CallInst1 will be the call-site placed +/// between Header and Tail, and CallInst2 will be the call-site between TBB and +/// Tail. For example, in the IR below with an OR condition, the call-site can +/// be split +/// +/// from : +/// +/// Header: +/// %c = icmp eq i32* %a, null +/// br i1 %c %Tail, %TBB +/// TBB: +/// %c2 = icmp eq i32* %b, null +/// br i1 %c %Tail, %End +/// Tail: +/// %ca = call i1 @callee (i32* %a, i32* %b) +/// +/// to : +/// +/// Header: // PredBB1 is Header +/// %c = icmp eq i32* %a, null +/// br i1 %c %Tail-split1, %TBB +/// TBB: // PredBB2 is TBB +/// %c2 = icmp eq i32* %b, null +/// br i1 %c %Tail-split2, %End +/// Tail-split1: +/// %ca1 = call @callee (i32* null, i32* %b) // CallInst1 +/// br %Tail +/// Tail-split2: +/// %ca2 = call @callee (i32* nonnull %a, i32* null) // CallInst2 +/// br %Tail +/// Tail: +/// %p = phi i1 [%ca1, %Tail-split1],[%ca2, %Tail-split2] +/// +/// Note that for an OR predicated case, CallInst1 and CallInst2 should be +/// created with more constrained arguments in +/// createCallSitesOnOrPredicatedArgument(). +static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2, + Instruction *CallInst1, Instruction *CallInst2) { + Instruction *Instr = CS.getInstruction(); + BasicBlock *TailBB = Instr->getParent(); + assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site"); + + BasicBlock *SplitBlock1 = + SplitBlockPredecessors(TailBB, PredBB1, ".predBB1.split"); + BasicBlock *SplitBlock2 = + SplitBlockPredecessors(TailBB, PredBB2, ".predBB2.split"); + + assert((SplitBlock1 && SplitBlock2) && "Unexpected new basic block split."); + + if (!CallInst1) + CallInst1 = Instr->clone(); + if (!CallInst2) + CallInst2 = Instr->clone(); + + CallInst1->insertBefore(&*SplitBlock1->getFirstInsertionPt()); + CallInst2->insertBefore(&*SplitBlock2->getFirstInsertionPt()); + + CallSite CS1(CallInst1); + CallSite CS2(CallInst2); + + // Handle PHIs used as arguments in the call-site. + for (auto &PI : *TailBB) { + PHINode *PN = dyn_cast<PHINode>(&PI); + if (!PN) + break; + unsigned ArgNo = 0; + for (auto &CI : CS.args()) { + if (&*CI == PN) { + CS1.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock1)); + CS2.setArgument(ArgNo, PN->getIncomingValueForBlock(SplitBlock2)); + } + ++ArgNo; + } + } + + // Replace users of the original call with a PHI mering call-sites split. + if (Instr->getNumUses()) { + PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call", + TailBB->getFirstNonPHI()); + PN->addIncoming(CallInst1, SplitBlock1); + PN->addIncoming(CallInst2, SplitBlock2); + Instr->replaceAllUsesWith(PN); + } + DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); + DEBUG(dbgs() << " " << *CallInst1 << " in " << SplitBlock1->getName() + << "\n"); + DEBUG(dbgs() << " " << *CallInst2 << " in " << SplitBlock2->getName() + << "\n"); + Instr->eraseFromParent(); + NumCallSiteSplit++; +} + +// Return true if the call-site has an argument which is a PHI with only +// constant incoming values. +static bool isPredicatedOnPHI(CallSite CS) { + Instruction *Instr = CS.getInstruction(); + BasicBlock *Parent = Instr->getParent(); + if (Instr != Parent->getFirstNonPHIOrDbg()) + return false; + + for (auto &BI : *Parent) { + if (PHINode *PN = dyn_cast<PHINode>(&BI)) { + for (auto &I : CS.args()) + if (&*I == PN) { + assert(PN->getNumIncomingValues() == 2 && + "Unexpected number of incoming values"); + if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1)) + return false; + if (PN->getIncomingValue(0) == PN->getIncomingValue(1)) + continue; + if (isa<Constant>(PN->getIncomingValue(0)) && + isa<Constant>(PN->getIncomingValue(1))) + return true; + } + } + break; + } + return false; +} + +static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) { + if (!isPredicatedOnPHI(CS)) + return false; + + auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); + splitCallSite(CS, Preds[0], Preds[1], nullptr, nullptr); + return true; +} +// Check if one of the predecessors is a single predecessors of the other. +// This is a requirement for control flow modeling an OR. HeaderBB points to +// the single predecessor and OrBB points to other node. HeaderBB potentially +// contains the first compare of the OR and OrBB the second. +static bool isOrHeader(BasicBlock *HeaderBB, BasicBlock *OrBB) { + return OrBB->getSinglePredecessor() == HeaderBB && + HeaderBB->getTerminator()->getNumSuccessors() == 2; +} + +static bool tryToSplitOnOrPredicatedArgument(CallSite CS) { + auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); + if (!isOrHeader(Preds[0], Preds[1]) && !isOrHeader(Preds[1], Preds[0])) + return false; + + SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2; + recordConditions(CS, Preds[0], C1); + recordConditions(CS, Preds[1], C2); + + Instruction *CallInst1 = addConditions(CS, C1); + Instruction *CallInst2 = addConditions(CS, C2); + if (!CallInst1 && !CallInst2) + return false; + + splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1); + return true; +} + +static bool tryToSplitCallSite(CallSite CS) { + if (!CS.arg_size() || !canSplitCallSite(CS)) + return false; + return tryToSplitOnOrPredicatedArgument(CS) || + tryToSplitOnPHIPredicatedArgument(CS); +} + +static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) { + bool Changed = false; + for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) { + BasicBlock &BB = *BI++; + for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { + Instruction *I = &*II++; + CallSite CS(cast<Value>(I)); + if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI)) + continue; + + Function *Callee = CS.getCalledFunction(); + if (!Callee || Callee->isDeclaration()) + continue; + Changed |= tryToSplitCallSite(CS); + } + } + return Changed; +} + +namespace { +struct CallSiteSplittingLegacyPass : public FunctionPass { + static char ID; + CallSiteSplittingLegacyPass() : FunctionPass(ID) { + initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + FunctionPass::getAnalysisUsage(AU); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return doCallSiteSplitting(F, TLI); + } +}; +} // namespace + +char CallSiteSplittingLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting", + "Call-site splitting", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting", + "Call-site splitting", false, false) +FunctionPass *llvm::createCallSiteSplittingPass() { + return new CallSiteSplittingLegacyPass(); +} + +PreservedAnalyses CallSiteSplittingPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + + if (!doCallSiteSplitting(F, TLI)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + return PA; +} diff --git a/lib/Transforms/Scalar/ConstantHoisting.cpp b/lib/Transforms/Scalar/ConstantHoisting.cpp index 122c9314e022..e4b08c5ed305 100644 --- a/lib/Transforms/Scalar/ConstantHoisting.cpp +++ b/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -34,18 +34,39 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/ConstantHoisting.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> #include <tuple> +#include <utility> using namespace llvm; using namespace consthoist; @@ -62,10 +83,12 @@ static cl::opt<bool> ConstHoistWithBlockFrequency( "without hoisting.")); namespace { + /// \brief The constant hoisting pass. class ConstantHoistingLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid + ConstantHoistingLegacyPass() : FunctionPass(ID) { initializeConstantHoistingLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -87,9 +110,11 @@ public: private: ConstantHoistingPass Impl; }; -} + +} // end anonymous namespace char ConstantHoistingLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist", "Constant Hoisting", false, false) INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) @@ -128,7 +153,6 @@ bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { return MadeChange; } - /// \brief Find the constant materialization insertion point. Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, unsigned Idx) const { @@ -217,8 +241,9 @@ static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, } // Visit Orders in bottom-up order. - typedef std::pair<SmallPtrSet<BasicBlock *, 16>, BlockFrequency> - InsertPtsCostPair; + using InsertPtsCostPair = + std::pair<SmallPtrSet<BasicBlock *, 16>, BlockFrequency>; + // InsertPtsMap is a map from a BB to the best insertion points for the // subtree of BB (subtree not including the BB itself). DenseMap<BasicBlock *, InsertPtsCostPair> InsertPtsMap; @@ -310,7 +335,6 @@ SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( return InsertPts; } - /// \brief Record constant integer ConstInt for instruction Inst at operand /// index Idx. /// @@ -351,7 +375,6 @@ void ConstantHoistingPass::collectConstantCandidates( } } - /// \brief Check the operand for instruction Inst at index Idx. void ConstantHoistingPass::collectConstantCandidates( ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx) { @@ -393,7 +416,6 @@ void ConstantHoistingPass::collectConstantCandidates( } } - /// \brief Scan the instruction for expensive integer constants and record them /// in the constant candidate vector. void ConstantHoistingPass::collectConstantCandidates( @@ -427,9 +449,8 @@ void ConstantHoistingPass::collectConstantCandidates(Function &Fn) { // bit widths (APInt Operator- does not like that). If the value cannot be // represented in uint64 we return an "empty" APInt. This is then interpreted // as the value is not in range. -static llvm::Optional<APInt> calculateOffsetDiff(const APInt &V1, - const APInt &V2) { - llvm::Optional<APInt> Res = None; +static Optional<APInt> calculateOffsetDiff(const APInt &V1, const APInt &V2) { + Optional<APInt> Res = None; unsigned BW = V1.getBitWidth() > V2.getBitWidth() ? V1.getBitWidth() : V2.getBitWidth(); uint64_t LimVal1 = V1.getLimitedValue(); @@ -496,9 +517,9 @@ ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, DEBUG(dbgs() << "Cost: " << Cost << "\n"); for (auto C2 = S; C2 != E; ++C2) { - llvm::Optional<APInt> Diff = calculateOffsetDiff( - C2->ConstInt->getValue(), - ConstCand->ConstInt->getValue()); + Optional<APInt> Diff = calculateOffsetDiff( + C2->ConstInt->getValue(), + ConstCand->ConstInt->getValue()); if (Diff) { const int ImmCosts = TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, Diff.getValue(), Ty); @@ -696,6 +717,9 @@ bool ConstantHoistingPass::emitBaseConstants() { IntegerType *Ty = ConstInfo.BaseConstant->getType(); Instruction *Base = new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP); + + Base->setDebugLoc(IP->getDebugLoc()); + DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant << ") to BB " << IP->getParent()->getName() << '\n' << *Base << '\n'); @@ -714,6 +738,8 @@ bool ConstantHoistingPass::emitBaseConstants() { emitBaseConstants(Base, RCI.Offset, U); ReBasesNum++; } + + Base->setDebugLoc(DILocation::getMergedLocation(Base->getDebugLoc(), U.Inst->getDebugLoc())); } } UsesNum = Uses; @@ -722,7 +748,6 @@ bool ConstantHoistingPass::emitBaseConstants() { assert(!Base->use_empty() && "The use list is empty!?"); assert(isa<Instruction>(Base->user_back()) && "All uses should be instructions."); - Base->setDebugLoc(cast<Instruction>(Base->user_back())->getDebugLoc()); } (void)UsesNum; (void)ReBasesNum; diff --git a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp index 28157783daa7..8f468ebf8949 100644 --- a/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp +++ b/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -12,22 +12,41 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "correlated-value-propagation" @@ -41,13 +60,16 @@ STATISTIC(NumDeadCases, "Number of switch cases removed"); STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); STATISTIC(NumAShrs, "Number of ashr converted to lshr"); STATISTIC(NumSRems, "Number of srem converted to urem"); +STATISTIC(NumOverflows, "Number of overflow checks removed"); static cl::opt<bool> DontProcessAdds("cvp-dont-process-adds", cl::init(true)); namespace { + class CorrelatedValuePropagation : public FunctionPass { public: static char ID; + CorrelatedValuePropagation(): FunctionPass(ID) { initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry()); } @@ -59,9 +81,11 @@ namespace { AU.addPreserved<GlobalsAAWrapperPass>(); } }; -} + +} // end anonymous namespace char CorrelatedValuePropagation::ID = 0; + INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation", "Value Propagation", false, false) INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) @@ -302,11 +326,72 @@ static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { return Changed; } +// See if we can prove that the given overflow intrinsic will not overflow. +static bool willNotOverflow(IntrinsicInst *II, LazyValueInfo *LVI) { + using OBO = OverflowingBinaryOperator; + auto NoWrap = [&] (Instruction::BinaryOps BinOp, unsigned NoWrapKind) { + Value *RHS = II->getOperand(1); + ConstantRange RRange = LVI->getConstantRange(RHS, II->getParent(), II); + ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + BinOp, RRange, NoWrapKind); + // As an optimization, do not compute LRange if we do not need it. + if (NWRegion.isEmptySet()) + return false; + Value *LHS = II->getOperand(0); + ConstantRange LRange = LVI->getConstantRange(LHS, II->getParent(), II); + return NWRegion.contains(LRange); + }; + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::uadd_with_overflow: + return NoWrap(Instruction::Add, OBO::NoUnsignedWrap); + case Intrinsic::sadd_with_overflow: + return NoWrap(Instruction::Add, OBO::NoSignedWrap); + case Intrinsic::usub_with_overflow: + return NoWrap(Instruction::Sub, OBO::NoUnsignedWrap); + case Intrinsic::ssub_with_overflow: + return NoWrap(Instruction::Sub, OBO::NoSignedWrap); + } + return false; +} + +static void processOverflowIntrinsic(IntrinsicInst *II) { + Value *NewOp = nullptr; + switch (II->getIntrinsicID()) { + default: + llvm_unreachable("Unexpected instruction."); + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + NewOp = BinaryOperator::CreateAdd(II->getOperand(0), II->getOperand(1), + II->getName(), II); + break; + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + NewOp = BinaryOperator::CreateSub(II->getOperand(0), II->getOperand(1), + II->getName(), II); + break; + } + ++NumOverflows; + IRBuilder<> B(II); + Value *NewI = B.CreateInsertValue(UndefValue::get(II->getType()), NewOp, 0); + NewI = B.CreateInsertValue(NewI, ConstantInt::getFalse(II->getContext()), 1); + II->replaceAllUsesWith(NewI); + II->eraseFromParent(); +} + /// Infer nonnull attributes for the arguments at the specified callsite. static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { SmallVector<unsigned, 4> ArgNos; unsigned ArgNo = 0; + if (auto *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { + if (willNotOverflow(II, LVI)) { + processOverflowIntrinsic(II); + return true; + } + } + for (Value *V : CS.args()) { PointerType *Type = dyn_cast<PointerType>(V->getType()); // Try to mark pointer typed parameters as non-null. We skip the @@ -335,18 +420,6 @@ static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { return true; } -// Helper function to rewrite srem and sdiv. As a policy choice, we choose not -// to waste compile time on anything where the operands are local defs. While -// LVI can sometimes reason about such cases, it's not its primary purpose. -static bool hasLocalDefs(BinaryOperator *SDI) { - for (Value *O : SDI->operands()) { - auto *I = dyn_cast<Instruction>(O); - if (I && I->getParent() == SDI->getParent()) - return true; - } - return false; -} - static bool hasPositiveOperands(BinaryOperator *SDI, LazyValueInfo *LVI) { Constant *Zero = ConstantInt::get(SDI->getType(), 0); for (Value *O : SDI->operands()) { @@ -358,7 +431,7 @@ static bool hasPositiveOperands(BinaryOperator *SDI, LazyValueInfo *LVI) { } static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy() || hasLocalDefs(SDI) || + if (SDI->getType()->isVectorTy() || !hasPositiveOperands(SDI, LVI)) return false; @@ -376,7 +449,7 @@ static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { /// conditions, this can sometimes prove conditions instcombine can't by /// exploiting range information. static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy() || hasLocalDefs(SDI) || + if (SDI->getType()->isVectorTy() || !hasPositiveOperands(SDI, LVI)) return false; @@ -391,7 +464,7 @@ static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { } static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { - if (SDI->getType()->isVectorTy() || hasLocalDefs(SDI)) + if (SDI->getType()->isVectorTy()) return false; Constant *Zero = ConstantInt::get(SDI->getType(), 0); @@ -410,12 +483,12 @@ static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { } static bool processAdd(BinaryOperator *AddOp, LazyValueInfo *LVI) { - typedef OverflowingBinaryOperator OBO; + using OBO = OverflowingBinaryOperator; if (DontProcessAdds) return false; - if (AddOp->getType()->isVectorTy() || hasLocalDefs(AddOp)) + if (AddOp->getType()->isVectorTy()) return false; bool NSW = AddOp->hasNoSignedWrap(); @@ -492,7 +565,7 @@ static bool runImpl(Function &F, LazyValueInfo *LVI, const SimplifyQuery &SQ) { // blocks before querying later blocks (which require us to analyze early // blocks). Eagerly simplifying shallow blocks means there is strictly less // work to do for deep blocks. This also means we don't visit unreachable - // blocks. + // blocks. for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { bool BBChanged = false; for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { diff --git a/lib/Transforms/Scalar/DeadStoreElimination.cpp b/lib/Transforms/Scalar/DeadStoreElimination.cpp index 1ec38e56aa4c..e703014bb0e6 100644 --- a/lib/Transforms/Scalar/DeadStoreElimination.cpp +++ b/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -16,31 +16,55 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/DeadStoreElimination.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CaptureTracking.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstddef> +#include <iterator> #include <map> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "dse" @@ -49,18 +73,23 @@ STATISTIC(NumRedundantStores, "Number of redundant stores deleted"); STATISTIC(NumFastStores, "Number of stores deleted"); STATISTIC(NumFastOther , "Number of other instrs removed"); STATISTIC(NumCompletePartials, "Number of stores dead by later partials"); +STATISTIC(NumModifiedStores, "Number of stores modified"); static cl::opt<bool> EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", cl::init(true), cl::Hidden, cl::desc("Enable partial-overwrite tracking in DSE")); +static cl::opt<bool> +EnablePartialStoreMerging("enable-dse-partial-store-merging", + cl::init(true), cl::Hidden, + cl::desc("Enable partial store merging in DSE")); //===----------------------------------------------------------------------===// // Helper functions //===----------------------------------------------------------------------===// -typedef std::map<int64_t, int64_t> OverlapIntervalsTy; -typedef DenseMap<Instruction *, OverlapIntervalsTy> InstOverlapIntervalsTy; +using OverlapIntervalsTy = std::map<int64_t, int64_t>; +using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>; /// Delete this instruction. Before we do, go through and zero out all the /// operands of this instruction. If any of them become dead, delete them and @@ -209,7 +238,6 @@ static bool isRemovable(Instruction *I) { case Intrinsic::init_trampoline: // Always safe to remove init_trampoline. return true; - case Intrinsic::memset: case Intrinsic::memmove: case Intrinsic::memcpy: @@ -224,7 +252,6 @@ static bool isRemovable(Instruction *I) { return false; } - /// Returns true if the end of this instruction can be safely shortened in /// length. static bool isShortenableAtTheEnd(Instruction *I) { @@ -287,14 +314,24 @@ static uint64_t getPointerSize(const Value *V, const DataLayout &DL, } namespace { -enum OverwriteResult { OW_Begin, OW_Complete, OW_End, OW_Unknown }; -} + +enum OverwriteResult { + OW_Begin, + OW_Complete, + OW_End, + OW_PartialEarlierWithFullLater, + OW_Unknown +}; + +} // end anonymous namespace /// Return 'OW_Complete' if a store to the 'Later' location completely /// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the /// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the -/// beginning of the 'Earlier' location is overwritten by 'Later', or -/// 'OW_Unknown' if nothing can be determined. +/// beginning of the 'Earlier' location is overwritten by 'Later'. +/// 'OW_PartialEarlierWithFullLater' means that an earlier (big) store was +/// overwritten by a latter (smaller) store which doesn't write outside the big +/// store's memory locations. Returns 'OW_Unknown' if nothing can be determined. static OverwriteResult isOverwrite(const MemoryLocation &Later, const MemoryLocation &Earlier, const DataLayout &DL, @@ -427,6 +464,19 @@ static OverwriteResult isOverwrite(const MemoryLocation &Later, } } + // Check for an earlier store which writes to all the memory locations that + // the later store writes to. + if (EnablePartialStoreMerging && LaterOff >= EarlierOff && + int64_t(EarlierOff + Earlier.Size) > LaterOff && + uint64_t(LaterOff - EarlierOff) + Later.Size <= Earlier.Size) { + DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" << EarlierOff + << ", " << int64_t(EarlierOff + Earlier.Size) + << ") by a later store [" << LaterOff << ", " + << int64_t(LaterOff + Later.Size) << ")\n"); + // TODO: Maybe come up with a better name? + return OW_PartialEarlierWithFullLater; + } + // Another interesting case is if the later store overwrites the end of the // earlier store. // @@ -544,11 +594,9 @@ static bool memoryIsNotModifiedBetween(Instruction *FirstI, } for (; BI != EI; ++BI) { Instruction *I = &*BI; - if (I->mayWriteToMemory() && I != SecondI) { - auto Res = AA->getModRefInfo(I, MemLoc); - if (Res & MRI_Mod) + if (I->mayWriteToMemory() && I != SecondI) + if (isModSet(AA->getModRefInfo(I, MemLoc))) return false; - } } if (B != FirstBB) { assert(B != &FirstBB->getParent()->getEntryBlock() && @@ -772,9 +820,7 @@ static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, // the call is live. DeadStackObjects.remove_if([&](Value *I) { // See if the call site touches the value. - ModRefInfo A = AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI)); - - return A == MRI_ModRef || A == MRI_Ref; + return isRefSet(AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI))); }); // If all of the allocas were clobbered by the call then we're not going @@ -840,7 +886,7 @@ static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset, if (!IsOverwriteEnd) LaterOffset = int64_t(LaterOffset + LaterSize); - if (!(llvm::isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) && + if (!(isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) && !((EarlierWriteAlign != 0) && LaterOffset % EarlierWriteAlign == 0)) return false; @@ -1094,6 +1140,8 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, // If we find a write that is a) removable (i.e., non-volatile), b) is // completely obliterated by the store to 'Loc', and c) which we know that // 'Inst' doesn't load from, then we can remove it. + // Also try to merge two stores if a later one only touches memory written + // to by the earlier one. if (isRemovable(DepWrite) && !isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) { int64_t InstWriteOffset, DepWriteOffset; @@ -1123,6 +1171,72 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, bool IsOverwriteEnd = (OR == OW_End); MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize, InstWriteOffset, LaterSize, IsOverwriteEnd); + } else if (EnablePartialStoreMerging && + OR == OW_PartialEarlierWithFullLater) { + auto *Earlier = dyn_cast<StoreInst>(DepWrite); + auto *Later = dyn_cast<StoreInst>(Inst); + if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && + Later && isa<ConstantInt>(Later->getValueOperand())) { + // If the store we find is: + // a) partially overwritten by the store to 'Loc' + // b) the later store is fully contained in the earlier one and + // c) they both have a constant value + // Merge the two stores, replacing the earlier store's value with a + // merge of both values. + // TODO: Deal with other constant types (vectors, etc), and probably + // some mem intrinsics (if needed) + + APInt EarlierValue = + cast<ConstantInt>(Earlier->getValueOperand())->getValue(); + APInt LaterValue = + cast<ConstantInt>(Later->getValueOperand())->getValue(); + unsigned LaterBits = LaterValue.getBitWidth(); + assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); + LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); + + // Offset of the smaller store inside the larger store + unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; + unsigned LShiftAmount = + DL.isBigEndian() + ? EarlierValue.getBitWidth() - BitOffsetDiff - LaterBits + : BitOffsetDiff; + APInt Mask = + APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, + LShiftAmount + LaterBits); + // Clear the bits we'll be replacing, then OR with the smaller + // store, shifted appropriately. + APInt Merged = + (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); + DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *DepWrite + << "\n Later: " << *Inst + << "\n Merged Value: " << Merged << '\n'); + + auto *SI = new StoreInst( + ConstantInt::get(Earlier->getValueOperand()->getType(), Merged), + Earlier->getPointerOperand(), false, Earlier->getAlignment(), + Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite); + + unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa, + LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_nontemporal}; + SI->copyMetadata(*DepWrite, MDToKeep); + ++NumModifiedStores; + + // Remove earlier, wider, store + size_t Idx = InstrOrdering.lookup(DepWrite); + InstrOrdering.erase(DepWrite); + InstrOrdering.insert(std::make_pair(SI, Idx)); + + // Delete the old stores and now-dead instructions that feed them. + deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, &InstrOrdering); + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, + &InstrOrdering); + MadeChange = true; + + // We erased DepWrite and Inst (Loc); start over. + break; + } } } @@ -1137,7 +1251,7 @@ static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, if (DepWrite == &BB.front()) break; // Can't look past this instruction if it might read 'Loc'. - if (AA->getModRefInfo(DepWrite, Loc) & MRI_Ref) + if (isRefSet(AA->getModRefInfo(DepWrite, Loc))) break; InstDep = MD->getPointerDependencyFrom(Loc, /*isLoad=*/ false, @@ -1190,9 +1304,12 @@ PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { } namespace { + /// A legacy pass for the legacy pass manager that wraps \c DSEPass. class DSELegacyPass : public FunctionPass { public: + static char ID; // Pass identification, replacement for typeid + DSELegacyPass() : FunctionPass(ID) { initializeDSELegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -1221,12 +1338,12 @@ public: AU.addPreserved<GlobalsAAWrapperPass>(); AU.addPreserved<MemoryDependenceWrapperPass>(); } - - static char ID; // Pass identification, replacement for typeid }; + } // end anonymous namespace char DSELegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) diff --git a/lib/Transforms/Scalar/DivRemPairs.cpp b/lib/Transforms/Scalar/DivRemPairs.cpp new file mode 100644 index 000000000000..e383af89a384 --- /dev/null +++ b/lib/Transforms/Scalar/DivRemPairs.cpp @@ -0,0 +1,206 @@ +//===- DivRemPairs.cpp - Hoist/decompose division and remainder -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass hoists and/or decomposes integer division and remainder +// instructions to enable CFG improvements and better codegen. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/DivRemPairs.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BypassSlowDivision.h" +using namespace llvm; + +#define DEBUG_TYPE "div-rem-pairs" +STATISTIC(NumPairs, "Number of div/rem pairs"); +STATISTIC(NumHoisted, "Number of instructions hoisted"); +STATISTIC(NumDecomposed, "Number of instructions decomposed"); + +/// Find matching pairs of integer div/rem ops (they have the same numerator, +/// denominator, and signedness). If they exist in different basic blocks, bring +/// them together by hoisting or replace the common division operation that is +/// implicit in the remainder: +/// X % Y <--> X - ((X / Y) * Y). +/// +/// We can largely ignore the normal safety and cost constraints on speculation +/// of these ops when we find a matching pair. This is because we are already +/// guaranteed that any exceptions and most cost are already incurred by the +/// first member of the pair. +/// +/// Note: This transform could be an oddball enhancement to EarlyCSE, GVN, or +/// SimplifyCFG, but it's split off on its own because it's different enough +/// that it doesn't quite match the stated objectives of those passes. +static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, + const DominatorTree &DT) { + bool Changed = false; + + // Insert all divide and remainder instructions into maps keyed by their + // operands and opcode (signed or unsigned). + DenseMap<DivRemMapKey, Instruction *> DivMap, RemMap; + for (auto &BB : F) { + for (auto &I : BB) { + if (I.getOpcode() == Instruction::SDiv) + DivMap[DivRemMapKey(true, I.getOperand(0), I.getOperand(1))] = &I; + else if (I.getOpcode() == Instruction::UDiv) + DivMap[DivRemMapKey(false, I.getOperand(0), I.getOperand(1))] = &I; + else if (I.getOpcode() == Instruction::SRem) + RemMap[DivRemMapKey(true, I.getOperand(0), I.getOperand(1))] = &I; + else if (I.getOpcode() == Instruction::URem) + RemMap[DivRemMapKey(false, I.getOperand(0), I.getOperand(1))] = &I; + } + } + + // We can iterate over either map because we are only looking for matched + // pairs. Choose remainders for efficiency because they are usually even more + // rare than division. + for (auto &RemPair : RemMap) { + // Find the matching division instruction from the division map. + Instruction *DivInst = DivMap[RemPair.getFirst()]; + if (!DivInst) + continue; + + // We have a matching pair of div/rem instructions. If one dominates the + // other, hoist and/or replace one. + NumPairs++; + Instruction *RemInst = RemPair.getSecond(); + bool IsSigned = DivInst->getOpcode() == Instruction::SDiv; + bool HasDivRemOp = TTI.hasDivRemOp(DivInst->getType(), IsSigned); + + // If the target supports div+rem and the instructions are in the same block + // already, there's nothing to do. The backend should handle this. If the + // target does not support div+rem, then we will decompose the rem. + if (HasDivRemOp && RemInst->getParent() == DivInst->getParent()) + continue; + + bool DivDominates = DT.dominates(DivInst, RemInst); + if (!DivDominates && !DT.dominates(RemInst, DivInst)) + continue; + + if (HasDivRemOp) { + // The target has a single div/rem operation. Hoist the lower instruction + // to make the matched pair visible to the backend. + if (DivDominates) + RemInst->moveAfter(DivInst); + else + DivInst->moveAfter(RemInst); + NumHoisted++; + } else { + // The target does not have a single div/rem operation. Decompose the + // remainder calculation as: + // X % Y --> X - ((X / Y) * Y). + Value *X = RemInst->getOperand(0); + Value *Y = RemInst->getOperand(1); + Instruction *Mul = BinaryOperator::CreateMul(DivInst, Y); + Instruction *Sub = BinaryOperator::CreateSub(X, Mul); + + // If the remainder dominates, then hoist the division up to that block: + // + // bb1: + // %rem = srem %x, %y + // bb2: + // %div = sdiv %x, %y + // --> + // bb1: + // %div = sdiv %x, %y + // %mul = mul %div, %y + // %rem = sub %x, %mul + // + // If the division dominates, it's already in the right place. The mul+sub + // will be in a different block because we don't assume that they are + // cheap to speculatively execute: + // + // bb1: + // %div = sdiv %x, %y + // bb2: + // %rem = srem %x, %y + // --> + // bb1: + // %div = sdiv %x, %y + // bb2: + // %mul = mul %div, %y + // %rem = sub %x, %mul + // + // If the div and rem are in the same block, we do the same transform, + // but any code movement would be within the same block. + + if (!DivDominates) + DivInst->moveBefore(RemInst); + Mul->insertAfter(RemInst); + Sub->insertAfter(Mul); + + // Now kill the explicit remainder. We have replaced it with: + // (sub X, (mul (div X, Y), Y) + RemInst->replaceAllUsesWith(Sub); + RemInst->eraseFromParent(); + NumDecomposed++; + } + Changed = true; + } + + return Changed; +} + +// Pass manager boilerplate below here. + +namespace { +struct DivRemPairsLegacyPass : public FunctionPass { + static char ID; + DivRemPairsLegacyPass() : FunctionPass(ID) { + initializeDivRemPairsLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.setPreservesCFG(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + FunctionPass::getAnalysisUsage(AU); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + return optimizeDivRem(F, TTI, DT); + } +}; +} + +char DivRemPairsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(DivRemPairsLegacyPass, "div-rem-pairs", + "Hoist/decompose integer division and remainder", false, + false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(DivRemPairsLegacyPass, "div-rem-pairs", + "Hoist/decompose integer division and remainder", false, + false) +FunctionPass *llvm::createDivRemPairsPass() { + return new DivRemPairsLegacyPass(); +} + +PreservedAnalyses DivRemPairsPass::run(Function &F, + FunctionAnalysisManager &FAM) { + TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); + DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); + if (!optimizeDivRem(F, TTI, DT)) + return PreservedAnalyses::all(); + // TODO: This pass just hoists/replaces math ops - all analyses are preserved? + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/lib/Transforms/Scalar/EarlyCSE.cpp b/lib/Transforms/Scalar/EarlyCSE.cpp index c5c9b2c185d6..5798e1c4ee99 100644 --- a/lib/Transforms/Scalar/EarlyCSE.cpp +++ b/lib/Transforms/Scalar/EarlyCSE.cpp @@ -13,9 +13,12 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/EarlyCSE.h" +#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/ScopedHashTable.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" @@ -24,18 +27,37 @@ #include "llvm/Analysis/MemorySSAUpdater.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/RecyclingAllocator.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> #include <deque> +#include <memory> +#include <utility> + using namespace llvm; using namespace llvm::PatternMatch; @@ -53,6 +75,7 @@ STATISTIC(NumDSE, "Number of trivial dead stores removed"); //===----------------------------------------------------------------------===// namespace { + /// \brief Struct representing the available values in the scoped hash table. struct SimpleValue { Instruction *Inst; @@ -77,20 +100,25 @@ struct SimpleValue { isa<ExtractValueInst>(Inst) || isa<InsertValueInst>(Inst); } }; -} + +} // end anonymous namespace namespace llvm { + template <> struct DenseMapInfo<SimpleValue> { static inline SimpleValue getEmptyKey() { return DenseMapInfo<Instruction *>::getEmptyKey(); } + static inline SimpleValue getTombstoneKey() { return DenseMapInfo<Instruction *>::getTombstoneKey(); } + static unsigned getHashValue(SimpleValue Val); static bool isEqual(SimpleValue LHS, SimpleValue RHS); }; -} + +} // end namespace llvm unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { Instruction *Inst = Val.Inst; @@ -115,6 +143,21 @@ unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { return hash_combine(Inst->getOpcode(), Pred, LHS, RHS); } + // Hash min/max/abs (cmp + select) to allow for commuted operands. + // Min/max may also have non-canonical compare predicate (eg, the compare for + // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the + // compare. + Value *A, *B; + SelectPatternFlavor SPF = matchSelectPattern(Inst, A, B).Flavor; + // TODO: We should also detect FP min/max. + if (SPF == SPF_SMIN || SPF == SPF_SMAX || + SPF == SPF_UMIN || SPF == SPF_UMAX || + SPF == SPF_ABS || SPF == SPF_NABS) { + if (A > B) + std::swap(A, B); + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + if (CastInst *CI = dyn_cast<CastInst>(Inst)) return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0)); @@ -173,6 +216,20 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate(); } + // Min/max/abs can occur with commuted operands, non-canonical predicates, + // and/or non-canonical operands. + Value *LHSA, *LHSB; + SelectPatternFlavor LSPF = matchSelectPattern(LHSI, LHSA, LHSB).Flavor; + // TODO: We should also detect FP min/max. + if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || + LSPF == SPF_UMIN || LSPF == SPF_UMAX || + LSPF == SPF_ABS || LSPF == SPF_NABS) { + Value *RHSA, *RHSB; + SelectPatternFlavor RSPF = matchSelectPattern(RHSI, RHSA, RHSB).Flavor; + return (LSPF == RSPF && ((LHSA == RHSA && LHSB == RHSB) || + (LHSA == RHSB && LHSB == RHSA))); + } + return false; } @@ -181,6 +238,7 @@ bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { //===----------------------------------------------------------------------===// namespace { + /// \brief Struct representing the available call values in the scoped hash /// table. struct CallValue { @@ -206,20 +264,25 @@ struct CallValue { return true; } }; -} + +} // end anonymous namespace namespace llvm { + template <> struct DenseMapInfo<CallValue> { static inline CallValue getEmptyKey() { return DenseMapInfo<Instruction *>::getEmptyKey(); } + static inline CallValue getTombstoneKey() { return DenseMapInfo<Instruction *>::getTombstoneKey(); } + static unsigned getHashValue(CallValue Val); static bool isEqual(CallValue LHS, CallValue RHS); }; -} + +} // end namespace llvm unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) { Instruction *Inst = Val.Inst; @@ -241,6 +304,7 @@ bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) { //===----------------------------------------------------------------------===// namespace { + /// \brief A simple and fast domtree-based CSE pass. /// /// This pass does a simple depth-first walk over the dominator tree, @@ -257,10 +321,13 @@ public: const SimplifyQuery SQ; MemorySSA *MSSA; std::unique_ptr<MemorySSAUpdater> MSSAUpdater; - typedef RecyclingAllocator< - BumpPtrAllocator, ScopedHashTableVal<SimpleValue, Value *>> AllocatorTy; - typedef ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>, - AllocatorTy> ScopedHTType; + + using AllocatorTy = + RecyclingAllocator<BumpPtrAllocator, + ScopedHashTableVal<SimpleValue, Value *>>; + using ScopedHTType = + ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>, + AllocatorTy>; /// \brief A scoped hash table of the current values of all of our simple /// scalar expressions. @@ -285,44 +352,45 @@ public: /// present the table; it is the responsibility of the consumer to inspect /// the atomicity/volatility if needed. struct LoadValue { - Instruction *DefInst; - unsigned Generation; - int MatchingId; - bool IsAtomic; - bool IsInvariant; - LoadValue() - : DefInst(nullptr), Generation(0), MatchingId(-1), IsAtomic(false), - IsInvariant(false) {} + Instruction *DefInst = nullptr; + unsigned Generation = 0; + int MatchingId = -1; + bool IsAtomic = false; + bool IsInvariant = false; + + LoadValue() = default; LoadValue(Instruction *Inst, unsigned Generation, unsigned MatchingId, bool IsAtomic, bool IsInvariant) : DefInst(Inst), Generation(Generation), MatchingId(MatchingId), IsAtomic(IsAtomic), IsInvariant(IsInvariant) {} }; - typedef RecyclingAllocator<BumpPtrAllocator, - ScopedHashTableVal<Value *, LoadValue>> - LoadMapAllocator; - typedef ScopedHashTable<Value *, LoadValue, DenseMapInfo<Value *>, - LoadMapAllocator> LoadHTType; + + using LoadMapAllocator = + RecyclingAllocator<BumpPtrAllocator, + ScopedHashTableVal<Value *, LoadValue>>; + using LoadHTType = + ScopedHashTable<Value *, LoadValue, DenseMapInfo<Value *>, + LoadMapAllocator>; + LoadHTType AvailableLoads; /// \brief A scoped hash table of the current values of read-only call /// values. /// /// It uses the same generation count as loads. - typedef ScopedHashTable<CallValue, std::pair<Instruction *, unsigned>> - CallHTType; + using CallHTType = + ScopedHashTable<CallValue, std::pair<Instruction *, unsigned>>; CallHTType AvailableCalls; /// \brief This is the current generation of the memory value. - unsigned CurrentGeneration; + unsigned CurrentGeneration = 0; /// \brief Set up the EarlyCSE runner for a particular function. EarlyCSE(const DataLayout &DL, const TargetLibraryInfo &TLI, const TargetTransformInfo &TTI, DominatorTree &DT, AssumptionCache &AC, MemorySSA *MSSA) : TLI(TLI), TTI(TTI), DT(DT), AC(AC), SQ(DL, &TLI, &DT, &AC), MSSA(MSSA), - MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)), CurrentGeneration(0) { - } + MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {} bool run(); @@ -336,11 +404,10 @@ private: CallHTType &AvailableCalls) : Scope(AvailableValues), LoadScope(AvailableLoads), CallScope(AvailableCalls) {} - - private: NodeScope(const NodeScope &) = delete; - void operator=(const NodeScope &) = delete; + NodeScope &operator=(const NodeScope &) = delete; + private: ScopedHTType::ScopeTy Scope; LoadHTType::ScopeTy LoadScope; CallHTType::ScopeTy CallScope; @@ -356,8 +423,10 @@ private: CallHTType &AvailableCalls, unsigned cg, DomTreeNode *n, DomTreeNode::iterator child, DomTreeNode::iterator end) : CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child), - EndIter(end), Scopes(AvailableValues, AvailableLoads, AvailableCalls), - Processed(false) {} + EndIter(end), Scopes(AvailableValues, AvailableLoads, AvailableCalls) + {} + StackNode(const StackNode &) = delete; + StackNode &operator=(const StackNode &) = delete; // Accessors. unsigned currentGeneration() { return CurrentGeneration; } @@ -365,27 +434,25 @@ private: void childGeneration(unsigned generation) { ChildGeneration = generation; } DomTreeNode *node() { return Node; } DomTreeNode::iterator childIter() { return ChildIter; } + DomTreeNode *nextChild() { DomTreeNode *child = *ChildIter; ++ChildIter; return child; } + DomTreeNode::iterator end() { return EndIter; } bool isProcessed() { return Processed; } void process() { Processed = true; } private: - StackNode(const StackNode &) = delete; - void operator=(const StackNode &) = delete; - - // Members. unsigned CurrentGeneration; unsigned ChildGeneration; DomTreeNode *Node; DomTreeNode::iterator ChildIter; DomTreeNode::iterator EndIter; NodeScope Scopes; - bool Processed; + bool Processed = false; }; /// \brief Wrapper class to handle memory instructions, including loads, @@ -393,24 +460,28 @@ private: class ParseMemoryInst { public: ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI) - : IsTargetMemInst(false), Inst(Inst) { + : Inst(Inst) { if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) if (TTI.getTgtMemIntrinsic(II, Info)) IsTargetMemInst = true; } + bool isLoad() const { if (IsTargetMemInst) return Info.ReadMem; return isa<LoadInst>(Inst); } + bool isStore() const { if (IsTargetMemInst) return Info.WriteMem; return isa<StoreInst>(Inst); } + bool isAtomic() const { if (IsTargetMemInst) return Info.Ordering != AtomicOrdering::NotAtomic; return Inst->isAtomic(); } + bool isUnordered() const { if (IsTargetMemInst) return Info.isUnordered(); @@ -447,6 +518,7 @@ private: return (getPointerOperand() == Inst.getPointerOperand() && getMatchingId() == Inst.getMatchingId()); } + bool isValid() const { return getPointerOperand() != nullptr; } // For regular (non-intrinsic) loads/stores, this is set to -1. For @@ -457,6 +529,7 @@ private: if (IsTargetMemInst) return Info.MatchingId; return -1; } + Value *getPointerOperand() const { if (IsTargetMemInst) return Info.PtrVal; if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { @@ -466,17 +539,19 @@ private: } return nullptr; } + bool mayReadFromMemory() const { if (IsTargetMemInst) return Info.ReadMem; return Inst->mayReadFromMemory(); } + bool mayWriteToMemory() const { if (IsTargetMemInst) return Info.WriteMem; return Inst->mayWriteToMemory(); } private: - bool IsTargetMemInst; + bool IsTargetMemInst = false; MemIntrinsicInfo Info; Instruction *Inst; }; @@ -524,8 +599,8 @@ private: for (MemoryPhi *MP : PhisToCheck) { MemoryAccess *FirstIn = MP->getIncomingValue(0); - if (all_of(MP->incoming_values(), - [=](Use &In) { return In == FirstIn; })) + if (llvm::all_of(MP->incoming_values(), + [=](Use &In) { return In == FirstIn; })) WorkQueue.push_back(MP); } PhisToCheck.clear(); @@ -533,7 +608,8 @@ private: } } }; -} + +} // end anonymous namespace /// Determine if the memory referenced by LaterInst is from the same heap /// version as EarlierInst. @@ -663,6 +739,12 @@ bool EarlyCSE::processNode(DomTreeNode *Node) { continue; } + // Skip sideeffect intrinsics, for the same reason as assume intrinsics. + if (match(Inst, m_Intrinsic<Intrinsic::sideeffect>())) { + DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << *Inst << '\n'); + continue; + } + // Skip invariant.start intrinsics since they only read memory, and we can // forward values across it. Also, we dont need to consume the last store // since the semantics of invariant.start allow us to perform DSE of the @@ -1014,6 +1096,7 @@ PreservedAnalyses EarlyCSEPass::run(Function &F, } namespace { + /// \brief A simple and fast domtree-based CSE pass. /// /// This pass does a simple depth-first walk over the dominator tree, @@ -1062,7 +1145,8 @@ public: AU.setPreservesCFG(); } }; -} + +} // end anonymous namespace using EarlyCSELegacyPass = EarlyCSELegacyCommonPass</*UseMemorySSA=*/false>; diff --git a/lib/Transforms/Scalar/GVN.cpp b/lib/Transforms/Scalar/GVN.cpp index ea28705e684d..e2c1eaf58e43 100644 --- a/lib/Transforms/Scalar/GVN.cpp +++ b/lib/Transforms/Scalar/GVN.cpp @@ -20,39 +20,64 @@ #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CFG.h" -#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/PHITransAddr.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" #include "llvm/Transforms/Utils/VNCoercion.h" - +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <utility> #include <vector> + using namespace llvm; using namespace llvm::gvn; using namespace llvm::VNCoercion; @@ -80,6 +105,7 @@ MaxRecurseDepth("max-recurse-depth", cl::Hidden, cl::init(1000), cl::ZeroOrMore, struct llvm::GVN::Expression { uint32_t opcode; Type *type; + bool commutative = false; SmallVector<uint32_t, 4> varargs; Expression(uint32_t o = ~2U) : opcode(o) {} @@ -104,20 +130,23 @@ struct llvm::GVN::Expression { }; namespace llvm { + template <> struct DenseMapInfo<GVN::Expression> { static inline GVN::Expression getEmptyKey() { return ~0U; } - static inline GVN::Expression getTombstoneKey() { return ~1U; } static unsigned getHashValue(const GVN::Expression &e) { using llvm::hash_value; + return static_cast<unsigned>(hash_value(e)); } + static bool isEqual(const GVN::Expression &LHS, const GVN::Expression &RHS) { return LHS == RHS; } }; -} // End llvm namespace. + +} // end namespace llvm /// Represents a particular available value that we know how to materialize. /// Materialization of an AvailableValue never fails. An AvailableValue is @@ -216,6 +245,7 @@ struct llvm::gvn::AvailableValueInBlock { unsigned Offset = 0) { return get(BB, AvailableValue::get(V, Offset)); } + static AvailableValueInBlock getUndef(BasicBlock *BB) { return get(BB, AvailableValue::getUndef()); } @@ -246,6 +276,7 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!"); if (e.varargs[0] > e.varargs[1]) std::swap(e.varargs[0], e.varargs[1]); + e.commutative = true; } if (CmpInst *C = dyn_cast<CmpInst>(I)) { @@ -256,6 +287,7 @@ GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { Predicate = CmpInst::getSwappedPredicate(Predicate); } e.opcode = (C->getOpcode() << 8) | Predicate; + e.commutative = true; } else if (InsertValueInst *E = dyn_cast<InsertValueInst>(I)) { for (InsertValueInst::idx_iterator II = E->idx_begin(), IE = E->idx_end(); II != IE; ++II) @@ -281,6 +313,7 @@ GVN::Expression GVN::ValueTable::createCmpExpr(unsigned Opcode, Predicate = CmpInst::getSwappedPredicate(Predicate); } e.opcode = (Opcode << 8) | Predicate; + e.commutative = true; return e; } @@ -340,7 +373,7 @@ GVN::Expression GVN::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { // ValueTable External Functions //===----------------------------------------------------------------------===// -GVN::ValueTable::ValueTable() : nextValueNumber(1) {} +GVN::ValueTable::ValueTable() = default; GVN::ValueTable::ValueTable(const ValueTable &) = default; GVN::ValueTable::ValueTable(ValueTable &&) = default; GVN::ValueTable::~ValueTable() = default; @@ -348,25 +381,25 @@ GVN::ValueTable::~ValueTable() = default; /// add - Insert a value into the table with a specified value number. void GVN::ValueTable::add(Value *V, uint32_t num) { valueNumbering.insert(std::make_pair(V, num)); + if (PHINode *PN = dyn_cast<PHINode>(V)) + NumberingPhi[num] = PN; } uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { if (AA->doesNotAccessMemory(C)) { Expression exp = createExpr(C); - uint32_t &e = expressionNumbering[exp]; - if (!e) e = nextValueNumber++; + uint32_t e = assignExpNewValueNum(exp).first; valueNumbering[C] = e; return e; } else if (AA->onlyReadsMemory(C)) { Expression exp = createExpr(C); - uint32_t &e = expressionNumbering[exp]; - if (!e) { - e = nextValueNumber++; - valueNumbering[C] = e; - return e; + auto ValNum = assignExpNewValueNum(exp); + if (ValNum.second) { + valueNumbering[C] = ValNum.first; + return ValNum.first; } if (!MD) { - e = nextValueNumber++; + uint32_t e = assignExpNewValueNum(exp).first; valueNumbering[C] = e; return e; } @@ -452,7 +485,6 @@ uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { uint32_t v = lookupOrAdd(cdep); valueNumbering[C] = v; return v; - } else { valueNumbering[C] = nextValueNumber; return nextValueNumber++; @@ -522,23 +554,29 @@ uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { case Instruction::ExtractValue: exp = createExtractvalueExpr(cast<ExtractValueInst>(I)); break; + case Instruction::PHI: + valueNumbering[V] = nextValueNumber; + NumberingPhi[nextValueNumber] = cast<PHINode>(V); + return nextValueNumber++; default: valueNumbering[V] = nextValueNumber; return nextValueNumber++; } - uint32_t& e = expressionNumbering[exp]; - if (!e) e = nextValueNumber++; + uint32_t e = assignExpNewValueNum(exp).first; valueNumbering[V] = e; return e; } /// Returns the value number of the specified value. Fails if /// the value has not yet been numbered. -uint32_t GVN::ValueTable::lookup(Value *V) const { +uint32_t GVN::ValueTable::lookup(Value *V, bool Verify) const { DenseMap<Value*, uint32_t>::const_iterator VI = valueNumbering.find(V); - assert(VI != valueNumbering.end() && "Value not numbered?"); - return VI->second; + if (Verify) { + assert(VI != valueNumbering.end() && "Value not numbered?"); + return VI->second; + } + return (VI != valueNumbering.end()) ? VI->second : 0; } /// Returns the value number of the given comparison, @@ -549,21 +587,28 @@ uint32_t GVN::ValueTable::lookupOrAddCmp(unsigned Opcode, CmpInst::Predicate Predicate, Value *LHS, Value *RHS) { Expression exp = createCmpExpr(Opcode, Predicate, LHS, RHS); - uint32_t& e = expressionNumbering[exp]; - if (!e) e = nextValueNumber++; - return e; + return assignExpNewValueNum(exp).first; } /// Remove all entries from the ValueTable. void GVN::ValueTable::clear() { valueNumbering.clear(); expressionNumbering.clear(); + NumberingPhi.clear(); + PhiTranslateTable.clear(); nextValueNumber = 1; + Expressions.clear(); + ExprIdx.clear(); + nextExprNumber = 0; } /// Remove a value from the value numbering. void GVN::ValueTable::erase(Value *V) { + uint32_t Num = valueNumbering.lookup(V); valueNumbering.erase(V); + // If V is PHINode, V <--> value number is an one-to-one mapping. + if (isa<PHINode>(V)) + NumberingPhi.erase(Num); } /// verifyRemoved - Verify that the value is removed from all internal data @@ -693,9 +738,6 @@ SpeculationFailure: return false; } - - - /// Given a set of loads specified by ValuesPerBlock, /// construct SSA form, allowing us to eliminate LI. This returns the value /// that should be used at LI's definition site. @@ -789,6 +831,7 @@ static void reportMayClobberedLoad(LoadInst *LI, MemDepResult DepInfo, DominatorTree *DT, OptimizationRemarkEmitter *ORE) { using namespace ore; + User *OtherAccess = nullptr; OptimizationRemarkMissed R(DEBUG_TYPE, "LoadClobbered", LI); @@ -817,7 +860,6 @@ static void reportMayClobberedLoad(LoadInst *LI, MemDepResult DepInfo, bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, Value *Address, AvailableValue &Res) { - assert((DepInfo.isDef() || DepInfo.isClobber()) && "expected a local dependence"); assert(LI->isUnordered() && "rules below are incorrect for ordered access"); @@ -879,8 +921,7 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, Instruction *I = DepInfo.getInst(); dbgs() << " is clobbered by " << *I << '\n'; ); - - if (ORE->allowExtraAnalysis()) + if (ORE->allowExtraAnalysis(DEBUG_TYPE)) reportMayClobberedLoad(LI, DepInfo, DT, ORE); return false; @@ -949,7 +990,6 @@ bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, void GVN::AnalyzeLoadAvailability(LoadInst *LI, LoadDepVect &Deps, AvailValInBlkVect &ValuesPerBlock, UnavailBlkVect &UnavailableBlocks) { - // Filter out useless results (non-locals, etc). Keep track of the blocks // where we have a value available in repl, also keep track of whether we see // dependencies that produce an unknown value for the load (such as a call @@ -1009,7 +1049,32 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // backwards through predecessors if needed. BasicBlock *LoadBB = LI->getParent(); BasicBlock *TmpBB = LoadBB; + bool IsSafeToSpeculativelyExecute = isSafeToSpeculativelyExecute(LI); + // Check that there is no implicit control flow instructions above our load in + // its block. If there is an instruction that doesn't always pass the + // execution to the following instruction, then moving through it may become + // invalid. For example: + // + // int arr[LEN]; + // int index = ???; + // ... + // guard(0 <= index && index < LEN); + // use(arr[index]); + // + // It is illegal to move the array access to any point above the guard, + // because if the index is out of bounds we should deoptimize rather than + // access the array. + // Check that there is no guard in this block above our intruction. + if (!IsSafeToSpeculativelyExecute) { + auto It = FirstImplicitControlFlowInsts.find(TmpBB); + if (It != FirstImplicitControlFlowInsts.end()) { + assert(It->second->getParent() == TmpBB && + "Implicit control flow map broken?"); + if (OI->dominates(It->second, LI)) + return false; + } + } while (TmpBB->getSinglePredecessor()) { TmpBB = TmpBB->getSinglePredecessor(); if (TmpBB == LoadBB) // Infinite (unreachable) loop. @@ -1024,6 +1089,11 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, // which it was not previously executed. if (TmpBB->getTerminator()->getNumSuccessors() != 1) return false; + + // Check that there is no implicit control flow in a block above. + if (!IsSafeToSpeculativelyExecute && + FirstImplicitControlFlowInsts.count(TmpBB)) + return false; } assert(TmpBB); @@ -1128,8 +1198,7 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, if (!CanDoPRE) { while (!NewInsts.empty()) { Instruction *I = NewInsts.pop_back_val(); - if (MD) MD->removeInstruction(I); - I->eraseFromParent(); + markInstructionForDeletion(I); } // HINT: Don't revert the edge-splitting as following transformation may // also need to split these critical edges. @@ -1206,8 +1275,10 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, if (V->getType()->isPtrOrPtrVectorTy()) MD->invalidateCachedPointerInfo(V); markInstructionForDeletion(LI); - ORE->emit(OptimizationRemark(DEBUG_TYPE, "LoadPRE", LI) - << "load eliminated by PRE"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "LoadPRE", LI) + << "load eliminated by PRE"; + }); ++NumPRELoad; return true; } @@ -1215,17 +1286,23 @@ bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, static void reportLoadElim(LoadInst *LI, Value *AvailableValue, OptimizationRemarkEmitter *ORE) { using namespace ore; - ORE->emit(OptimizationRemark(DEBUG_TYPE, "LoadElim", LI) - << "load of type " << NV("Type", LI->getType()) << " eliminated" - << setExtraArgs() << " in favor of " - << NV("InfavorOfValue", AvailableValue)); + + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "LoadElim", LI) + << "load of type " << NV("Type", LI->getType()) << " eliminated" + << setExtraArgs() << " in favor of " + << NV("InfavorOfValue", AvailableValue); + }); } /// Attempt to eliminate a load whose dependencies are /// non-local by performing PHI construction. bool GVN::processNonLocalLoad(LoadInst *LI) { // non-local speculations are not allowed under asan. - if (LI->getParent()->getParent()->hasFnAttribute(Attribute::SanitizeAddress)) + if (LI->getParent()->getParent()->hasFnAttribute( + Attribute::SanitizeAddress) || + LI->getParent()->getParent()->hasFnAttribute( + Attribute::SanitizeHWAddress)) return false; // Step 1: Find the non-local dependencies of the load. @@ -1322,6 +1399,11 @@ bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { } markInstructionForDeletion(IntrinsicI); return false; + } else if (isa<Constant>(V)) { + // If it's not false, and constant, it must evaluate to true. This means our + // assume is assume(true), and thus, pointless, and we don't want to do + // anything more here. + return false; } Constant *True = ConstantInt::getTrue(V->getContext()); @@ -1452,6 +1534,106 @@ bool GVN::processLoad(LoadInst *L) { return false; } +/// Return a pair the first field showing the value number of \p Exp and the +/// second field showing whether it is a value number newly created. +std::pair<uint32_t, bool> +GVN::ValueTable::assignExpNewValueNum(Expression &Exp) { + uint32_t &e = expressionNumbering[Exp]; + bool CreateNewValNum = !e; + if (CreateNewValNum) { + Expressions.push_back(Exp); + if (ExprIdx.size() < nextValueNumber + 1) + ExprIdx.resize(nextValueNumber * 2); + e = nextValueNumber; + ExprIdx[nextValueNumber++] = nextExprNumber++; + } + return {e, CreateNewValNum}; +} + +/// Return whether all the values related with the same \p num are +/// defined in \p BB. +bool GVN::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB, + GVN &Gvn) { + LeaderTableEntry *Vals = &Gvn.LeaderTable[Num]; + while (Vals && Vals->BB == BB) + Vals = Vals->Next; + return !Vals; +} + +/// Wrap phiTranslateImpl to provide caching functionality. +uint32_t GVN::ValueTable::phiTranslate(const BasicBlock *Pred, + const BasicBlock *PhiBlock, uint32_t Num, + GVN &Gvn) { + auto FindRes = PhiTranslateTable.find({Num, Pred}); + if (FindRes != PhiTranslateTable.end()) + return FindRes->second; + uint32_t NewNum = phiTranslateImpl(Pred, PhiBlock, Num, Gvn); + PhiTranslateTable.insert({{Num, Pred}, NewNum}); + return NewNum; +} + +/// Translate value number \p Num using phis, so that it has the values of +/// the phis in BB. +uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred, + const BasicBlock *PhiBlock, + uint32_t Num, GVN &Gvn) { + if (PHINode *PN = NumberingPhi[Num]) { + for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { + if (PN->getParent() == PhiBlock && PN->getIncomingBlock(i) == Pred) + if (uint32_t TransVal = lookup(PN->getIncomingValue(i), false)) + return TransVal; + } + return Num; + } + + // If there is any value related with Num is defined in a BB other than + // PhiBlock, it cannot depend on a phi in PhiBlock without going through + // a backedge. We can do an early exit in that case to save compile time. + if (!areAllValsInBB(Num, PhiBlock, Gvn)) + return Num; + + if (Num >= ExprIdx.size() || ExprIdx[Num] == 0) + return Num; + Expression Exp = Expressions[ExprIdx[Num]]; + + for (unsigned i = 0; i < Exp.varargs.size(); i++) { + // For InsertValue and ExtractValue, some varargs are index numbers + // instead of value numbers. Those index numbers should not be + // translated. + if ((i > 1 && Exp.opcode == Instruction::InsertValue) || + (i > 0 && Exp.opcode == Instruction::ExtractValue)) + continue; + Exp.varargs[i] = phiTranslate(Pred, PhiBlock, Exp.varargs[i], Gvn); + } + + if (Exp.commutative) { + assert(Exp.varargs.size() == 2 && "Unsupported commutative expression!"); + if (Exp.varargs[0] > Exp.varargs[1]) { + std::swap(Exp.varargs[0], Exp.varargs[1]); + uint32_t Opcode = Exp.opcode >> 8; + if (Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) + Exp.opcode = (Opcode << 8) | + CmpInst::getSwappedPredicate( + static_cast<CmpInst::Predicate>(Exp.opcode & 255)); + } + } + + if (uint32_t NewNum = expressionNumbering[Exp]) + return NewNum; + return Num; +} + +/// Erase stale entry from phiTranslate cache so phiTranslate can be computed +/// again. +void GVN::ValueTable::eraseTranslateCacheEntry(uint32_t Num, + const BasicBlock &CurrBlock) { + for (const BasicBlock *Pred : predecessors(&CurrBlock)) { + auto FindRes = PhiTranslateTable.find({Num, Pred}); + if (FindRes != PhiTranslateTable.end()) + PhiTranslateTable.erase(FindRes); + } +} + // In order to find a leader for a given value number at a // specific basic block, we first obtain the list of all Values for that number, // and then scan the list to find one whose block dominates the block in @@ -1496,6 +1678,13 @@ static bool isOnlyReachableViaThisEdge(const BasicBlockEdge &E, return Pred != nullptr; } +void GVN::assignBlockRPONumber(Function &F) { + uint32_t NextBlockNumber = 1; + ReversePostOrderTraversal<Function *> RPOT(&F); + for (BasicBlock *BB : RPOT) + BlockRPONumber[BB] = NextBlockNumber++; +} + // Tries to replace instruction with const, using information from // ReplaceWithConstMap. bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { @@ -1827,6 +2016,8 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, TLI = &RunTLI; VN.setAliasAnalysis(&RunAA); MD = RunMD; + OrderedInstructions OrderedInstrs(DT); + OI = &OrderedInstrs; VN.setMemDep(MD); ORE = RunORE; @@ -1857,6 +2048,7 @@ bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, // Fabricate val-num for dead-code in order to suppress assertion in // performPRE(). assignValNumForDeadCode(); + assignBlockRPONumber(F); bool PREChanged = true; while (PREChanged) { PREChanged = performPRE(F); @@ -1908,14 +2100,26 @@ bool GVN::processBlock(BasicBlock *BB) { if (!AtStart) --BI; - for (SmallVectorImpl<Instruction *>::iterator I = InstrsToErase.begin(), - E = InstrsToErase.end(); I != E; ++I) { - DEBUG(dbgs() << "GVN removed: " << **I << '\n'); - if (MD) MD->removeInstruction(*I); - DEBUG(verifyRemoved(*I)); - (*I)->eraseFromParent(); + bool InvalidateImplicitCF = false; + const Instruction *MaybeFirstICF = FirstImplicitControlFlowInsts.lookup(BB); + for (auto *I : InstrsToErase) { + assert(I->getParent() == BB && "Removing instruction from wrong block?"); + DEBUG(dbgs() << "GVN removed: " << *I << '\n'); + if (MD) MD->removeInstruction(I); + DEBUG(verifyRemoved(I)); + if (MaybeFirstICF == I) { + // We have erased the first ICF in block. The map needs to be updated. + InvalidateImplicitCF = true; + // Do not keep dangling pointer on the erased instruction. + MaybeFirstICF = nullptr; + } + I->eraseFromParent(); } + + OI->invalidateBlock(BB); InstrsToErase.clear(); + if (InvalidateImplicitCF) + fillImplicitControlFlowInfo(BB); if (AtStart) BI = BB->begin(); @@ -1928,7 +2132,7 @@ bool GVN::processBlock(BasicBlock *BB) { // Instantiate an expression in a predecessor that lacked it. bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, - unsigned int ValNo) { + BasicBlock *Curr, unsigned int ValNo) { // Because we are going top-down through the block, all value numbers // will be available in the predecessor by the time we need them. Any // that weren't originally present will have been instantiated earlier @@ -1946,7 +2150,9 @@ bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, success = false; break; } - if (Value *V = findLeader(Pred, VN.lookup(Op))) { + uint32_t TValNo = + VN.phiTranslate(Pred, Curr, VN.lookup(Op), *this); + if (Value *V = findLeader(Pred, TValNo)) { Instr->setOperand(i, V); } else { success = false; @@ -1963,10 +2169,12 @@ bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, Instr->insertBefore(Pred->getTerminator()); Instr->setName(Instr->getName() + ".pre"); Instr->setDebugLoc(Instr->getDebugLoc()); - VN.add(Instr, ValNo); + + unsigned Num = VN.lookupOrAdd(Instr); + VN.add(Instr, Num); // Update the availability map to include the new instruction. - addToLeaderTable(ValNo, Instr, Pred); + addToLeaderTable(Num, Instr, Pred); return true; } @@ -2004,18 +2212,27 @@ bool GVN::performScalarPRE(Instruction *CurInst) { SmallVector<std::pair<Value *, BasicBlock *>, 8> predMap; for (BasicBlock *P : predecessors(CurrentBlock)) { - // We're not interested in PRE where the block is its - // own predecessor, or in blocks with predecessors - // that are not reachable. - if (P == CurrentBlock) { + // We're not interested in PRE where blocks with predecessors that are + // not reachable. + if (!DT->isReachableFromEntry(P)) { NumWithout = 2; break; - } else if (!DT->isReachableFromEntry(P)) { + } + // It is not safe to do PRE when P->CurrentBlock is a loop backedge, and + // when CurInst has operand defined in CurrentBlock (so it may be defined + // by phi in the loop header). + if (BlockRPONumber[P] >= BlockRPONumber[CurrentBlock] && + llvm::any_of(CurInst->operands(), [&](const Use &U) { + if (auto *Inst = dyn_cast<Instruction>(U.get())) + return Inst->getParent() == CurrentBlock; + return false; + })) { NumWithout = 2; break; } - Value *predV = findLeader(P, ValNo); + uint32_t TValNo = VN.phiTranslate(P, CurrentBlock, ValNo, *this); + Value *predV = findLeader(P, TValNo); if (!predV) { predMap.push_back(std::make_pair(static_cast<Value *>(nullptr), P)); PREPred = P; @@ -2041,6 +2258,20 @@ bool GVN::performScalarPRE(Instruction *CurInst) { Instruction *PREInstr = nullptr; if (NumWithout != 0) { + if (!isSafeToSpeculativelyExecute(CurInst)) { + // It is only valid to insert a new instruction if the current instruction + // is always executed. An instruction with implicit control flow could + // prevent us from doing it. If we cannot speculate the execution, then + // PRE should be prohibited. + auto It = FirstImplicitControlFlowInsts.find(CurrentBlock); + if (It != FirstImplicitControlFlowInsts.end()) { + assert(It->second->getParent() == CurrentBlock && + "Implicit control flow map broken?"); + if (OI->dominates(It->second, CurInst)) + return false; + } + } + // Don't do PRE across indirect branch. if (isa<IndirectBrInst>(PREPred->getTerminator())) return false; @@ -2055,7 +2286,7 @@ bool GVN::performScalarPRE(Instruction *CurInst) { } // We need to insert somewhere, so let's give it a shot PREInstr = CurInst->clone(); - if (!performScalarPREInsertion(PREInstr, PREPred, ValNo)) { + if (!performScalarPREInsertion(PREInstr, PREPred, CurrentBlock, ValNo)) { // If we failed insertion, make sure we remove the instruction. DEBUG(verifyRemoved(PREInstr)); PREInstr->deleteValue(); @@ -2065,7 +2296,7 @@ bool GVN::performScalarPRE(Instruction *CurInst) { // Either we should have filled in the PRE instruction, or we should // not have needed insertions. - assert (PREInstr != nullptr || NumWithout == 0); + assert(PREInstr != nullptr || NumWithout == 0); ++NumGVNPRE; @@ -2074,13 +2305,19 @@ bool GVN::performScalarPRE(Instruction *CurInst) { PHINode::Create(CurInst->getType(), predMap.size(), CurInst->getName() + ".pre-phi", &CurrentBlock->front()); for (unsigned i = 0, e = predMap.size(); i != e; ++i) { - if (Value *V = predMap[i].first) + if (Value *V = predMap[i].first) { + // If we use an existing value in this phi, we have to patch the original + // value because the phi will be used to replace a later value. + patchReplacementInstruction(CurInst, V); Phi->addIncoming(V, predMap[i].second); - else + } else Phi->addIncoming(PREInstr, PREPred); } VN.add(Phi, ValNo); + // After creating a new PHI for ValNo, the phi translate result for ValNo will + // be changed, so erase the related stale entries in phi translate cache. + VN.eraseTranslateCacheEntry(ValNo, *CurrentBlock); addToLeaderTable(ValNo, Phi, CurrentBlock); Phi->setDebugLoc(CurInst->getDebugLoc()); CurInst->replaceAllUsesWith(Phi); @@ -2093,7 +2330,14 @@ bool GVN::performScalarPRE(Instruction *CurInst) { if (MD) MD->removeInstruction(CurInst); DEBUG(verifyRemoved(CurInst)); + bool InvalidateImplicitCF = + FirstImplicitControlFlowInsts.lookup(CurInst->getParent()) == CurInst; + // FIXME: Intended to be markInstructionForDeletion(CurInst), but it causes + // some assertion failures. + OI->invalidateBlock(CurrentBlock); CurInst->eraseFromParent(); + if (InvalidateImplicitCF) + fillImplicitControlFlowInfo(CurrentBlock); ++NumGVNInstr; return true; @@ -2160,6 +2404,9 @@ bool GVN::iterateOnFunction(Function &F) { // RPOT walks the graph in its constructor and will not be invalidated during // processBlock. ReversePostOrderTraversal<Function *> RPOT(&F); + + for (BasicBlock *BB : RPOT) + fillImplicitControlFlowInfo(BB); for (BasicBlock *BB : RPOT) Changed |= processBlock(BB); @@ -2169,7 +2416,50 @@ bool GVN::iterateOnFunction(Function &F) { void GVN::cleanupGlobalSets() { VN.clear(); LeaderTable.clear(); + BlockRPONumber.clear(); TableAllocator.Reset(); + FirstImplicitControlFlowInsts.clear(); +} + +void +GVN::fillImplicitControlFlowInfo(BasicBlock *BB) { + // Make sure that all marked instructions are actually deleted by this point, + // so that we don't need to care about omitting them. + assert(InstrsToErase.empty() && "Filling before removed all marked insns?"); + auto MayNotTransferExecutionToSuccessor = [&](const Instruction *I) { + // If a block's instruction doesn't always pass the control to its successor + // instruction, mark the block as having implicit control flow. We use them + // to avoid wrong assumptions of sort "if A is executed and B post-dominates + // A, then B is also executed". This is not true is there is an implicit + // control flow instruction (e.g. a guard) between them. + // + // TODO: Currently, isGuaranteedToTransferExecutionToSuccessor returns false + // for volatile stores and loads because they can trap. The discussion on + // whether or not it is correct is still ongoing. We might want to get rid + // of this logic in the future. Anyways, trapping instructions shouldn't + // introduce implicit control flow, so we explicitly allow them here. This + // must be removed once isGuaranteedToTransferExecutionToSuccessor is fixed. + if (isGuaranteedToTransferExecutionToSuccessor(I)) + return false; + if (isa<LoadInst>(I)) { + assert(cast<LoadInst>(I)->isVolatile() && + "Non-volatile load should transfer execution to successor!"); + return false; + } + if (isa<StoreInst>(I)) { + assert(cast<StoreInst>(I)->isVolatile() && + "Non-volatile store should transfer execution to successor!"); + return false; + } + return true; + }; + FirstImplicitControlFlowInsts.erase(BB); + + for (auto &I : *BB) + if (MayNotTransferExecutionToSuccessor(&I)) { + FirstImplicitControlFlowInsts[BB] = &I; + break; + } } /// Verify that the specified instruction does not occur in our @@ -2317,6 +2607,7 @@ void GVN::assignValNumForDeadCode() { class llvm::gvn::GVNLegacyPass : public FunctionPass { public: static char ID; // Pass identification, replacement for typeid + explicit GVNLegacyPass(bool NoLoads = false) : FunctionPass(ID), NoLoads(NoLoads) { initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry()); @@ -2360,11 +2651,6 @@ private: char GVNLegacyPass::ID = 0; -// The public interface to this file... -FunctionPass *llvm::createGVNPass(bool NoLoads) { - return new GVNLegacyPass(NoLoads); -} - INITIALIZE_PASS_BEGIN(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) @@ -2374,3 +2660,8 @@ INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) + +// The public interface to this file... +FunctionPass *llvm::createGVNPass(bool NoLoads) { + return new GVNLegacyPass(NoLoads); +} diff --git a/lib/Transforms/Scalar/GVNHoist.cpp b/lib/Transforms/Scalar/GVNHoist.cpp index 29de792bd248..c0cd1ea74a74 100644 --- a/lib/Transforms/Scalar/GVNHoist.cpp +++ b/lib/Transforms/Scalar/GVNHoist.cpp @@ -13,44 +13,72 @@ // 1. To reduce the code size. // 2. In some cases reduce critical path (by exposing more ILP). // +// The algorithm factors out the reachability of values such that multiple +// queries to find reachability of values are fast. This is based on finding the +// ANTIC points in the CFG which do not change during hoisting. The ANTIC points +// are basically the dominance-frontiers in the inverse graph. So we introduce a +// data structure (CHI nodes) to keep track of values flowing out of a basic +// block. We only do this for values with multiple occurrences in the function +// as they are the potential hoistable candidates. This approach allows us to +// hoist instructions to a basic block with more than two successors, as well as +// deal with infinite loops in a trivial way. +// +// Limitations: This pass does not hoist fully redundant expressions because +// they are already handled by GVN-PRE. It is advisable to run gvn-hoist before +// and after gvn-pre because gvn-pre creates opportunities for more instructions +// to be hoisted. +// // Hoisting may affect the performance in some cases. To mitigate that, hoisting // is disabled in the following cases. // 1. Scalars across calls. // 2. geps when corresponding load/store cannot be hoisted. -// -// TODO: Hoist from >2 successors. Currently GVNHoist will not hoist stores -// in this case because it works on two instructions at a time. -// entry: -// switch i32 %c1, label %exit1 [ -// i32 0, label %sw0 -// i32 1, label %sw1 -// ] -// -// sw0: -// store i32 1, i32* @G -// br label %exit -// -// sw1: -// store i32 1, i32* @G -// br label %exit -// -// exit1: -// store i32 1, i32* @G -// ret void -// exit: -// ret void //===----------------------------------------------------------------------===// #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <memory> +#include <utility> +#include <vector> using namespace llvm; @@ -69,6 +97,7 @@ static cl::opt<int> MaxHoistedThreshold("gvn-max-hoisted", cl::Hidden, cl::init(-1), cl::desc("Max number of instructions to hoist " "(default unlimited = -1)")); + static cl::opt<int> MaxNumberOfBBSInPath( "gvn-hoist-max-bbs", cl::Hidden, cl::init(4), cl::desc("Max number of basic blocks on the path between " @@ -86,34 +115,50 @@ static cl::opt<int> namespace llvm { -// Provides a sorting function based on the execution order of two instructions. -struct SortByDFSIn { -private: - DenseMap<const Value *, unsigned> &DFSNumber; +using BBSideEffectsSet = DenseMap<const BasicBlock *, bool>; +using SmallVecInsn = SmallVector<Instruction *, 4>; +using SmallVecImplInsn = SmallVectorImpl<Instruction *>; -public: - SortByDFSIn(DenseMap<const Value *, unsigned> &D) : DFSNumber(D) {} - - // Returns true when A executes before B. - bool operator()(const Instruction *A, const Instruction *B) const { - const BasicBlock *BA = A->getParent(); - const BasicBlock *BB = B->getParent(); - unsigned ADFS, BDFS; - if (BA == BB) { - ADFS = DFSNumber.lookup(A); - BDFS = DFSNumber.lookup(B); - } else { - ADFS = DFSNumber.lookup(BA); - BDFS = DFSNumber.lookup(BB); - } - assert(ADFS && BDFS); - return ADFS < BDFS; - } -}; +// Each element of a hoisting list contains the basic block where to hoist and +// a list of instructions to be hoisted. +using HoistingPointInfo = std::pair<BasicBlock *, SmallVecInsn>; + +using HoistingPointList = SmallVector<HoistingPointInfo, 4>; // A map from a pair of VNs to all the instructions with those VNs. -typedef DenseMap<std::pair<unsigned, unsigned>, SmallVector<Instruction *, 4>> - VNtoInsns; +using VNType = std::pair<unsigned, unsigned>; + +using VNtoInsns = DenseMap<VNType, SmallVector<Instruction *, 4>>; + +// CHI keeps information about values flowing out of a basic block. It is +// similar to PHI but in the inverse graph, and used for outgoing values on each +// edge. For conciseness, it is computed only for instructions with multiple +// occurrences in the CFG because they are the only hoistable candidates. +// A (CHI[{V, B, I1}, {V, C, I2}] +// / \ +// / \ +// B(I1) C (I2) +// The Value number for both I1 and I2 is V, the CHI node will save the +// instruction as well as the edge where the value is flowing to. +struct CHIArg { + VNType VN; + + // Edge destination (shows the direction of flow), may not be where the I is. + BasicBlock *Dest; + + // The instruction (VN) which uses the values flowing out of CHI. + Instruction *I; + + bool operator==(const CHIArg &A) { return VN == A.VN; } + bool operator!=(const CHIArg &A) { return !(*this == A); } +}; + +using CHIIt = SmallVectorImpl<CHIArg>::iterator; +using CHIArgs = iterator_range<CHIIt>; +using OutValuesType = DenseMap<BasicBlock *, SmallVector<CHIArg, 2>>; +using InValuesType = + DenseMap<BasicBlock *, SmallVector<std::pair<VNType, Instruction *>, 2>>; + // An invalid value number Used when inserting a single value number into // VNtoInsns. enum : unsigned { InvalidVN = ~2U }; @@ -192,16 +237,10 @@ public: } const VNtoInsns &getScalarVNTable() const { return VNtoCallsScalars; } - const VNtoInsns &getLoadVNTable() const { return VNtoCallsLoads; } - const VNtoInsns &getStoreVNTable() const { return VNtoCallsStores; } }; -typedef DenseMap<const BasicBlock *, bool> BBSideEffectsSet; -typedef SmallVector<Instruction *, 4> SmallVecInsn; -typedef SmallVectorImpl<Instruction *> SmallVecImplInsn; - static void combineKnownMetadata(Instruction *ReplInst, Instruction *I) { static const unsigned KnownIDs[] = { LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, @@ -216,15 +255,13 @@ static void combineKnownMetadata(Instruction *ReplInst, Instruction *I) { // cases reduce critical path (by exposing more ILP). class GVNHoist { public: - GVNHoist(DominatorTree *DT, AliasAnalysis *AA, MemoryDependenceResults *MD, - MemorySSA *MSSA) - : DT(DT), AA(AA), MD(MD), MSSA(MSSA), - MSSAUpdater(make_unique<MemorySSAUpdater>(MSSA)), - HoistingGeps(false), - HoistedCtr(0) - { } + GVNHoist(DominatorTree *DT, PostDominatorTree *PDT, AliasAnalysis *AA, + MemoryDependenceResults *MD, MemorySSA *MSSA) + : DT(DT), PDT(PDT), AA(AA), MD(MD), MSSA(MSSA), + MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {} bool run(Function &F) { + NumFuncArgs = F.arg_size(); VN.setDomTree(DT); VN.setAliasAnalysis(AA); VN.setMemDep(MD); @@ -241,7 +278,7 @@ public: int ChainLength = 0; // FIXME: use lazy evaluation of VN to avoid the fix-point computation. - while (1) { + while (true) { if (MaxChainLength != -1 && ++ChainLength >= MaxChainLength) return Res; @@ -261,18 +298,48 @@ public: return Res; } + // Copied from NewGVN.cpp + // This function provides global ranking of operations so that we can place + // them in a canonical order. Note that rank alone is not necessarily enough + // for a complete ordering, as constants all have the same rank. However, + // generally, we will simplify an operation with all constants so that it + // doesn't matter what order they appear in. + unsigned int rank(const Value *V) const { + // Prefer constants to undef to anything else + // Undef is a constant, have to check it first. + // Prefer smaller constants to constantexprs + if (isa<ConstantExpr>(V)) + return 2; + if (isa<UndefValue>(V)) + return 1; + if (isa<Constant>(V)) + return 0; + else if (auto *A = dyn_cast<Argument>(V)) + return 3 + A->getArgNo(); + + // Need to shift the instruction DFS by number of arguments + 3 to account + // for the constant and argument ranking above. + auto Result = DFSNumber.lookup(V); + if (Result > 0) + return 4 + NumFuncArgs + Result; + // Unreachable or something else, just return a really large number. + return ~0; + } + private: GVN::ValueTable VN; DominatorTree *DT; + PostDominatorTree *PDT; AliasAnalysis *AA; MemoryDependenceResults *MD; MemorySSA *MSSA; std::unique_ptr<MemorySSAUpdater> MSSAUpdater; - const bool HoistingGeps; DenseMap<const Value *, unsigned> DFSNumber; BBSideEffectsSet BBSideEffects; - DenseSet<const BasicBlock*> HoistBarrier; - int HoistedCtr; + DenseSet<const BasicBlock *> HoistBarrier; + SmallVector<BasicBlock *, 32> IDFBlocks; + unsigned NumFuncArgs; + const bool HoistingGeps = false; enum InsKind { Unknown, Scalar, Load, Store }; @@ -305,45 +372,7 @@ private: return false; } - // Return true when all paths from HoistBB to the end of the function pass - // through one of the blocks in WL. - bool hoistingFromAllPaths(const BasicBlock *HoistBB, - SmallPtrSetImpl<const BasicBlock *> &WL) { - - // Copy WL as the loop will remove elements from it. - SmallPtrSet<const BasicBlock *, 2> WorkList(WL.begin(), WL.end()); - - for (auto It = df_begin(HoistBB), E = df_end(HoistBB); It != E;) { - // There exists a path from HoistBB to the exit of the function if we are - // still iterating in DF traversal and we removed all instructions from - // the work list. - if (WorkList.empty()) - return false; - - const BasicBlock *BB = *It; - if (WorkList.erase(BB)) { - // Stop DFS traversal when BB is in the work list. - It.skipChildren(); - continue; - } - - // We reached the leaf Basic Block => not all paths have this instruction. - if (!BB->getTerminator()->getNumSuccessors()) - return false; - - // When reaching the back-edge of a loop, there may be a path through the - // loop that does not pass through B or C before exiting the loop. - if (successorDominate(BB, HoistBB)) - return false; - - // Increment DFS traversal when not skipping children. - ++It; - } - - return true; - } - - /* Return true when I1 appears before I2 in the instructions of BB. */ + // Return true when I1 appears before I2 in the instructions of BB. bool firstInBB(const Instruction *I1, const Instruction *I2) { assert(I1->getParent() == I2->getParent()); unsigned I1DFS = DFSNumber.lookup(I1); @@ -387,6 +416,25 @@ private: return false; } + bool hasEHhelper(const BasicBlock *BB, const BasicBlock *SrcBB, + int &NBBsOnAllPaths) { + // Stop walk once the limit is reached. + if (NBBsOnAllPaths == 0) + return true; + + // Impossible to hoist with exceptions on the path. + if (hasEH(BB)) + return true; + + // No such instruction after HoistBarrier in a basic block was + // selected for hoisting so instructions selected within basic block with + // a hoist barrier can be hoisted. + if ((BB != SrcBB) && HoistBarrier.count(BB)) + return true; + + return false; + } + // Return true when there are exception handling or loads of memory Def // between Def and NewPt. This function is only called for stores: Def is // the MemoryDef of the store to be hoisted. @@ -414,18 +462,7 @@ private: continue; } - // Stop walk once the limit is reached. - if (NBBsOnAllPaths == 0) - return true; - - // Impossible to hoist with exceptions on the path. - if (hasEH(BB)) - return true; - - // No such instruction after HoistBarrier in a basic block was - // selected for hoisting so instructions selected within basic block with - // a hoist barrier can be hoisted. - if ((BB != OldBB) && HoistBarrier.count(BB)) + if (hasEHhelper(BB, OldBB, NBBsOnAllPaths)) return true; // Check that we do not move a store past loads. @@ -463,18 +500,7 @@ private: continue; } - // Stop walk once the limit is reached. - if (NBBsOnAllPaths == 0) - return true; - - // Impossible to hoist with exceptions on the path. - if (hasEH(BB)) - return true; - - // No such instruction after HoistBarrier in a basic block was - // selected for hoisting so instructions selected within basic block with - // a hoist barrier can be hoisted. - if ((BB != SrcBB) && HoistBarrier.count(BB)) + if (hasEHhelper(BB, SrcBB, NBBsOnAllPaths)) return true; // -1 is unlimited number of blocks on all paths. @@ -491,7 +517,6 @@ private: // to NewPt. bool safeToHoistLdSt(const Instruction *NewPt, const Instruction *OldPt, MemoryUseOrDef *U, InsKind K, int &NBBsOnAllPaths) { - // In place hoisting is safe. if (NewPt == OldPt) return true; @@ -533,141 +558,258 @@ private: // Return true when it is safe to hoist scalar instructions from all blocks in // WL to HoistBB. - bool safeToHoistScalar(const BasicBlock *HoistBB, - SmallPtrSetImpl<const BasicBlock *> &WL, + bool safeToHoistScalar(const BasicBlock *HoistBB, const BasicBlock *BB, int &NBBsOnAllPaths) { - // Check that the hoisted expression is needed on all paths. - if (!hoistingFromAllPaths(HoistBB, WL)) - return false; + return !hasEHOnPath(HoistBB, BB, NBBsOnAllPaths); + } - for (const BasicBlock *BB : WL) - if (hasEHOnPath(HoistBB, BB, NBBsOnAllPaths)) + // In the inverse CFG, the dominance frontier of basic block (BB) is the + // point where ANTIC needs to be computed for instructions which are going + // to be hoisted. Since this point does not change during gvn-hoist, + // we compute it only once (on demand). + // The ides is inspired from: + // "Partial Redundancy Elimination in SSA Form" + // ROBERT KENNEDY, SUN CHAN, SHIN-MING LIU, RAYMOND LO, PENG TU and FRED CHOW + // They use similar idea in the forward graph to to find fully redundant and + // partially redundant expressions, here it is used in the inverse graph to + // find fully anticipable instructions at merge point (post-dominator in + // the inverse CFG). + // Returns the edge via which an instruction in BB will get the values from. + + // Returns true when the values are flowing out to each edge. + bool valueAnticipable(CHIArgs C, TerminatorInst *TI) const { + if (TI->getNumSuccessors() > (unsigned)std::distance(C.begin(), C.end())) + return false; // Not enough args in this CHI. + + for (auto CHI : C) { + BasicBlock *Dest = CHI.Dest; + // Find if all the edges have values flowing out of BB. + bool Found = llvm::any_of(TI->successors(), [Dest](const BasicBlock *BB) { + return BB == Dest; }); + if (!Found) return false; - + } return true; } - // Each element of a hoisting list contains the basic block where to hoist and - // a list of instructions to be hoisted. - typedef std::pair<BasicBlock *, SmallVecInsn> HoistingPointInfo; - typedef SmallVector<HoistingPointInfo, 4> HoistingPointList; - - // Partition InstructionsToHoist into a set of candidates which can share a - // common hoisting point. The partitions are collected in HPL. IsScalar is - // true when the instructions in InstructionsToHoist are scalars. IsLoad is - // true when the InstructionsToHoist are loads, false when they are stores. - void partitionCandidates(SmallVecImplInsn &InstructionsToHoist, - HoistingPointList &HPL, InsKind K) { - // No need to sort for two instructions. - if (InstructionsToHoist.size() > 2) { - SortByDFSIn Pred(DFSNumber); - std::sort(InstructionsToHoist.begin(), InstructionsToHoist.end(), Pred); - } - + // Check if it is safe to hoist values tracked by CHI in the range + // [Begin, End) and accumulate them in Safe. + void checkSafety(CHIArgs C, BasicBlock *BB, InsKind K, + SmallVectorImpl<CHIArg> &Safe) { int NumBBsOnAllPaths = MaxNumberOfBBSInPath; - - SmallVecImplInsn::iterator II = InstructionsToHoist.begin(); - SmallVecImplInsn::iterator Start = II; - Instruction *HoistPt = *II; - BasicBlock *HoistBB = HoistPt->getParent(); - MemoryUseOrDef *UD; - if (K != InsKind::Scalar) - UD = MSSA->getMemoryAccess(HoistPt); - - for (++II; II != InstructionsToHoist.end(); ++II) { - Instruction *Insn = *II; - BasicBlock *BB = Insn->getParent(); - BasicBlock *NewHoistBB; - Instruction *NewHoistPt; - - if (BB == HoistBB) { // Both are in the same Basic Block. - NewHoistBB = HoistBB; - NewHoistPt = firstInBB(Insn, HoistPt) ? Insn : HoistPt; + for (auto CHI : C) { + Instruction *Insn = CHI.I; + if (!Insn) // No instruction was inserted in this CHI. + continue; + if (K == InsKind::Scalar) { + if (safeToHoistScalar(BB, Insn->getParent(), NumBBsOnAllPaths)) + Safe.push_back(CHI); } else { - // If the hoisting point contains one of the instructions, - // then hoist there, otherwise hoist before the terminator. - NewHoistBB = DT->findNearestCommonDominator(HoistBB, BB); - if (NewHoistBB == BB) - NewHoistPt = Insn; - else if (NewHoistBB == HoistBB) - NewHoistPt = HoistPt; - else - NewHoistPt = NewHoistBB->getTerminator(); + MemoryUseOrDef *UD = MSSA->getMemoryAccess(Insn); + if (safeToHoistLdSt(BB->getTerminator(), Insn, UD, K, NumBBsOnAllPaths)) + Safe.push_back(CHI); } + } + } - SmallPtrSet<const BasicBlock *, 2> WL; - WL.insert(HoistBB); - WL.insert(BB); + using RenameStackType = DenseMap<VNType, SmallVector<Instruction *, 2>>; + + // Push all the VNs corresponding to BB into RenameStack. + void fillRenameStack(BasicBlock *BB, InValuesType &ValueBBs, + RenameStackType &RenameStack) { + auto it1 = ValueBBs.find(BB); + if (it1 != ValueBBs.end()) { + // Iterate in reverse order to keep lower ranked values on the top. + for (std::pair<VNType, Instruction *> &VI : reverse(it1->second)) { + // Get the value of instruction I + DEBUG(dbgs() << "\nPushing on stack: " << *VI.second); + RenameStack[VI.first].push_back(VI.second); + } + } + } - if (K == InsKind::Scalar) { - if (safeToHoistScalar(NewHoistBB, WL, NumBBsOnAllPaths)) { - // Extend HoistPt to NewHoistPt. - HoistPt = NewHoistPt; - HoistBB = NewHoistBB; - continue; - } - } else { - // When NewBB already contains an instruction to be hoisted, the - // expression is needed on all paths. - // Check that the hoisted expression is needed on all paths: it is - // unsafe to hoist loads to a place where there may be a path not - // loading from the same address: for instance there may be a branch on - // which the address of the load may not be initialized. - if ((HoistBB == NewHoistBB || BB == NewHoistBB || - hoistingFromAllPaths(NewHoistBB, WL)) && - // Also check that it is safe to move the load or store from HoistPt - // to NewHoistPt, and from Insn to NewHoistPt. - safeToHoistLdSt(NewHoistPt, HoistPt, UD, K, NumBBsOnAllPaths) && - safeToHoistLdSt(NewHoistPt, Insn, MSSA->getMemoryAccess(Insn), - K, NumBBsOnAllPaths)) { - // Extend HoistPt to NewHoistPt. - HoistPt = NewHoistPt; - HoistBB = NewHoistBB; - continue; - } + void fillChiArgs(BasicBlock *BB, OutValuesType &CHIBBs, + RenameStackType &RenameStack) { + // For each *predecessor* (because Post-DOM) of BB check if it has a CHI + for (auto Pred : predecessors(BB)) { + auto P = CHIBBs.find(Pred); + if (P == CHIBBs.end()) { + continue; } + DEBUG(dbgs() << "\nLooking at CHIs in: " << Pred->getName();); + // A CHI is found (BB -> Pred is an edge in the CFG) + // Pop the stack until Top(V) = Ve. + auto &VCHI = P->second; + for (auto It = VCHI.begin(), E = VCHI.end(); It != E;) { + CHIArg &C = *It; + if (!C.Dest) { + auto si = RenameStack.find(C.VN); + // The Basic Block where CHI is must dominate the value we want to + // track in a CHI. In the PDom walk, there can be values in the + // stack which are not control dependent e.g., nested loop. + if (si != RenameStack.end() && si->second.size() && + DT->dominates(Pred, si->second.back()->getParent())) { + C.Dest = BB; // Assign the edge + C.I = si->second.pop_back_val(); // Assign the argument + DEBUG(dbgs() << "\nCHI Inserted in BB: " << C.Dest->getName() + << *C.I << ", VN: " << C.VN.first << ", " + << C.VN.second); + } + // Move to next CHI of a different value + It = std::find_if(It, VCHI.end(), + [It](CHIArg &A) { return A != *It; }); + } else + ++It; + } + } + } + + // Walk the post-dominator tree top-down and use a stack for each value to + // store the last value you see. When you hit a CHI from a given edge, the + // value to use as the argument is at the top of the stack, add the value to + // CHI and pop. + void insertCHI(InValuesType &ValueBBs, OutValuesType &CHIBBs) { + auto Root = PDT->getNode(nullptr); + if (!Root) + return; + // Depth first walk on PDom tree to fill the CHIargs at each PDF. + RenameStackType RenameStack; + for (auto Node : depth_first(Root)) { + BasicBlock *BB = Node->getBlock(); + if (!BB) + continue; + + // Collect all values in BB and push to stack. + fillRenameStack(BB, ValueBBs, RenameStack); - // At this point it is not safe to extend the current hoisting to - // NewHoistPt: save the hoisting list so far. - if (std::distance(Start, II) > 1) - HPL.push_back({HoistBB, SmallVecInsn(Start, II)}); - - // Start over from BB. - Start = II; - if (K != InsKind::Scalar) - UD = MSSA->getMemoryAccess(*Start); - HoistPt = Insn; - HoistBB = BB; - NumBBsOnAllPaths = MaxNumberOfBBSInPath; + // Fill outgoing values in each CHI corresponding to BB. + fillChiArgs(BB, CHIBBs, RenameStack); } + } + + // Walk all the CHI-nodes to find ones which have a empty-entry and remove + // them Then collect all the instructions which are safe to hoist and see if + // they form a list of anticipable values. OutValues contains CHIs + // corresponding to each basic block. + void findHoistableCandidates(OutValuesType &CHIBBs, InsKind K, + HoistingPointList &HPL) { + auto cmpVN = [](const CHIArg &A, const CHIArg &B) { return A.VN < B.VN; }; + + // CHIArgs now have the outgoing values, so check for anticipability and + // accumulate hoistable candidates in HPL. + for (std::pair<BasicBlock *, SmallVector<CHIArg, 2>> &A : CHIBBs) { + BasicBlock *BB = A.first; + SmallVectorImpl<CHIArg> &CHIs = A.second; + // Vector of PHIs contains PHIs for different instructions. + // Sort the args according to their VNs, such that identical + // instructions are together. + std::stable_sort(CHIs.begin(), CHIs.end(), cmpVN); + auto TI = BB->getTerminator(); + auto B = CHIs.begin(); + // [PreIt, PHIIt) form a range of CHIs which have identical VNs. + auto PHIIt = std::find_if(CHIs.begin(), CHIs.end(), + [B](CHIArg &A) { return A != *B; }); + auto PrevIt = CHIs.begin(); + while (PrevIt != PHIIt) { + // Collect values which satisfy safety checks. + SmallVector<CHIArg, 2> Safe; + // We check for safety first because there might be multiple values in + // the same path, some of which are not safe to be hoisted, but overall + // each edge has at least one value which can be hoisted, making the + // value anticipable along that path. + checkSafety(make_range(PrevIt, PHIIt), BB, K, Safe); + + // List of safe values should be anticipable at TI. + if (valueAnticipable(make_range(Safe.begin(), Safe.end()), TI)) { + HPL.push_back({BB, SmallVecInsn()}); + SmallVecInsn &V = HPL.back().second; + for (auto B : Safe) + V.push_back(B.I); + } - // Save the last partition. - if (std::distance(Start, II) > 1) - HPL.push_back({HoistBB, SmallVecInsn(Start, II)}); + // Check other VNs + PrevIt = PHIIt; + PHIIt = std::find_if(PrevIt, CHIs.end(), + [PrevIt](CHIArg &A) { return A != *PrevIt; }); + } + } } - // Initialize HPL from Map. + // Compute insertion points for each values which can be fully anticipated at + // a dominator. HPL contains all such values. void computeInsertionPoints(const VNtoInsns &Map, HoistingPointList &HPL, InsKind K) { + // Sort VNs based on their rankings + std::vector<VNType> Ranks; for (const auto &Entry : Map) { - if (MaxHoistedThreshold != -1 && ++HoistedCtr > MaxHoistedThreshold) - return; + Ranks.push_back(Entry.first); + } - const SmallVecInsn &V = Entry.second; + // TODO: Remove fully-redundant expressions. + // Get instruction from the Map, assume that all the Instructions + // with same VNs have same rank (this is an approximation). + std::sort(Ranks.begin(), Ranks.end(), + [this, &Map](const VNType &r1, const VNType &r2) { + return (rank(*Map.lookup(r1).begin()) < + rank(*Map.lookup(r2).begin())); + }); + + // - Sort VNs according to their rank, and start with lowest ranked VN + // - Take a VN and for each instruction with same VN + // - Find the dominance frontier in the inverse graph (PDF) + // - Insert the chi-node at PDF + // - Remove the chi-nodes with missing entries + // - Remove values from CHI-nodes which do not truly flow out, e.g., + // modified along the path. + // - Collect the remaining values that are still anticipable + SmallVector<BasicBlock *, 2> IDFBlocks; + ReverseIDFCalculator IDFs(*PDT); + OutValuesType OutValue; + InValuesType InValue; + for (const auto &R : Ranks) { + const SmallVecInsn &V = Map.lookup(R); if (V.size() < 2) continue; - - // Compute the insertion point and the list of expressions to be hoisted. - SmallVecInsn InstructionsToHoist; - for (auto I : V) - // We don't need to check for hoist-barriers here because if - // I->getParent() is a barrier then I precedes the barrier. - if (!hasEH(I->getParent())) - InstructionsToHoist.push_back(I); - - if (!InstructionsToHoist.empty()) - partitionCandidates(InstructionsToHoist, HPL, K); + const VNType &VN = R; + SmallPtrSet<BasicBlock *, 2> VNBlocks; + for (auto &I : V) { + BasicBlock *BBI = I->getParent(); + if (!hasEH(BBI)) + VNBlocks.insert(BBI); + } + // Compute the Post Dominance Frontiers of each basic block + // The dominance frontier of a live block X in the reverse + // control graph is the set of blocks upon which X is control + // dependent. The following sequence computes the set of blocks + // which currently have dead terminators that are control + // dependence sources of a block which is in NewLiveBlocks. + IDFs.setDefiningBlocks(VNBlocks); + IDFs.calculate(IDFBlocks); + + // Make a map of BB vs instructions to be hoisted. + for (unsigned i = 0; i < V.size(); ++i) { + InValue[V[i]->getParent()].push_back(std::make_pair(VN, V[i])); + } + // Insert empty CHI node for this VN. This is used to factor out + // basic blocks where the ANTIC can potentially change. + for (auto IDFB : IDFBlocks) { // TODO: Prune out useless CHI insertions. + for (unsigned i = 0; i < V.size(); ++i) { + CHIArg C = {VN, nullptr, nullptr}; + // Ignore spurious PDFs. + if (DT->properlyDominates(IDFB, V[i]->getParent())) { + OutValue[IDFB].push_back(C); + DEBUG(dbgs() << "\nInsertion a CHI for BB: " << IDFB->getName() + << ", for Insn: " << *V[i]); + } + } + } } + + // Insert CHI args at each PDF to iterate on factored graph of + // control dependence. + insertCHI(InValue, OutValue); + // Using the CHI args inserted at each PDF, find fully anticipable values. + findHoistableCandidates(OutValue, K, HPL); } // Return true when all operands of Instr are available at insertion point @@ -714,7 +856,6 @@ private: Instruction *ClonedGep = Gep->clone(); for (unsigned i = 0, e = Gep->getNumOperands(); i != e; ++i) if (Instruction *Op = dyn_cast<Instruction>(Gep->getOperand(i))) { - // Check whether the operand is already available. if (DT->dominates(Op->getParent(), HoistPt)) continue; @@ -748,6 +889,88 @@ private: Repl->replaceUsesOfWith(Gep, ClonedGep); } + void updateAlignment(Instruction *I, Instruction *Repl) { + if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) { + ReplacementLoad->setAlignment( + std::min(ReplacementLoad->getAlignment(), + cast<LoadInst>(I)->getAlignment())); + ++NumLoadsRemoved; + } else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) { + ReplacementStore->setAlignment( + std::min(ReplacementStore->getAlignment(), + cast<StoreInst>(I)->getAlignment())); + ++NumStoresRemoved; + } else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) { + ReplacementAlloca->setAlignment( + std::max(ReplacementAlloca->getAlignment(), + cast<AllocaInst>(I)->getAlignment())); + } else if (isa<CallInst>(Repl)) { + ++NumCallsRemoved; + } + } + + // Remove all the instructions in Candidates and replace their usage with Repl. + // Returns the number of instructions removed. + unsigned rauw(const SmallVecInsn &Candidates, Instruction *Repl, + MemoryUseOrDef *NewMemAcc) { + unsigned NR = 0; + for (Instruction *I : Candidates) { + if (I != Repl) { + ++NR; + updateAlignment(I, Repl); + if (NewMemAcc) { + // Update the uses of the old MSSA access with NewMemAcc. + MemoryAccess *OldMA = MSSA->getMemoryAccess(I); + OldMA->replaceAllUsesWith(NewMemAcc); + MSSAUpdater->removeMemoryAccess(OldMA); + } + + Repl->andIRFlags(I); + combineKnownMetadata(Repl, I); + I->replaceAllUsesWith(Repl); + // Also invalidate the Alias Analysis cache. + MD->removeInstruction(I); + I->eraseFromParent(); + } + } + return NR; + } + + // Replace all Memory PHI usage with NewMemAcc. + void raMPHIuw(MemoryUseOrDef *NewMemAcc) { + SmallPtrSet<MemoryPhi *, 4> UsePhis; + for (User *U : NewMemAcc->users()) + if (MemoryPhi *Phi = dyn_cast<MemoryPhi>(U)) + UsePhis.insert(Phi); + + for (MemoryPhi *Phi : UsePhis) { + auto In = Phi->incoming_values(); + if (llvm::all_of(In, [&](Use &U) { return U == NewMemAcc; })) { + Phi->replaceAllUsesWith(NewMemAcc); + MSSAUpdater->removeMemoryAccess(Phi); + } + } + } + + // Remove all other instructions and replace them with Repl. + unsigned removeAndReplace(const SmallVecInsn &Candidates, Instruction *Repl, + BasicBlock *DestBB, bool MoveAccess) { + MemoryUseOrDef *NewMemAcc = MSSA->getMemoryAccess(Repl); + if (MoveAccess && NewMemAcc) { + // The definition of this ld/st will not change: ld/st hoisting is + // legal when the ld/st is not moved past its current definition. + MSSAUpdater->moveToPlace(NewMemAcc, DestBB, MemorySSA::End); + } + + // Replace all other instructions with Repl with memory access NewMemAcc. + unsigned NR = rauw(Candidates, Repl, NewMemAcc); + + // Remove MemorySSA phi nodes with the same arguments. + if (NewMemAcc) + raMPHIuw(NewMemAcc); + return NR; + } + // In the case Repl is a load or a store, we make all their GEPs // available: GEPs are not hoisted by default to avoid the address // computations to be hoisted without the associated load or store. @@ -789,11 +1012,11 @@ private: for (const HoistingPointInfo &HP : HPL) { // Find out whether we already have one of the instructions in HoistPt, // in which case we do not have to move it. - BasicBlock *HoistPt = HP.first; + BasicBlock *DestBB = HP.first; const SmallVecInsn &InstructionsToHoist = HP.second; Instruction *Repl = nullptr; for (Instruction *I : InstructionsToHoist) - if (I->getParent() == HoistPt) + if (I->getParent() == DestBB) // If there are two instructions in HoistPt to be hoisted in place: // update Repl to be the first one, such that we can rename the uses // of the second based on the first. @@ -805,7 +1028,7 @@ private: bool MoveAccess = true; if (Repl) { // Repl is already in HoistPt: it remains in place. - assert(allOperandsAvailable(Repl, HoistPt) && + assert(allOperandsAvailable(Repl, DestBB) && "instruction depends on operands that are not available"); MoveAccess = false; } else { @@ -816,40 +1039,26 @@ private: // We can move Repl in HoistPt only when all operands are available. // The order in which hoistings are done may influence the availability // of operands. - if (!allOperandsAvailable(Repl, HoistPt)) { - + if (!allOperandsAvailable(Repl, DestBB)) { // When HoistingGeps there is nothing more we can do to make the // operands available: just continue. if (HoistingGeps) continue; // When not HoistingGeps we need to copy the GEPs. - if (!makeGepOperandsAvailable(Repl, HoistPt, InstructionsToHoist)) + if (!makeGepOperandsAvailable(Repl, DestBB, InstructionsToHoist)) continue; } // Move the instruction at the end of HoistPt. - Instruction *Last = HoistPt->getTerminator(); + Instruction *Last = DestBB->getTerminator(); MD->removeInstruction(Repl); Repl->moveBefore(Last); DFSNumber[Repl] = DFSNumber[Last]++; } - MemoryAccess *NewMemAcc = MSSA->getMemoryAccess(Repl); - - if (MoveAccess) { - if (MemoryUseOrDef *OldMemAcc = - dyn_cast_or_null<MemoryUseOrDef>(NewMemAcc)) { - // The definition of this ld/st will not change: ld/st hoisting is - // legal when the ld/st is not moved past its current definition. - MemoryAccess *Def = OldMemAcc->getDefiningAccess(); - NewMemAcc = - MSSAUpdater->createMemoryAccessInBB(Repl, Def, HoistPt, MemorySSA::End); - OldMemAcc->replaceAllUsesWith(NewMemAcc); - MSSAUpdater->removeMemoryAccess(OldMemAcc); - } - } + NR += removeAndReplace(InstructionsToHoist, Repl, DestBB, MoveAccess); if (isa<LoadInst>(Repl)) ++NL; @@ -859,59 +1068,6 @@ private: ++NC; else // Scalar ++NI; - - // Remove and rename all other instructions. - for (Instruction *I : InstructionsToHoist) - if (I != Repl) { - ++NR; - if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) { - ReplacementLoad->setAlignment( - std::min(ReplacementLoad->getAlignment(), - cast<LoadInst>(I)->getAlignment())); - ++NumLoadsRemoved; - } else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) { - ReplacementStore->setAlignment( - std::min(ReplacementStore->getAlignment(), - cast<StoreInst>(I)->getAlignment())); - ++NumStoresRemoved; - } else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) { - ReplacementAlloca->setAlignment( - std::max(ReplacementAlloca->getAlignment(), - cast<AllocaInst>(I)->getAlignment())); - } else if (isa<CallInst>(Repl)) { - ++NumCallsRemoved; - } - - if (NewMemAcc) { - // Update the uses of the old MSSA access with NewMemAcc. - MemoryAccess *OldMA = MSSA->getMemoryAccess(I); - OldMA->replaceAllUsesWith(NewMemAcc); - MSSAUpdater->removeMemoryAccess(OldMA); - } - - Repl->andIRFlags(I); - combineKnownMetadata(Repl, I); - I->replaceAllUsesWith(Repl); - // Also invalidate the Alias Analysis cache. - MD->removeInstruction(I); - I->eraseFromParent(); - } - - // Remove MemorySSA phi nodes with the same arguments. - if (NewMemAcc) { - SmallPtrSet<MemoryPhi *, 4> UsePhis; - for (User *U : NewMemAcc->users()) - if (MemoryPhi *Phi = dyn_cast<MemoryPhi>(U)) - UsePhis.insert(Phi); - - for (auto *Phi : UsePhis) { - auto In = Phi->incoming_values(); - if (all_of(In, [&](Use &U) { return U == NewMemAcc; })) { - Phi->replaceAllUsesWith(NewMemAcc); - MSSAUpdater->removeMemoryAccess(Phi); - } - } - } } NumHoisted += NL + NS + NC + NI; @@ -935,8 +1091,8 @@ private: // If I1 cannot guarantee progress, subsequent instructions // in BB cannot be hoisted anyways. if (!isGuaranteedToTransferExecutionToSuccessor(&I1)) { - HoistBarrier.insert(BB); - break; + HoistBarrier.insert(BB); + break; } // Only hoist the first instructions in BB up to MaxDepthInBB. Hoisting // deeper may increase the register pressure and compilation time. @@ -954,7 +1110,8 @@ private: else if (auto *Call = dyn_cast<CallInst>(&I1)) { if (auto *Intr = dyn_cast<IntrinsicInst>(Call)) { if (isa<DbgInfoIntrinsic>(Intr) || - Intr->getIntrinsicID() == Intrinsic::assume) + Intr->getIntrinsicID() == Intrinsic::assume || + Intr->getIntrinsicID() == Intrinsic::sideeffect) continue; } if (Call->mayHaveSideEffects()) @@ -996,16 +1153,18 @@ public: if (skipFunction(F)) return false; auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); auto &MD = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); - GVNHoist G(&DT, &AA, &MD, &MSSA); + GVNHoist G(&DT, &PDT, &AA, &MD, &MSSA); return G.run(F); } void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); AU.addRequired<MemoryDependenceWrapperPass>(); AU.addRequired<MemorySSAWrapperPass>(); @@ -1014,14 +1173,16 @@ public: AU.addPreserved<GlobalsAAWrapperPass>(); } }; -} // namespace + +} // end namespace llvm PreservedAnalyses GVNHoistPass::run(Function &F, FunctionAnalysisManager &AM) { DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); AliasAnalysis &AA = AM.getResult<AAManager>(F); MemoryDependenceResults &MD = AM.getResult<MemoryDependenceAnalysis>(F); MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); - GVNHoist G(&DT, &AA, &MD, &MSSA); + GVNHoist G(&DT, &PDT, &AA, &MD, &MSSA); if (!G.run(F)) return PreservedAnalyses::all(); @@ -1033,6 +1194,7 @@ PreservedAnalyses GVNHoistPass::run(Function &F, FunctionAnalysisManager &AM) { } char GVNHoistLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(GVNHoistLegacyPass, "gvn-hoist", "Early GVN Hoisting of Expressions", false, false) INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) diff --git a/lib/Transforms/Scalar/GVNSink.cpp b/lib/Transforms/Scalar/GVNSink.cpp index 5fd2dfc118b4..814a62cd7d65 100644 --- a/lib/Transforms/Scalar/GVNSink.cpp +++ b/lib/Transforms/Scalar/GVNSink.cpp @@ -1,4 +1,4 @@ -//===- GVNSink.cpp - sink expressions into successors -------------------===// +//===- GVNSink.cpp - sink expressions into successors ---------------------===// // // The LLVM Compiler Infrastructure // @@ -31,33 +31,54 @@ /// replace %a1 with %c1, will it contribute in an equivalent way to all /// successive instructions?". The PostValueTable class in GVN provides this /// mapping. -/// +// //===----------------------------------------------------------------------===// +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringExtras.h" #include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/MemorySSA.h" -#include "llvm/Analysis/PostDominators.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/Verifier.h" -#include "llvm/Support/MathExtras.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ArrayRecycler.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVN.h" #include "llvm/Transforms/Scalar/GVNExpression.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" -#include <unordered_set> +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "gvn-sink" @@ -72,8 +93,8 @@ LLVM_DUMP_METHOD void Expression::dump() const { dbgs() << "\n"; } -} -} +} // end namespace GVNExpression +} // end namespace llvm namespace { @@ -97,7 +118,7 @@ static bool isMemoryInst(const Instruction *I) { /// list returned by operator*. class LockstepReverseIterator { ArrayRef<BasicBlock *> Blocks; - SmallPtrSet<BasicBlock *, 4> ActiveBlocks; + SmallSetVector<BasicBlock *, 4> ActiveBlocks; SmallVector<Instruction *, 4> Insts; bool Fail; @@ -115,7 +136,7 @@ public: for (BasicBlock *BB : Blocks) { if (BB->size() <= 1) { // Block wasn't big enough - only contained a terminator. - ActiveBlocks.erase(BB); + ActiveBlocks.remove(BB); continue; } Insts.push_back(BB->getTerminator()->getPrevNode()); @@ -126,13 +147,20 @@ public: bool isValid() const { return !Fail; } ArrayRef<Instruction *> operator*() const { return Insts; } - SmallPtrSet<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; } - void restrictToBlocks(SmallPtrSetImpl<BasicBlock *> &Blocks) { + // Note: This needs to return a SmallSetVector as the elements of + // ActiveBlocks will be later copied to Blocks using std::copy. The + // resultant order of elements in Blocks needs to be deterministic. + // Using SmallPtrSet instead causes non-deterministic order while + // copying. And we cannot simply sort Blocks as they need to match the + // corresponding Values. + SmallSetVector<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; } + + void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) { for (auto II = Insts.begin(); II != Insts.end();) { if (std::find(Blocks.begin(), Blocks.end(), (*II)->getParent()) == Blocks.end()) { - ActiveBlocks.erase((*II)->getParent()); + ActiveBlocks.remove((*II)->getParent()); II = Insts.erase(II); } else { ++II; @@ -146,7 +174,7 @@ public: SmallVector<Instruction *, 4> NewInsts; for (auto *Inst : Insts) { if (Inst == &Inst->getParent()->front()) - ActiveBlocks.erase(Inst->getParent()); + ActiveBlocks.remove(Inst->getParent()); else NewInsts.push_back(Inst->getPrevNode()); } @@ -180,14 +208,14 @@ struct SinkingInstructionCandidate { NumExtraPHIs) // PHIs are expensive, so make sure they're worth it. - SplitEdgeCost; } + bool operator>(const SinkingInstructionCandidate &Other) const { return Cost > Other.Cost; } }; #ifndef NDEBUG -llvm::raw_ostream &operator<<(llvm::raw_ostream &OS, - const SinkingInstructionCandidate &C) { +raw_ostream &operator<<(raw_ostream &OS, const SinkingInstructionCandidate &C) { OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">"; return OS; @@ -204,17 +232,20 @@ class ModelledPHI { SmallVector<BasicBlock *, 4> Blocks; public: - ModelledPHI() {} + ModelledPHI() = default; + ModelledPHI(const PHINode *PN) { + // BasicBlock comes first so we sort by basic block pointer order, then by value pointer order. + SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops; for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) - Blocks.push_back(PN->getIncomingBlock(I)); - std::sort(Blocks.begin(), Blocks.end()); - - // This assumes the PHI is already well-formed and there aren't conflicting - // incoming values for the same block. - for (auto *B : Blocks) - Values.push_back(PN->getIncomingValueForBlock(B)); + Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)}); + std::sort(Ops.begin(), Ops.end()); + for (auto &P : Ops) { + Blocks.push_back(P.first); + Values.push_back(P.second); + } } + /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI /// without the same ID. /// \note This is specifically for DenseMapInfo - do not use this! @@ -241,7 +272,7 @@ public: /// Restrict the PHI's contents down to only \c NewBlocks. /// \c NewBlocks must be a subset of \c this->Blocks. - void restrictToBlocks(const SmallPtrSetImpl<BasicBlock *> &NewBlocks) { + void restrictToBlocks(const SmallSetVector<BasicBlock *, 4> &NewBlocks) { auto BI = Blocks.begin(); auto VI = Values.begin(); while (BI != Blocks.end()) { @@ -261,19 +292,23 @@ public: ArrayRef<Value *> getValues() const { return Values; } bool areAllIncomingValuesSame() const { - return all_of(Values, [&](Value *V) { return V == Values[0]; }); + return llvm::all_of(Values, [&](Value *V) { return V == Values[0]; }); } + bool areAllIncomingValuesSameType() const { - return all_of( + return llvm::all_of( Values, [&](Value *V) { return V->getType() == Values[0]->getType(); }); } + bool areAnyIncomingValuesConstant() const { - return any_of(Values, [&](Value *V) { return isa<Constant>(V); }); + return llvm::any_of(Values, [&](Value *V) { return isa<Constant>(V); }); } + // Hash functor unsigned hash() const { return (unsigned)hash_combine_range(Values.begin(), Values.end()); } + bool operator==(const ModelledPHI &Other) const { return Values == Other.Values && Blocks == Other.Blocks; } @@ -284,17 +319,20 @@ template <typename ModelledPHI> struct DenseMapInfo { static ModelledPHI Dummy = ModelledPHI::createDummy(0); return Dummy; } + static inline ModelledPHI &getTombstoneKey() { static ModelledPHI Dummy = ModelledPHI::createDummy(1); return Dummy; } + static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); } + static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) { return LHS == RHS; } }; -typedef DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>> ModelledPHISet; +using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>; //===----------------------------------------------------------------------===// // ValueTable @@ -325,10 +363,11 @@ public: op_push_back(U.getUser()); std::sort(op_begin(), op_end()); } + void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; } void setVolatile(bool V) { Volatile = V; } - virtual hash_code getHashValue() const { + hash_code getHashValue() const override { return hash_combine(GVNExpression::BasicExpression::getHashValue(), MemoryUseOrder, Volatile); } @@ -348,7 +387,7 @@ class ValueTable { DenseMap<size_t, uint32_t> HashNumbering; BumpPtrAllocator Allocator; ArrayRecycler<Value *> Recycler; - uint32_t nextValueNumber; + uint32_t nextValueNumber = 1; /// Create an expression for I based on its opcode and its uses. If I /// touches or reads memory, the expression is also based upon its memory @@ -378,6 +417,8 @@ class ValueTable { } public: + ValueTable() = default; + /// Returns the value number for the specified value, assigning /// it a new number if it did not have one before. uint32_t lookupOrAdd(Value *V) { @@ -483,8 +524,6 @@ public: nextValueNumber = 1; } - ValueTable() : nextValueNumber(1) {} - /// \c Inst uses or touches memory. Return an ID describing the memory state /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2), /// the exact same memory operations happen after I1 and I2. @@ -519,7 +558,8 @@ public: class GVNSink { public: - GVNSink() : VN() {} + GVNSink() = default; + bool run(Function &F) { DEBUG(dbgs() << "GVNSink: running on function @" << F.getName() << "\n"); @@ -576,8 +616,9 @@ private: void foldPointlessPHINodes(BasicBlock *BB) { auto I = BB->begin(); while (PHINode *PN = dyn_cast<PHINode>(I++)) { - if (!all_of(PN->incoming_values(), - [&](const Value *V) { return V == PN->getIncomingValue(0); })) + if (!llvm::all_of(PN->incoming_values(), [&](const Value *V) { + return V == PN->getIncomingValue(0); + })) continue; if (PN->getIncomingValue(0) != PN) PN->replaceAllUsesWith(PN->getIncomingValue(0)); @@ -624,7 +665,7 @@ Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( SmallVector<Instruction *, 4> NewInsts; for (auto *I : Insts) { if (VN.lookup(I) != VNumToSink) - ActivePreds.erase(I->getParent()); + ActivePreds.remove(I->getParent()); else NewInsts.push_back(I); } @@ -794,7 +835,7 @@ void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, SmallVector<Value *, 4> NewOperands; for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) { - bool NeedPHI = any_of(Insts, [&I0, O](const Instruction *I) { + bool NeedPHI = llvm::any_of(Insts, [&I0, O](const Instruction *I) { return I->getOperand(O) != I0->getOperand(O); }); if (!NeedPHI) { @@ -860,7 +901,8 @@ public: AU.addPreserved<GlobalsAAWrapperPass>(); } }; -} // namespace + +} // end anonymous namespace PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) { GVNSink G; @@ -873,6 +915,7 @@ PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) { } char GVNSinkLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(GVNSinkLegacyPass, "gvn-sink", "Early GVN sinking of Expressions", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) diff --git a/lib/Transforms/Scalar/GuardWidening.cpp b/lib/Transforms/Scalar/GuardWidening.cpp index fb7c6e15758d..c4aeccb85ca7 100644 --- a/lib/Transforms/Scalar/GuardWidening.cpp +++ b/lib/Transforms/Scalar/GuardWidening.cpp @@ -664,6 +664,7 @@ PreservedAnalyses GuardWideningPass::run(Function &F, return PA; } +#ifndef NDEBUG StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { switch (WS) { case WS_IllegalOrNegative: @@ -678,6 +679,7 @@ StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { llvm_unreachable("Fully covered switch above!"); } +#endif char GuardWideningLegacyPass::ID = 0; diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp index 10782963177c..74d6014d3e3d 100644 --- a/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -25,27 +25,54 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/IndVarSimplify.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" @@ -53,6 +80,10 @@ #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SimplifyIndVar.h" +#include <cassert> +#include <cstdint> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "indvars" @@ -91,6 +122,7 @@ DisableLFTR("disable-lftr", cl::Hidden, cl::init(false), cl::desc("Disable Linear Function Test Replace optimization")); namespace { + struct RewritePhi; class IndVarSimplify { @@ -131,7 +163,8 @@ public: bool run(Loop *L); }; -} + +} // end anonymous namespace /// Return true if the SCEV expansion generated by the rewriter can replace the /// original value. SCEV guarantees that it produces the same value, but the way @@ -251,7 +284,6 @@ static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) { /// is converted into /// for(int i = 0; i < 10000; ++i) /// bar((double)i); -/// void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); unsigned BackEdge = IncomingEdge^1; @@ -305,7 +337,6 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { L->contains(TheBr->getSuccessor(1)))) return; - // If it isn't a comparison with an integer-as-fp (the exit value), we can't // transform it. ConstantFP *ExitValueVal = dyn_cast<ConstantFP>(Compare->getOperand(1)); @@ -373,7 +404,6 @@ void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { // transform the IV. if (Leftover != 0 && int32_t(ExitValue+IncValue) < ExitValue) return; - } else { // If we have a negative stride, we require the init to be greater than the // exit value. @@ -452,7 +482,6 @@ void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { // First step. Check to see if there are any floating-point recurrences. // If there are, change them into integer recurrences, permitting analysis by // the SCEV routines. - // BasicBlock *Header = L->getHeader(); SmallVector<WeakTrackingVH, 8> PHIs; @@ -472,18 +501,26 @@ void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { } namespace { + // Collect information about PHI nodes which can be transformed in // rewriteLoopExitValues. struct RewritePhi { PHINode *PN; - unsigned Ith; // Ith incoming value. - Value *Val; // Exit value after expansion. - bool HighCost; // High Cost when expansion. + + // Ith incoming value. + unsigned Ith; + + // Exit value after expansion. + Value *Val; + + // High Cost when expansion. + bool HighCost; RewritePhi(PHINode *P, unsigned I, Value *V, bool H) : PN(P), Ith(I), Val(V), HighCost(H) {} }; -} + +} // end anonymous namespace Value *IndVarSimplify::expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, Instruction *InsertPt, @@ -747,7 +784,6 @@ void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { /// aggressively. bool IndVarSimplify::canLoopBeDeleted( Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) { - BasicBlock *Preheader = L->getLoopPreheader(); // If there is no preheader, the loop will not be deleted. if (!Preheader) @@ -790,7 +826,9 @@ bool IndVarSimplify::canLoopBeDeleted( } for (auto *BB : L->blocks()) - if (any_of(*BB, [](Instruction &I) { return I.mayHaveSideEffects(); })) + if (llvm::any_of(*BB, [](Instruction &I) { + return I.mayHaveSideEffects(); + })) return false; return true; @@ -801,15 +839,21 @@ bool IndVarSimplify::canLoopBeDeleted( //===----------------------------------------------------------------------===// namespace { + // Collect information about induction variables that are used by sign/zero // extend operations. This information is recorded by CollectExtend and provides // the input to WidenIV. struct WideIVInfo { PHINode *NarrowIV = nullptr; - Type *WidestNativeType = nullptr; // Widest integer type created [sz]ext - bool IsSigned = false; // Was a sext user seen before a zext? + + // Widest integer type created [sz]ext + Type *WidestNativeType = nullptr; + + // Was a sext user seen before a zext? + bool IsSigned = false; }; -} + +} // end anonymous namespace /// Update information about the induction variable that is extended by this /// sign or zero extend operation. This is used to determine the final width of @@ -885,7 +929,6 @@ struct NarrowIVDefUse { /// creating any new induction variables. To do this, it creates a new phi of /// the wider type and redirects all users, either removing extends or inserting /// truncs whenever we stop propagating the type. -/// class WidenIV { // Parameters PHINode *OrigPhi; @@ -902,22 +945,24 @@ class WidenIV { bool HasGuards; // Result - PHINode *WidePhi; - Instruction *WideInc; - const SCEV *WideIncExpr; + PHINode *WidePhi = nullptr; + Instruction *WideInc = nullptr; + const SCEV *WideIncExpr = nullptr; SmallVectorImpl<WeakTrackingVH> &DeadInsts; SmallPtrSet<Instruction *,16> Widened; SmallVector<NarrowIVDefUse, 8> NarrowIVUsers; enum ExtendKind { ZeroExtended, SignExtended, Unknown }; + // A map tracking the kind of extension used to widen each narrow IV // and narrow IV user. // Key: pointer to a narrow IV or IV user. // Value: the kind of extension used to widen this Instruction. DenseMap<AssertingVH<Instruction>, ExtendKind> ExtendKindMap; - typedef std::pair<AssertingVH<Value>, AssertingVH<Instruction>> DefUserPair; + using DefUserPair = std::pair<AssertingVH<Value>, AssertingVH<Instruction>>; + // A map with control-dependent ranges for post increment IV uses. The key is // a pair of IV def and a use of this def denoting the context. The value is // a ConstantRange representing possible values of the def at the given @@ -935,6 +980,7 @@ class WidenIV { void calculatePostIncRanges(PHINode *OrigPhi); void calculatePostIncRange(Instruction *NarrowDef, Instruction *NarrowUser); + void updatePostIncRangeInfo(Value *Def, Instruction *UseI, ConstantRange R) { DefUserPair Key(Def, UseI); auto It = PostIncRangeInfos.find(Key); @@ -950,8 +996,7 @@ public: bool HasGuards) : OrigPhi(WI.NarrowIV), WideType(WI.WidestNativeType), LI(LInfo), L(LI->getLoopFor(OrigPhi->getParent())), SE(SEv), DT(DTree), - HasGuards(HasGuards), WidePhi(nullptr), WideInc(nullptr), - WideIncExpr(nullptr), DeadInsts(DI) { + HasGuards(HasGuards), DeadInsts(DI) { assert(L->getHeader() == OrigPhi->getParent() && "Phi must be an IV"); ExtendKindMap[OrigPhi] = WI.IsSigned ? SignExtended : ZeroExtended; } @@ -969,7 +1014,7 @@ protected: ExtendKind getExtendKind(Instruction *I); - typedef std::pair<const SCEVAddRecExpr *, ExtendKind> WidenedRecTy; + using WidenedRecTy = std::pair<const SCEVAddRecExpr *, ExtendKind>; WidenedRecTy getWideRecurrence(NarrowIVDefUse DU); @@ -984,7 +1029,8 @@ protected: void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef); }; -} // anonymous namespace + +} // end anonymous namespace /// Perform a quick domtree based check for loop invariance assuming that V is /// used within the loop. LoopInfo::isLoopInvariant() seems gratuitous for this @@ -1182,7 +1228,6 @@ const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, /// operands is an AddRec for this loop, return the AddRec and the kind of /// extension used. WidenIV::WidenedRecTy WidenIV::getExtendedOperandRecurrence(NarrowIVDefUse DU) { - // Handle the common case of add<nsw/nuw> const unsigned OpCode = DU.NarrowUse->getOpcode(); // Only Add/Sub/Mul instructions supported yet. @@ -1310,7 +1355,7 @@ bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { Value *Op = Cmp->getOperand(Cmp->getOperand(0) == DU.NarrowDef ? 1 : 0); unsigned CastWidth = SE->getTypeSizeInBits(Op->getType()); unsigned IVWidth = SE->getTypeSizeInBits(WideType); - assert (CastWidth <= IVWidth && "Unexpected width while widening compare."); + assert(CastWidth <= IVWidth && "Unexpected width while widening compare."); // Widen the compare instruction. IRBuilder<> Builder( @@ -1461,7 +1506,6 @@ Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { } /// Add eligible users of NarrowDef to NarrowIVUsers. -/// void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { const SCEV *NarrowSCEV = SE->getSCEV(NarrowDef); bool NonNegativeDef = @@ -1494,7 +1538,6 @@ void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { /// /// It would be simpler to delete uses as they are processed, but we must avoid /// invalidating SCEV expressions. -/// PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { // Is this phi an induction variable? const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(OrigPhi)); @@ -1581,6 +1624,15 @@ PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { if (DU.NarrowDef->use_empty()) DeadInsts.emplace_back(DU.NarrowDef); } + + // Attach any debug information to the new PHI. Since OrigPhi and WidePHI + // evaluate the same recurrence, we can just copy the debug info over. + SmallVector<DbgValueInst *, 1> DbgValues; + llvm::findDbgValues(DbgValues, OrigPhi); + auto *MDPhi = MetadataAsValue::get(WidePhi->getContext(), + ValueAsMetadata::get(WidePhi)); + for (auto &DbgValue : DbgValues) + DbgValue->setOperand(0, MDPhi); return WidePhi; } @@ -1696,12 +1748,12 @@ void WidenIV::calculatePostIncRanges(PHINode *OrigPhi) { // Live IV Reduction - Minimize IVs live across the loop. //===----------------------------------------------------------------------===// - //===----------------------------------------------------------------------===// // Simplification of IV users based on SCEV evaluation. //===----------------------------------------------------------------------===// namespace { + class IndVarSimplifyVisitor : public IVVisitor { ScalarEvolution *SE; const TargetTransformInfo *TTI; @@ -1721,14 +1773,14 @@ public: // Implement the interface used by simplifyUsersOfIV. void visitCast(CastInst *Cast) override { visitIVCast(Cast, WI, SE, TTI); } }; -} + +} // end anonymous namespace /// Iteratively perform simplification on a worklist of IV users. Each /// successive simplification may push more users which may themselves be /// candidates for simplification. /// /// Sign/Zero extend elimination is interleaved with IV simplification. -/// void IndVarSimplify::simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI) { @@ -1759,7 +1811,8 @@ void IndVarSimplify::simplifyAndExtend(Loop *L, // Information about sign/zero extensions of CurrIV. IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT); - Changed |= simplifyUsersOfIV(CurrIV, SE, DT, LI, DeadInsts, &Visitor); + Changed |= + simplifyUsersOfIV(CurrIV, SE, DT, LI, DeadInsts, Rewriter, &Visitor); if (Visitor.WI.WidestNativeType) { WideIVs.push_back(Visitor.WI); @@ -2501,8 +2554,10 @@ PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, } namespace { + struct IndVarSimplifyLegacyPass : public LoopPass { static char ID; // Pass identification, replacement for typeid + IndVarSimplifyLegacyPass() : LoopPass(ID) { initializeIndVarSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -2529,9 +2584,11 @@ struct IndVarSimplifyLegacyPass : public LoopPass { getLoopAnalysisUsage(AU); } }; -} + +} // end anonymous namespace char IndVarSimplifyLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars", "Induction Variable Simplification", false, false) INITIALIZE_PASS_DEPENDENCY(LoopPass) diff --git a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp index 99b4458ea0fa..5c4d55bfbb2b 100644 --- a/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp +++ b/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -1,4 +1,4 @@ -//===-- InductiveRangeCheckElimination.cpp - ------------------------------===// +//===- InductiveRangeCheckElimination.cpp - -------------------------------===// // // The LLVM Compiler Infrastructure // @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// // The InductiveRangeCheckElimination pass splits a loop's iteration space into // three disjoint ranges. It does that in a way such that the loop running in // the middle loop provably does not need range checks. As an example, it will @@ -39,30 +40,61 @@ // throw_out_of_bounds(); // } // } +// //===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <limits> +#include <utility> +#include <vector> using namespace llvm; +using namespace llvm::PatternMatch; static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, cl::init(64)); @@ -79,6 +111,9 @@ static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", cl::Hidden, cl::init(false)); +static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch", + cl::Hidden, cl::init(true)); + static const char *ClonedLoopTag = "irce.loop.clone"; #define DEBUG_TYPE "irce" @@ -114,15 +149,16 @@ class InductiveRangeCheck { static StringRef rangeCheckKindToStr(RangeCheckKind); - const SCEV *Offset = nullptr; - const SCEV *Scale = nullptr; - Value *Length = nullptr; + const SCEV *Begin = nullptr; + const SCEV *Step = nullptr; + const SCEV *End = nullptr; Use *CheckUse = nullptr; RangeCheckKind Kind = RANGE_CHECK_UNKNOWN; + bool IsSigned = true; static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, - Value *&Length); + Value *&Length, bool &IsSigned); static void extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, @@ -130,20 +166,21 @@ class InductiveRangeCheck { SmallPtrSetImpl<Value *> &Visited); public: - const SCEV *getOffset() const { return Offset; } - const SCEV *getScale() const { return Scale; } - Value *getLength() const { return Length; } + const SCEV *getBegin() const { return Begin; } + const SCEV *getStep() const { return Step; } + const SCEV *getEnd() const { return End; } + bool isSigned() const { return IsSigned; } void print(raw_ostream &OS) const { OS << "InductiveRangeCheck:\n"; OS << " Kind: " << rangeCheckKindToStr(Kind) << "\n"; - OS << " Offset: "; - Offset->print(OS); - OS << " Scale: "; - Scale->print(OS); - OS << " Length: "; - if (Length) - Length->print(OS); + OS << " Begin: "; + Begin->print(OS); + OS << " Step: "; + Step->print(OS); + OS << " End: "; + if (End) + End->print(OS); else OS << "(null)"; OS << "\n CheckUse: "; @@ -173,6 +210,14 @@ public: Type *getType() const { return Begin->getType(); } const SCEV *getBegin() const { return Begin; } const SCEV *getEnd() const { return End; } + bool isEmpty(ScalarEvolution &SE, bool IsSigned) const { + if (Begin == End) + return true; + if (IsSigned) + return SE.isKnownPredicate(ICmpInst::ICMP_SGE, Begin, End); + else + return SE.isKnownPredicate(ICmpInst::ICMP_UGE, Begin, End); + } }; /// This is the value the condition of the branch needs to evaluate to for the @@ -183,7 +228,8 @@ public: /// check is redundant and can be constant-folded away. The induction /// variable is not required to be the canonical {0,+,1} induction variable. Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, - const SCEVAddRecExpr *IndVar) const; + const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const; /// Parse out a set of inductive range checks from \p BI and append them to \p /// Checks. @@ -199,6 +245,7 @@ public: class InductiveRangeCheckElimination : public LoopPass { public: static char ID; + InductiveRangeCheckElimination() : LoopPass(ID) { initializeInductiveRangeCheckEliminationPass( *PassRegistry::getPassRegistry()); @@ -212,8 +259,9 @@ public: bool runOnLoop(Loop *L, LPPassManager &LPM) override; }; +} // end anonymous namespace + char InductiveRangeCheckElimination::ID = 0; -} INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce", "Inductive range check elimination", false, false) @@ -247,12 +295,10 @@ StringRef InductiveRangeCheck::rangeCheckKindToStr( /// range checked, and set `Length` to the upper limit `Index` is being range /// checked with if (and only if) the range check type is stronger or equal to /// RANGE_CHECK_UPPER. -/// InductiveRangeCheck::RangeCheckKind InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, ScalarEvolution &SE, Value *&Index, - Value *&Length) { - + Value *&Length, bool &IsSigned) { auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { const SCEV *S = SE.getSCEV(V); if (isa<SCEVCouldNotCompute>(S)) @@ -262,8 +308,6 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, SE.isKnownNonNegative(S); }; - using namespace llvm::PatternMatch; - ICmpInst::Predicate Pred = ICI->getPredicate(); Value *LHS = ICI->getOperand(0); Value *RHS = ICI->getOperand(1); @@ -276,6 +320,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGE: + IsSigned = true; if (match(RHS, m_ConstantInt<0>())) { Index = LHS; return RANGE_CHECK_LOWER; @@ -286,6 +331,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_SGT: + IsSigned = true; if (match(RHS, m_ConstantInt<-1>())) { Index = LHS; return RANGE_CHECK_LOWER; @@ -302,6 +348,7 @@ InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, std::swap(LHS, RHS); LLVM_FALLTHROUGH; case ICmpInst::ICMP_UGT: + IsSigned = false; if (IsNonNegativeAndNotLoopVarying(LHS)) { Index = RHS; Length = LHS; @@ -317,42 +364,16 @@ void InductiveRangeCheck::extractRangeChecksFromCond( Loop *L, ScalarEvolution &SE, Use &ConditionUse, SmallVectorImpl<InductiveRangeCheck> &Checks, SmallPtrSetImpl<Value *> &Visited) { - using namespace llvm::PatternMatch; - Value *Condition = ConditionUse.get(); if (!Visited.insert(Condition).second) return; + // TODO: Do the same for OR, XOR, NOT etc? if (match(Condition, m_And(m_Value(), m_Value()))) { - SmallVector<InductiveRangeCheck, 8> SubChecks; extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0), - SubChecks, Visited); + Checks, Visited); extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1), - SubChecks, Visited); - - if (SubChecks.size() == 2) { - // Handle a special case where we know how to merge two checks separately - // checking the upper and lower bounds into a full range check. - const auto &RChkA = SubChecks[0]; - const auto &RChkB = SubChecks[1]; - if ((RChkA.Length == RChkB.Length || !RChkA.Length || !RChkB.Length) && - RChkA.Offset == RChkB.Offset && RChkA.Scale == RChkB.Scale) { - - // If RChkA.Kind == RChkB.Kind then we just found two identical checks. - // But if one of them is a RANGE_CHECK_LOWER and the other is a - // RANGE_CHECK_UPPER (only possibility if they're different) then - // together they form a RANGE_CHECK_BOTH. - SubChecks[0].Kind = - (InductiveRangeCheck::RangeCheckKind)(RChkA.Kind | RChkB.Kind); - SubChecks[0].Length = RChkA.Length ? RChkA.Length : RChkB.Length; - SubChecks[0].CheckUse = &ConditionUse; - - // We updated one of the checks in place, now erase the other. - SubChecks.pop_back(); - } - } - - Checks.insert(Checks.end(), SubChecks.begin(), SubChecks.end()); + Checks, Visited); return; } @@ -361,7 +382,8 @@ void InductiveRangeCheck::extractRangeChecksFromCond( return; Value *Length = nullptr, *Index; - auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length); + bool IsSigned; + auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned); if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) return; @@ -373,18 +395,18 @@ void InductiveRangeCheck::extractRangeChecksFromCond( return; InductiveRangeCheck IRC; - IRC.Length = Length; - IRC.Offset = IndexAddRec->getStart(); - IRC.Scale = IndexAddRec->getStepRecurrence(SE); + IRC.End = Length ? SE.getSCEV(Length) : nullptr; + IRC.Begin = IndexAddRec->getStart(); + IRC.Step = IndexAddRec->getStepRecurrence(SE); IRC.CheckUse = &ConditionUse; IRC.Kind = RCKind; + IRC.IsSigned = IsSigned; Checks.push_back(IRC); } void InductiveRangeCheck::extractRangeChecksFromBranch( BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo &BPI, SmallVectorImpl<InductiveRangeCheck> &Checks) { - if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) return; @@ -435,16 +457,16 @@ namespace { // kinds of loops we can deal with -- ones that have a single latch that is also // an exiting block *and* have a canonical induction variable. struct LoopStructure { - const char *Tag; + const char *Tag = ""; - BasicBlock *Header; - BasicBlock *Latch; + BasicBlock *Header = nullptr; + BasicBlock *Latch = nullptr; // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th // successor is `LatchExit', the exit block of the loop. - BranchInst *LatchBr; - BasicBlock *LatchExit; - unsigned LatchBrExitIdx; + BranchInst *LatchBr = nullptr; + BasicBlock *LatchExit = nullptr; + unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max(); // The loop represented by this instance of LoopStructure is semantically // equivalent to: @@ -452,18 +474,17 @@ struct LoopStructure { // intN_ty inc = IndVarIncreasing ? 1 : -1; // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT; // - // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarNext) + // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase) // ... body ... - Value *IndVarNext; - Value *IndVarStart; - Value *LoopExitAt; - bool IndVarIncreasing; + Value *IndVarBase = nullptr; + Value *IndVarStart = nullptr; + Value *IndVarStep = nullptr; + Value *LoopExitAt = nullptr; + bool IndVarIncreasing = false; + bool IsSignedPredicate = true; - LoopStructure() - : Tag(""), Header(nullptr), Latch(nullptr), LatchBr(nullptr), - LatchExit(nullptr), LatchBrExitIdx(-1), IndVarNext(nullptr), - IndVarStart(nullptr), LoopExitAt(nullptr), IndVarIncreasing(false) {} + LoopStructure() = default; template <typename M> LoopStructure map(M Map) const { LoopStructure Result; @@ -473,10 +494,12 @@ struct LoopStructure { Result.LatchBr = cast<BranchInst>(Map(LatchBr)); Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); Result.LatchBrExitIdx = LatchBrExitIdx; - Result.IndVarNext = Map(IndVarNext); + Result.IndVarBase = Map(IndVarBase); Result.IndVarStart = Map(IndVarStart); + Result.IndVarStep = Map(IndVarStep); Result.LoopExitAt = Map(LoopExitAt); Result.IndVarIncreasing = IndVarIncreasing; + Result.IsSignedPredicate = IsSignedPredicate; return Result; } @@ -494,7 +517,6 @@ struct LoopStructure { /// loops to run any remaining iterations. The pre loop runs any iterations in /// which the induction variable is < Begin, and the post loop runs any /// iterations in which the induction variable is >= End. -/// class LoopConstrainer { // The representation of a clone of the original loop we started out with. struct ClonedLoop { @@ -511,13 +533,12 @@ class LoopConstrainer { // Result of rewriting the range of a loop. See changeIterationSpaceEnd for // more details on what these fields mean. struct RewrittenRangeInfo { - BasicBlock *PseudoExit; - BasicBlock *ExitSelector; + BasicBlock *PseudoExit = nullptr; + BasicBlock *ExitSelector = nullptr; std::vector<PHINode *> PHIValuesAtPseudoExit; - PHINode *IndVarEnd; + PHINode *IndVarEnd = nullptr; - RewrittenRangeInfo() - : PseudoExit(nullptr), ExitSelector(nullptr), IndVarEnd(nullptr) {} + RewrittenRangeInfo() = default; }; // Calculated subranges we restrict the iteration space of the main loop to. @@ -541,14 +562,12 @@ class LoopConstrainer { // Compute a safe set of limits for the main loop to run in -- effectively the // intersection of `Range' and the iteration space of the original loop. // Return None if unable to compute the set of subranges. - // - Optional<SubRanges> calculateSubRanges() const; + Optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const; // Clone `OriginalLoop' and return the result in CLResult. The IR after // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- // the PHI nodes say that there is an incoming edge from `OriginalPreheader` // but there is no such edge. - // void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; // Create the appropriate loop structure needed to describe a cloned copy of @@ -577,7 +596,6 @@ class LoopConstrainer { // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate // preheader because it is made to branch to the loop header only // conditionally. - // RewrittenRangeInfo changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, Value *ExitLoopAt, @@ -585,7 +603,6 @@ class LoopConstrainer { // The loop denoted by `LS' has `OldPreheader' as its preheader. This // function creates a new preheader for `LS' and returns it. - // BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, const char *Tag) const; @@ -613,12 +630,13 @@ class LoopConstrainer { // Information about the original loop we started out with. Loop &OriginalLoop; - const SCEV *LatchTakenCount; - BasicBlock *OriginalPreheader; + + const SCEV *LatchTakenCount = nullptr; + BasicBlock *OriginalPreheader = nullptr; // The preheader of the main loop. This may or may not be different from // `OriginalPreheader'. - BasicBlock *MainLoopPreheader; + BasicBlock *MainLoopPreheader = nullptr; // The range we need to run the main loop in. InductiveRangeCheck::Range Range; @@ -632,15 +650,14 @@ public: const LoopStructure &LS, ScalarEvolution &SE, DominatorTree &DT, InductiveRangeCheck::Range R) : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), - SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), - LatchTakenCount(nullptr), OriginalPreheader(nullptr), - MainLoopPreheader(nullptr), Range(R), MainLoopStructure(LS) {} + SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), Range(R), + MainLoopStructure(LS) {} // Entry point for the algorithm. Returns true on success. bool run(); }; -} +} // end anonymous namespace void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, BasicBlock *ReplaceBy) { @@ -649,22 +666,55 @@ void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, PN->setIncomingBlock(i, ReplaceBy); } -static bool CanBeSMax(ScalarEvolution &SE, const SCEV *S) { - APInt SMax = - APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(SMax) && - SE.getUnsignedRange(S).contains(SMax); +static bool CanBeMax(ScalarEvolution &SE, const SCEV *S, bool Signed) { + APInt Max = Signed ? + APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()) : + APInt::getMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(Max) && + SE.getUnsignedRange(S).contains(Max); +} + +static bool SumCanReachMax(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, + bool Signed) { + // S1 < INT_MAX - S2 ===> S1 + S2 < INT_MAX. + assert(SE.isKnownNonNegative(S2) && + "We expected the 2nd arg to be non-negative!"); + const SCEV *Max = SE.getConstant( + Signed ? APInt::getSignedMaxValue( + cast<IntegerType>(S1->getType())->getBitWidth()) + : APInt::getMaxValue( + cast<IntegerType>(S1->getType())->getBitWidth())); + const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); + return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + S1, CapForS1); } -static bool CanBeSMin(ScalarEvolution &SE, const SCEV *S) { - APInt SMin = - APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()); - return SE.getSignedRange(S).contains(SMin) && - SE.getUnsignedRange(S).contains(SMin); +static bool CanBeMin(ScalarEvolution &SE, const SCEV *S, bool Signed) { + APInt Min = Signed ? + APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()) : + APInt::getMinValue(cast<IntegerType>(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(Min) && + SE.getUnsignedRange(S).contains(Min); +} + +static bool SumCanReachMin(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, + bool Signed) { + // S1 > INT_MIN - S2 ===> S1 + S2 > INT_MIN. + assert(SE.isKnownNonPositive(S2) && + "We expected the 2nd arg to be non-positive!"); + const SCEV *Max = SE.getConstant( + Signed ? APInt::getSignedMinValue( + cast<IntegerType>(S1->getType())->getBitWidth()) + : APInt::getMinValue( + cast<IntegerType>(S1->getType())->getBitWidth())); + const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); + return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, + S1, CapForS1); } Optional<LoopStructure> -LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BPI, +LoopStructure::parseLoopStructure(ScalarEvolution &SE, + BranchProbabilityInfo &BPI, Loop &L, const char *&FailureReason) { if (!L.isLoopSimplifyForm()) { FailureReason = "loop not in LoopSimplify form"; @@ -766,7 +816,11 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; }; - auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing) { + // Here we check whether the suggested AddRec is an induction variable that + // can be handled (i.e. with known constant step), and if yes, calculate its + // step and identify whether it is increasing or decreasing. + auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing, + ConstantInt *&StepCI) { if (!AR->isAffine()) return false; @@ -778,11 +832,10 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP if (const SCEVConstant *StepExpr = dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) { - ConstantInt *StepCI = StepExpr->getValue(); - if (StepCI->isOne() || StepCI->isMinusOne()) { - IsIncreasing = StepCI->isOne(); - return true; - } + StepCI = StepExpr->getValue(); + assert(!StepCI->isZero() && "Zero step?"); + IsIncreasing = !StepCI->isNegative(); + return true; } return false; @@ -791,59 +844,87 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP // `ICI` is interpreted as taking the backedge if the *next* value of the // induction variable satisfies some constraint. - const SCEVAddRecExpr *IndVarNext = cast<SCEVAddRecExpr>(LeftSCEV); + const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); bool IsIncreasing = false; - if (!IsInductionVar(IndVarNext, IsIncreasing)) { + bool IsSignedPredicate = true; + ConstantInt *StepCI; + if (!IsInductionVar(IndVarBase, IsIncreasing, StepCI)) { FailureReason = "LHS in icmp not induction variable"; return None; } - const SCEV *StartNext = IndVarNext->getStart(); - const SCEV *Addend = SE.getNegativeSCEV(IndVarNext->getStepRecurrence(SE)); + const SCEV *StartNext = IndVarBase->getStart(); + const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); + const SCEV *Step = SE.getSCEV(StepCI); ConstantInt *One = ConstantInt::get(IndVarTy, 1); - // TODO: generalize the predicates here to also match their unsigned variants. if (IsIncreasing) { bool DecreasedRightValueByOne = false; - // Try to turn eq/ne predicates to those we can work with. - if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) - // while (++i != len) { while (++i < len) { - // ... ---> ... - // } } - Pred = ICmpInst::ICMP_SLT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeSMin(SE, RightSCEV)) { - // while (true) { while (true) { - // if (++i == len) ---> if (++i > len - 1) - // break; break; - // ... ... - // } } - Pred = ICmpInst::ICMP_SGT; - RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); - DecreasedRightValueByOne = true; + if (StepCI->isOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (++i != len) { while (++i < len) { + // ... ---> ... + // } } + // If both parts are known non-negative, it is profitable to use + // unsigned comparison in increasing loop. This allows us to make the + // comparison check against "RightSCEV + 1" more optimistic. + if (SE.isKnownNonNegative(IndVarStart) && + SE.isKnownNonNegative(RightSCEV)) + Pred = ICmpInst::ICMP_ULT; + else + Pred = ICmpInst::ICMP_SLT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeMin(SE, RightSCEV, /* IsSignedPredicate */ true)) { + // while (true) { while (true) { + // if (++i == len) ---> if (++i > len - 1) + // break; break; + // ... ... + // } } + // TODO: Insert ICMP_UGT if both are non-negative? + Pred = ICmpInst::ICMP_SGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } } + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); bool FoundExpectedPred = - (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 1) || - (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 0); + (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); if (!FoundExpectedPred) { FailureReason = "expected icmp slt semantically, found something else"; return None; } + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return None; + } + + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSignedPredicate ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + if (LatchBrExitIdx == 0) { - if (CanBeSMax(SE, RightSCEV)) { + const SCEV *StepMinusOne = SE.getMinusSCEV(Step, + SE.getOne(Step->getType())); + if (SumCanReachMax(SE, RightSCEV, StepMinusOne, IsSignedPredicate)) { // TODO: this restriction is easily removable -- we just have to // remember that the icmp was an slt and not an sle. - FailureReason = "limit may overflow when coercing sle to slt"; + FailureReason = "limit may overflow when coercing le to lt"; return None; } if (!SE.isLoopEntryGuardedByCond( - &L, CmpInst::ICMP_SLT, IndVarStart, - SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())))) { + &L, BoundPred, IndVarStart, + SE.getAddExpr(RightSCEV, Step))) { FailureReason = "Induction variable start not bounded by upper limit"; return None; } @@ -855,8 +936,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP RightValue = B.CreateAdd(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SLT, IndVarStart, - RightSCEV)) { + if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { FailureReason = "Induction variable start not bounded by upper limit"; return None; } @@ -865,43 +945,65 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP } } else { bool IncreasedRightValueByOne = false; - // Try to turn eq/ne predicates to those we can work with. - if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) - // while (--i != len) { while (--i > len) { - // ... ---> ... - // } } - Pred = ICmpInst::ICMP_SGT; - else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && - !CanBeSMax(SE, RightSCEV)) { - // while (true) { while (true) { - // if (--i == len) ---> if (--i < len + 1) - // break; break; - // ... ... - // } } - Pred = ICmpInst::ICMP_SLT; - RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); - IncreasedRightValueByOne = true; + if (StepCI->isMinusOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (--i != len) { while (--i > len) { + // ... ---> ... + // } } + // We intentionally don't turn the predicate into UGT even if we know + // that both operands are non-negative, because it will only pessimize + // our check against "RightSCEV - 1". + Pred = ICmpInst::ICMP_SGT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeMax(SE, RightSCEV, /* IsSignedPredicate */ true)) { + // while (true) { while (true) { + // if (--i == len) ---> if (--i < len + 1) + // break; break; + // ... ... + // } } + // TODO: Insert ICMP_ULT if both are non-negative? + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } } + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + bool FoundExpectedPred = - (Pred == ICmpInst::ICMP_SGT && LatchBrExitIdx == 1) || - (Pred == ICmpInst::ICMP_SLT && LatchBrExitIdx == 0); + (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); if (!FoundExpectedPred) { FailureReason = "expected icmp sgt semantically, found something else"; return None; } + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return None; + } + + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSignedPredicate ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + if (LatchBrExitIdx == 0) { - if (CanBeSMin(SE, RightSCEV)) { + const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); + if (SumCanReachMin(SE, RightSCEV, StepPlusOne, IsSignedPredicate)) { // TODO: this restriction is easily removable -- we just have to // remember that the icmp was an sgt and not an sge. - FailureReason = "limit may overflow when coercing sge to sgt"; + FailureReason = "limit may overflow when coercing ge to gt"; return None; } if (!SE.isLoopEntryGuardedByCond( - &L, CmpInst::ICMP_SGT, IndVarStart, + &L, BoundPred, IndVarStart, SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) { FailureReason = "Induction variable start not bounded by lower limit"; return None; @@ -914,8 +1016,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP RightValue = B.CreateSub(RightValue, One); } } else { - if (!SE.isLoopEntryGuardedByCond(&L, CmpInst::ICMP_SGT, IndVarStart, - RightSCEV)) { + if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { FailureReason = "Induction variable start not bounded by lower limit"; return None; } @@ -923,7 +1024,6 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP "Right value can be increased only for LatchBrExitIdx == 0!"); } } - BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); assert(SE.getLoopDisposition(LatchCount, &L) == @@ -946,9 +1046,11 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP Result.LatchExit = LatchExit; Result.LatchBrExitIdx = LatchBrExitIdx; Result.IndVarStart = IndVarStartV; - Result.IndVarNext = LeftValue; + Result.IndVarStep = StepCI; + Result.IndVarBase = LeftValue; Result.IndVarIncreasing = IsIncreasing; Result.LoopExitAt = RightValue; + Result.IsSignedPredicate = IsSignedPredicate; FailureReason = nullptr; @@ -956,7 +1058,7 @@ LoopStructure::parseLoopStructure(ScalarEvolution &SE, BranchProbabilityInfo &BP } Optional<LoopConstrainer::SubRanges> -LoopConstrainer::calculateSubRanges() const { +LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); if (Range.getType() != Ty) @@ -999,26 +1101,31 @@ LoopConstrainer::calculateSubRanges() const { // that case, `Clamp` will always return `Smallest` and // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) // will be an empty range. Returning an empty range is always safe. - // Smallest = SE.getAddExpr(End, One); Greatest = SE.getAddExpr(Start, One); GreatestSeen = Start; } - auto Clamp = [this, Smallest, Greatest](const SCEV *S) { - return SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)); + auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) { + return IsSignedPredicate + ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)) + : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S)); }; - // In some cases we can prove that we don't need a pre or post loop + // In some cases we can prove that we don't need a pre or post loop. + ICmpInst::Predicate PredLE = + IsSignedPredicate ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; + ICmpInst::Predicate PredLT = + IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; bool ProvablyNoPreloop = - SE.isKnownPredicate(ICmpInst::ICMP_SLE, Range.getBegin(), Smallest); + SE.isKnownPredicate(PredLE, Range.getBegin(), Smallest); if (!ProvablyNoPreloop) Result.LowLimit = Clamp(Range.getBegin()); bool ProvablyNoPostLoop = - SE.isKnownPredicate(ICmpInst::ICMP_SLT, GreatestSeen, Range.getEnd()); + SE.isKnownPredicate(PredLT, GreatestSeen, Range.getEnd()); if (!ProvablyNoPostLoop) Result.HighLimit = Clamp(Range.getEnd()); @@ -1082,7 +1189,6 @@ void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, BasicBlock *ContinuationBlock) const { - // We start with a loop with a single latch: // // +--------------------+ @@ -1153,7 +1259,6 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // | original exit <----+ // | | // +--------------------+ - // RewrittenRangeInfo RRI; @@ -1165,22 +1270,35 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); bool Increasing = LS.IndVarIncreasing; + bool IsSignedPredicate = LS.IsSignedPredicate; IRBuilder<> B(PreheaderJump); // EnterLoopCond - is it okay to start executing this `LS'? - Value *EnterLoopCond = Increasing - ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) - : B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt); + Value *EnterLoopCond = nullptr; + if (Increasing) + EnterLoopCond = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) + : B.CreateICmpULT(LS.IndVarStart, ExitSubloopAt); + else + EnterLoopCond = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt) + : B.CreateICmpUGT(LS.IndVarStart, ExitSubloopAt); B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); PreheaderJump->eraseFromParent(); LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); B.SetInsertPoint(LS.LatchBr); - Value *TakeBackedgeLoopCond = - Increasing ? B.CreateICmpSLT(LS.IndVarNext, ExitSubloopAt) - : B.CreateICmpSGT(LS.IndVarNext, ExitSubloopAt); + Value *TakeBackedgeLoopCond = nullptr; + if (Increasing) + TakeBackedgeLoopCond = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarBase, ExitSubloopAt) + : B.CreateICmpULT(LS.IndVarBase, ExitSubloopAt); + else + TakeBackedgeLoopCond = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarBase, ExitSubloopAt) + : B.CreateICmpUGT(LS.IndVarBase, ExitSubloopAt); Value *CondForBranch = LS.LatchBrExitIdx == 1 ? TakeBackedgeLoopCond : B.CreateNot(TakeBackedgeLoopCond); @@ -1192,9 +1310,15 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( // IterationsLeft - are there any more iterations left, given the original // upper bound on the induction variable? If not, we branch to the "real" // exit. - Value *IterationsLeft = Increasing - ? B.CreateICmpSLT(LS.IndVarNext, LS.LoopExitAt) - : B.CreateICmpSGT(LS.IndVarNext, LS.LoopExitAt); + Value *IterationsLeft = nullptr; + if (Increasing) + IterationsLeft = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarBase, LS.LoopExitAt) + : B.CreateICmpULT(LS.IndVarBase, LS.LoopExitAt); + else + IterationsLeft = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarBase, LS.LoopExitAt) + : B.CreateICmpUGT(LS.IndVarBase, LS.LoopExitAt); B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); BranchInst *BranchToContinuation = @@ -1217,10 +1341,10 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( RRI.PHIValuesAtPseudoExit.push_back(NewPHI); } - RRI.IndVarEnd = PHINode::Create(LS.IndVarNext->getType(), 2, "indvar.end", + RRI.IndVarEnd = PHINode::Create(LS.IndVarBase->getType(), 2, "indvar.end", BranchToContinuation); RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); - RRI.IndVarEnd->addIncoming(LS.IndVarNext, RRI.ExitSelector); + RRI.IndVarEnd->addIncoming(LS.IndVarBase, RRI.ExitSelector); // The latch exit now has a branch from `RRI.ExitSelector' instead of // `LS.Latch'. The PHI nodes need to be updated to reflect that. @@ -1237,7 +1361,6 @@ LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( void LoopConstrainer::rewriteIncomingValuesForPHIs( LoopStructure &LS, BasicBlock *ContinuationBlock, const LoopConstrainer::RewrittenRangeInfo &RRI) const { - unsigned PHIIndex = 0; for (Instruction &I : *LS.Header) { auto *PN = dyn_cast<PHINode>(&I); @@ -1255,7 +1378,6 @@ void LoopConstrainer::rewriteIncomingValuesForPHIs( BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, const char *Tag) const { - BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); BranchInst::Create(LS.Header, Preheader); @@ -1282,7 +1404,7 @@ void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, ValueToValueMapTy &VM) { - Loop &New = *new Loop(); + Loop &New = *LI.AllocateLoop(); if (Parent) Parent->addChildLoop(&New); else @@ -1311,7 +1433,8 @@ bool LoopConstrainer::run() { OriginalPreheader = Preheader; MainLoopPreheader = Preheader; - Optional<SubRanges> MaybeSR = calculateSubRanges(); + bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; + Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); if (!MaybeSR.hasValue()) { DEBUG(dbgs() << "irce: could not compute subranges\n"); return false; @@ -1320,7 +1443,7 @@ bool LoopConstrainer::run() { SubRanges SR = MaybeSR.getValue(); bool Increasing = MainLoopStructure.IndVarIncreasing; IntegerType *IVTy = - cast<IntegerType>(MainLoopStructure.IndVarNext->getType()); + cast<IntegerType>(MainLoopStructure.IndVarBase->getType()); SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); Instruction *InsertPt = OriginalPreheader->getTerminator(); @@ -1345,7 +1468,7 @@ bool LoopConstrainer::run() { if (Increasing) ExitPreLoopAtSCEV = *SR.LowLimit; else { - if (CanBeSMin(SE, *SR.HighLimit)) { + if (CanBeMin(SE, *SR.HighLimit, IsSignedPredicate)) { DEBUG(dbgs() << "irce: could not prove no-overflow when computing " << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) << "\n"); @@ -1354,6 +1477,13 @@ bool LoopConstrainer::run() { ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); } + if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { + DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " preloop exit limit " << *ExitPreLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() << "\n"); + return false; + } + ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); ExitPreLoopAt->setName("exit.preloop.at"); } @@ -1364,7 +1494,7 @@ bool LoopConstrainer::run() { if (Increasing) ExitMainLoopAtSCEV = *SR.HighLimit; else { - if (CanBeSMin(SE, *SR.LowLimit)) { + if (CanBeMin(SE, *SR.LowLimit, IsSignedPredicate)) { DEBUG(dbgs() << "irce: could not prove no-overflow when computing " << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) << "\n"); @@ -1373,6 +1503,13 @@ bool LoopConstrainer::run() { ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); } + if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { + DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " main loop exit limit " << *ExitMainLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() << "\n"); + return false; + } + ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); ExitMainLoopAt->setName("exit.mainloop.at"); } @@ -1463,34 +1600,27 @@ bool LoopConstrainer::run() { /// range, returns None. Optional<InductiveRangeCheck::Range> InductiveRangeCheck::computeSafeIterationSpace( - ScalarEvolution &SE, const SCEVAddRecExpr *IndVar) const { + ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const { // IndVar is of the form "A + B * I" (where "I" is the canonical induction // variable, that may or may not exist as a real llvm::Value in the loop) and // this inductive range check is a range check on the "C + D * I" ("C" is - // getOffset() and "D" is getScale()). We rewrite the value being range + // getBegin() and "D" is getStep()). We rewrite the value being range // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". - // Currently we support this only for "B" = "D" = { 1 or -1 }, but the code - // can be generalized as needed. // // The actual inequalities we solve are of the form // // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) // - // The inequality is satisfied by -M <= IndVar < (L - M) [^1]. All additions - // and subtractions are twos-complement wrapping and comparisons are signed. - // - // Proof: - // - // If there exists IndVar such that -M <= IndVar < (L - M) then it follows - // that -M <= (-M + L) [== Eq. 1]. Since L >= 0, if (-M + L) sign-overflows - // then (-M + L) < (-M). Hence by [Eq. 1], (-M + L) could not have - // overflown. - // - // This means IndVar = t + (-M) for t in [0, L). Hence (IndVar + M) = t. - // Hence 0 <= (IndVar + M) < L - - // [^1]: Note that the solution does _not_ apply if L < 0; consider values M = - // 127, IndVar = 126 and L = -2 in an i8 world. + // Here L stands for upper limit of the safe iteration space. + // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid + // overflows when calculating (0 - M) and (L - M) we, depending on type of + // IV's iteration space, limit the calculations by borders of the iteration + // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0. + // If we figured out that "anything greater than (-M) is safe", we strengthen + // this to "everything greater than 0 is safe", assuming that values between + // -M and 0 just do not exist in unsigned iteration space, and we don't want + // to deal with overflown values. if (!IndVar->isAffine()) return None; @@ -1499,42 +1629,89 @@ InductiveRangeCheck::computeSafeIterationSpace( const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); if (!B) return None; + assert(!B->isZero() && "Recurrence with zero step?"); - const SCEV *C = getOffset(); - const SCEVConstant *D = dyn_cast<SCEVConstant>(getScale()); + const SCEV *C = getBegin(); + const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep()); if (D != B) return None; - ConstantInt *ConstD = D->getValue(); - if (!(ConstD->isMinusOne() || ConstD->isOne())) - return None; + assert(!D->getValue()->isZero() && "Recurrence with zero step?"); + unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); + const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + // Substract Y from X so that it does not go through border of the IV + // iteration space. Mathematically, it is equivalent to: + // + // ClampedSubstract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] + // + // In [1], 'X - Y' is a mathematical substraction (result is not bounded to + // any width of bit grid). But after we take min/max, the result is + // guaranteed to be within [INT_MIN, INT_MAX]. + // + // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min + // values, depending on type of latch condition that defines IV iteration + // space. + auto ClampedSubstract = [&](const SCEV *X, const SCEV *Y) { + assert(SE.isKnownNonNegative(X) && + "We can only substract from values in [0; SINT_MAX]!"); + if (IsLatchSigned) { + // X is a number from signed range, Y is interpreted as signed. + // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only + // thing we should care about is that we didn't cross SINT_MAX. + // So, if Y is positive, we substract Y safely. + // Rule 1: Y > 0 ---> Y. + // If 0 <= -Y <= (SINT_MAX - X), we substract Y safely. + // Rule 2: Y >=s (X - SINT_MAX) ---> Y. + // If 0 <= (SINT_MAX - X) < -Y, we can only substract (X - SINT_MAX). + // Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX). + // It gives us smax(Y, X - SINT_MAX) to substract in all cases. + const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax); + return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax), + SCEV::FlagNSW); + } else + // X is a number from unsigned range, Y is interpreted as signed. + // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only + // thing we should care about is that we didn't cross zero. + // So, if Y is negative, we substract Y safely. + // Rule 1: Y <s 0 ---> Y. + // If 0 <= Y <= X, we substract Y safely. + // Rule 2: Y <=s X ---> Y. + // If 0 <= X < Y, we should stop at 0 and can only substract X. + // Rule 3: Y >s X ---> X. + // It gives us smin(X, Y) to substract in all cases. + return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW); + }; const SCEV *M = SE.getMinusSCEV(C, A); - - const SCEV *Begin = SE.getNegativeSCEV(M); - const SCEV *UpperLimit = nullptr; + const SCEV *Zero = SE.getZero(M->getType()); + const SCEV *Begin = ClampedSubstract(Zero, M); + const SCEV *L = nullptr; // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". // We can potentially do much better here. - if (Value *V = getLength()) { - UpperLimit = SE.getSCEV(V); - } else { + if (const SCEV *EndLimit = getEnd()) + L = EndLimit; + else { assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); - unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); - UpperLimit = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + L = SIntMax; } - - const SCEV *End = SE.getMinusSCEV(UpperLimit, M); + const SCEV *End = ClampedSubstract(L, M); return InductiveRangeCheck::Range(Begin, End); } static Optional<InductiveRangeCheck::Range> -IntersectRange(ScalarEvolution &SE, - const Optional<InductiveRangeCheck::Range> &R1, - const InductiveRangeCheck::Range &R2) { +IntersectSignedRange(ScalarEvolution &SE, + const Optional<InductiveRangeCheck::Range> &R1, + const InductiveRangeCheck::Range &R2) { + if (R2.isEmpty(SE, /* IsSigned */ true)) + return None; if (!R1.hasValue()) return R2; auto &R1Value = R1.getValue(); + // We never return empty ranges from this function, and R1 is supposed to be + // a result of intersection. Thus, R1 is never empty. + assert(!R1Value.isEmpty(SE, /* IsSigned */ true) && + "We should never have empty R1!"); // TODO: we could widen the smaller range and have this work; but for now we // bail out to keep things simple. @@ -1544,7 +1721,40 @@ IntersectRange(ScalarEvolution &SE, const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); - return InductiveRangeCheck::Range(NewBegin, NewEnd); + // If the resulting range is empty, just return None. + auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); + if (Ret.isEmpty(SE, /* IsSigned */ true)) + return None; + return Ret; +} + +static Optional<InductiveRangeCheck::Range> +IntersectUnsignedRange(ScalarEvolution &SE, + const Optional<InductiveRangeCheck::Range> &R1, + const InductiveRangeCheck::Range &R2) { + if (R2.isEmpty(SE, /* IsSigned */ false)) + return None; + if (!R1.hasValue()) + return R2; + auto &R1Value = R1.getValue(); + // We never return empty ranges from this function, and R1 is supposed to be + // a result of intersection. Thus, R1 is never empty. + assert(!R1Value.isEmpty(SE, /* IsSigned */ false) && + "We should never have empty R1!"); + + // TODO: we could widen the smaller range and have this work; but for now we + // bail out to keep things simple. + if (R1Value.getType() != R2.getType()) + return None; + + const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin()); + const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd()); + + // If the resulting range is empty, just return None. + auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); + if (Ret.isEmpty(SE, /* IsSigned */ false)) + return None; + return Ret; } bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { @@ -1598,24 +1808,31 @@ bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { return false; } LoopStructure LS = MaybeLoopStructure.getValue(); - bool Increasing = LS.IndVarIncreasing; - const SCEV *MinusOne = - SE.getConstant(LS.IndVarNext->getType(), Increasing ? -1 : 1, true); const SCEVAddRecExpr *IndVar = - cast<SCEVAddRecExpr>(SE.getAddExpr(SE.getSCEV(LS.IndVarNext), MinusOne)); + cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep))); Optional<InductiveRangeCheck::Range> SafeIterRange; Instruction *ExprInsertPt = Preheader->getTerminator(); SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; + // Basing on the type of latch predicate, we interpret the IV iteration range + // as signed or unsigned range. We use different min/max functions (signed or + // unsigned) when intersecting this range with safe iteration ranges implied + // by range checks. + auto IntersectRange = + LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange; IRBuilder<> B(ExprInsertPt); for (InductiveRangeCheck &IRC : RangeChecks) { - auto Result = IRC.computeSafeIterationSpace(SE, IndVar); + auto Result = IRC.computeSafeIterationSpace(SE, IndVar, + LS.IsSignedPredicate); if (Result.hasValue()) { auto MaybeSafeIterRange = IntersectRange(SE, SafeIterRange, Result.getValue()); if (MaybeSafeIterRange.hasValue()) { + assert( + !MaybeSafeIterRange.getValue().isEmpty(SE, LS.IsSignedPredicate) && + "We should never return empty ranges!"); RangeChecksToEliminate.push_back(IRC); SafeIterRange = MaybeSafeIterRange.getValue(); } diff --git a/lib/Transforms/Scalar/InferAddressSpaces.cpp b/lib/Transforms/Scalar/InferAddressSpaces.cpp index 89b28f0aeee6..7d66c0f73821 100644 --- a/lib/Transforms/Scalar/InferAddressSpaces.cpp +++ b/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -1,4 +1,4 @@ -//===-- NVPTXInferAddressSpace.cpp - ---------------------*- C++ -*-===// +//===- InferAddressSpace.cpp - --------------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -89,26 +89,54 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" +#include <cassert> +#include <iterator> +#include <limits> +#include <utility> +#include <vector> #define DEBUG_TYPE "infer-address-spaces" using namespace llvm; +static const unsigned UninitializedAddressSpace = + std::numeric_limits<unsigned>::max(); + namespace { -static const unsigned UninitializedAddressSpace = ~0u; using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; @@ -146,10 +174,9 @@ private: // Changes the flat address expressions in function F to point to specific // address spaces if InferredAddrSpace says so. Postorder is the postorder of // all flat expressions in the use-def graph of function F. - bool - rewriteWithNewAddressSpaces(ArrayRef<WeakTrackingVH> Postorder, - const ValueToAddrSpaceMapTy &InferredAddrSpace, - Function *F) const; + bool rewriteWithNewAddressSpaces( + const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, + const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const; void appendsFlatAddressExpressionToPostorderStack( Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, @@ -170,13 +197,16 @@ private: SmallVectorImpl<const Use *> *UndefUsesToFix) const; unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const; }; + } // end anonymous namespace char InferAddressSpaces::ID = 0; namespace llvm { + void initializeInferAddressSpacesPass(PassRegistry &); -} + +} // end namespace llvm INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", false, false) @@ -454,11 +484,10 @@ static Value *cloneInstructionWithNewAddressSpace( NewGEP->setIsInBounds(GEP->isInBounds()); return NewGEP; } - case Instruction::Select: { + case Instruction::Select: assert(I->getType()->isPointerTy()); return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], NewPointerOperands[2], "", nullptr, I); - } default: llvm_unreachable("Unexpected opcode"); } @@ -600,7 +629,7 @@ bool InferAddressSpaces::runOnFunction(Function &F) { // Changes the address spaces of the flat address expressions who are inferred // to point to a specific address space. - return rewriteWithNewAddressSpaces(Postorder, InferredAddrSpace, &F); + return rewriteWithNewAddressSpaces(TTI, Postorder, InferredAddrSpace, &F); } // Constants need to be tracked through RAUW to handle cases with nested @@ -708,24 +737,32 @@ Optional<unsigned> InferAddressSpaces::updateAddressSpace( /// \p returns true if \p U is the pointer operand of a memory instruction with /// a single pointer operand that can have its address space changed by simply -/// mutating the use to a new value. -static bool isSimplePointerUseValidToReplace(Use &U) { +/// mutating the use to a new value. If the memory instruction is volatile, +/// return true only if the target allows the memory instruction to be volatile +/// in the new address space. +static bool isSimplePointerUseValidToReplace(const TargetTransformInfo &TTI, + Use &U, unsigned AddrSpace) { User *Inst = U.getUser(); unsigned OpNo = U.getOperandNo(); + bool VolatileIsAllowed = false; + if (auto *I = dyn_cast<Instruction>(Inst)) + VolatileIsAllowed = TTI.hasVolatileVariant(I, AddrSpace); if (auto *LI = dyn_cast<LoadInst>(Inst)) - return OpNo == LoadInst::getPointerOperandIndex() && !LI->isVolatile(); + return OpNo == LoadInst::getPointerOperandIndex() && + (VolatileIsAllowed || !LI->isVolatile()); if (auto *SI = dyn_cast<StoreInst>(Inst)) - return OpNo == StoreInst::getPointerOperandIndex() && !SI->isVolatile(); + return OpNo == StoreInst::getPointerOperandIndex() && + (VolatileIsAllowed || !SI->isVolatile()); if (auto *RMW = dyn_cast<AtomicRMWInst>(Inst)) - return OpNo == AtomicRMWInst::getPointerOperandIndex() && !RMW->isVolatile(); + return OpNo == AtomicRMWInst::getPointerOperandIndex() && + (VolatileIsAllowed || !RMW->isVolatile()); - if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { + if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) return OpNo == AtomicCmpXchgInst::getPointerOperandIndex() && - !CmpX->isVolatile(); - } + (VolatileIsAllowed || !CmpX->isVolatile()); return false; } @@ -818,7 +855,7 @@ static Value::use_iterator skipToNextUser(Value::use_iterator I, } bool InferAddressSpaces::rewriteWithNewAddressSpaces( - ArrayRef<WeakTrackingVH> Postorder, + const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const { // For each address expression to be modified, creates a clone of it with its // pointer operands converted to the new address space. Since the pointer @@ -878,7 +915,8 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces( // to the next instruction. I = skipToNextUser(I, E); - if (isSimplePointerUseValidToReplace(U)) { + if (isSimplePointerUseValidToReplace( + TTI, U, V->getType()->getPointerAddressSpace())) { // If V is used as the pointer operand of a compatible memory operation, // sets the pointer operand to NewV. This replacement does not change // the element type, so the resultant load/store is still valid. @@ -933,6 +971,11 @@ bool InferAddressSpaces::rewriteWithNewAddressSpaces( if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) { unsigned NewAS = NewV->getType()->getPointerAddressSpace(); if (ASC->getDestAddressSpace() == NewAS) { + if (ASC->getType()->getPointerElementType() != + NewV->getType()->getPointerElementType()) { + NewV = CastInst::Create(Instruction::BitCast, NewV, + ASC->getType(), "", ASC); + } ASC->replaceAllUsesWith(NewV); DeadInstructions.push_back(ASC); continue; diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp index dc9143bebc45..6b0377e0ecb3 100644 --- a/lib/Transforms/Scalar/JumpThreading.cpp +++ b/lib/Transforms/Scalar/JumpThreading.cpp @@ -14,25 +14,50 @@ #include "llvm/Transforms/Scalar/JumpThreading.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/BlockFrequencyInfoImpl.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -41,8 +66,15 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> #include <memory> +#include <utility> + using namespace llvm; using namespace jumpthreading; @@ -70,6 +102,7 @@ static cl::opt<bool> PrintLVIAfterJumpThreading( cl::Hidden); namespace { + /// This pass performs 'jump threading', which looks at blocks that have /// multiple predecessors and multiple successors. If one or more of the /// predecessors of the block can be proven to always jump to one of the @@ -85,12 +118,12 @@ namespace { /// /// In this case, the unconditional branch at the end of the first if can be /// revectored to the false side of the second if. - /// class JumpThreading : public FunctionPass { JumpThreadingPass Impl; public: static char ID; // Pass identification + JumpThreading(int T = -1) : FunctionPass(ID), Impl(T) { initializeJumpThreadingPass(*PassRegistry::getPassRegistry()); } @@ -108,9 +141,11 @@ namespace { void releaseMemory() override { Impl.releaseMemory(); } }; -} + +} // end anonymous namespace char JumpThreading::ID = 0; + INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading", "Jump Threading", false, false) INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) @@ -120,14 +155,125 @@ INITIALIZE_PASS_END(JumpThreading, "jump-threading", "Jump Threading", false, false) // Public interface to the Jump Threading pass -FunctionPass *llvm::createJumpThreadingPass(int Threshold) { return new JumpThreading(Threshold); } +FunctionPass *llvm::createJumpThreadingPass(int Threshold) { + return new JumpThreading(Threshold); +} JumpThreadingPass::JumpThreadingPass(int T) { BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); } -/// runOnFunction - Top level algorithm. -/// +// Update branch probability information according to conditional +// branch probablity. This is usually made possible for cloned branches +// in inline instances by the context specific profile in the caller. +// For instance, +// +// [Block PredBB] +// [Branch PredBr] +// if (t) { +// Block A; +// } else { +// Block B; +// } +// +// [Block BB] +// cond = PN([true, %A], [..., %B]); // PHI node +// [Branch CondBr] +// if (cond) { +// ... // P(cond == true) = 1% +// } +// +// Here we know that when block A is taken, cond must be true, which means +// P(cond == true | A) = 1 +// +// Given that P(cond == true) = P(cond == true | A) * P(A) + +// P(cond == true | B) * P(B) +// we get: +// P(cond == true ) = P(A) + P(cond == true | B) * P(B) +// +// which gives us: +// P(A) is less than P(cond == true), i.e. +// P(t == true) <= P(cond == true) +// +// In other words, if we know P(cond == true) is unlikely, we know +// that P(t == true) is also unlikely. +// +static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { + BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); + if (!CondBr) + return; + + BranchProbability BP; + uint64_t TrueWeight, FalseWeight; + if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight)) + return; + + // Returns the outgoing edge of the dominating predecessor block + // that leads to the PhiNode's incoming block: + auto GetPredOutEdge = + [](BasicBlock *IncomingBB, + BasicBlock *PhiBB) -> std::pair<BasicBlock *, BasicBlock *> { + auto *PredBB = IncomingBB; + auto *SuccBB = PhiBB; + while (true) { + BranchInst *PredBr = dyn_cast<BranchInst>(PredBB->getTerminator()); + if (PredBr && PredBr->isConditional()) + return {PredBB, SuccBB}; + auto *SinglePredBB = PredBB->getSinglePredecessor(); + if (!SinglePredBB) + return {nullptr, nullptr}; + SuccBB = PredBB; + PredBB = SinglePredBB; + } + }; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *PhiOpnd = PN->getIncomingValue(i); + ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd); + + if (!CI || !CI->getType()->isIntegerTy(1)) + continue; + + BP = (CI->isOne() ? BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight) + : BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight)); + + auto PredOutEdge = GetPredOutEdge(PN->getIncomingBlock(i), BB); + if (!PredOutEdge.first) + return; + + BasicBlock *PredBB = PredOutEdge.first; + BranchInst *PredBr = cast<BranchInst>(PredBB->getTerminator()); + + uint64_t PredTrueWeight, PredFalseWeight; + // FIXME: We currently only set the profile data when it is missing. + // With PGO, this can be used to refine even existing profile data with + // context information. This needs to be done after more performance + // testing. + if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight)) + continue; + + // We can not infer anything useful when BP >= 50%, because BP is the + // upper bound probability value. + if (BP >= BranchProbability(50, 100)) + continue; + + SmallVector<uint32_t, 2> Weights; + if (PredBr->getSuccessor(0) == PredOutEdge.second) { + Weights.push_back(BP.getNumerator()); + Weights.push_back(BP.getCompl().getNumerator()); + } else { + Weights.push_back(BP.getCompl().getNumerator()); + Weights.push_back(BP.getNumerator()); + } + PredBr->setMetadata(LLVMContext::MD_prof, + MDBuilder(PredBr->getParent()->getContext()) + .createBranchWeights(Weights)); + } +} + +/// runOnFunction - Toplevel algorithm. bool JumpThreading::runOnFunction(Function &F) { if (skipFunction(F)) return false; @@ -155,7 +301,6 @@ bool JumpThreading::runOnFunction(Function &F) { PreservedAnalyses JumpThreadingPass::run(Function &F, FunctionAnalysisManager &AM) { - auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); auto &LVI = AM.getResult<LazyValueAnalysis>(F); auto &AA = AM.getResult<AAManager>(F); @@ -184,7 +329,6 @@ bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, bool HasProfileData_, std::unique_ptr<BlockFrequencyInfo> BFI_, std::unique_ptr<BranchProbabilityInfo> BPI_) { - DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); TLI = TLI_; LVI = LVI_; @@ -384,7 +528,6 @@ static unsigned getJumpThreadDuplicationCost(BasicBlock *BB, /// within the loop (forming a nested loop). This simple analysis is not rich /// enough to track all of these properties and keep it up-to-date as the CFG /// mutates, so we don't allow any of these transformations. -/// void JumpThreadingPass::FindLoopHeaders(Function &F) { SmallVector<std::pair<const BasicBlock*,const BasicBlock*>, 32> Edges; FindFunctionBackedges(F, Edges); @@ -418,7 +561,6 @@ static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) { /// BB in the result vector. /// /// This returns true if there were any known values. -/// bool JumpThreadingPass::ComputeValueKnownInPredecessors( Value *V, BasicBlock *BB, PredValueInfo &Result, ConstantPreference Preference, Instruction *CxtI) { @@ -507,8 +649,6 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( return true; } - PredValueInfoTy LHSVals, RHSVals; - // Handle some boolean conditions. if (I->getType()->getPrimitiveSizeInBits() == 1) { assert(Preference == WantInteger && "One-bit non-integer type?"); @@ -516,6 +656,8 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( // X & false -> false if (I->getOpcode() == Instruction::Or || I->getOpcode() == Instruction::And) { + PredValueInfoTy LHSVals, RHSVals; + ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, WantInteger, CxtI); ComputeValueKnownInPredecessors(I->getOperand(1), BB, RHSVals, @@ -655,6 +797,7 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( // x as a live-in. { using namespace PatternMatch; + Value *AddLHS; ConstantInt *AddConst; if (isa<ConstantInt>(CmpConst) && @@ -751,14 +894,11 @@ bool JumpThreadingPass::ComputeValueKnownInPredecessors( return !Result.empty(); } - - /// GetBestDestForBranchOnUndef - If we determine that the specified block ends /// in an undefined jump, decide which block is best to revector to. /// /// Since we can pick an arbitrary destination, we pick the successor with the /// fewest predecessors. This should reduce the in-degree of the others. -/// static unsigned GetBestDestForJumpOnUndef(BasicBlock *BB) { TerminatorInst *BBTerm = BB->getTerminator(); unsigned MinSucc = 0; @@ -979,7 +1119,6 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { // for loads that are used by a switch or by the condition for the branch. If // we see one, check to see if it's partially redundant. If so, insert a PHI // which can then be used to thread the values. - // Value *SimplifyValue = CondInst; if (CmpInst *CondCmp = dyn_cast<CmpInst>(SimplifyValue)) if (isa<Constant>(CondCmp->getOperand(1))) @@ -991,10 +1130,14 @@ bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { if (SimplifyPartiallyRedundantLoad(LI)) return true; + // Before threading, try to propagate profile data backwards: + if (PHINode *PN = dyn_cast<PHINode>(CondInst)) + if (PN->getParent() == BB && isa<BranchInst>(BB->getTerminator())) + updatePredecessorProfileMetadata(PN, BB); + // Handle a variety of cases where we are branching on something derived from // a PHI node in the current block. If we can prove that any predecessors // compute a predictable value based on a PHI node, thread those predecessors. - // if (ProcessThreadableEdges(CondInst, BB, Preference, Terminator)) return true; @@ -1036,9 +1179,9 @@ bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) { if (PBI->getSuccessor(0) != CurrentBB && PBI->getSuccessor(1) != CurrentBB) return false; - bool FalseDest = PBI->getSuccessor(1) == CurrentBB; + bool CondIsTrue = PBI->getSuccessor(0) == CurrentBB; Optional<bool> Implication = - isImpliedCondition(PBI->getCondition(), Cond, DL, FalseDest); + isImpliedCondition(PBI->getCondition(), Cond, DL, CondIsTrue); if (Implication) { BI->getSuccessor(*Implication ? 1 : 0)->removePredecessor(BB); BranchInst::Create(BI->getSuccessor(*Implication ? 0 : 1), BI); @@ -1124,7 +1267,9 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { LI->getAAMetadata(AATags); SmallPtrSet<BasicBlock*, 8> PredsScanned; - typedef SmallVector<std::pair<BasicBlock*, Value*>, 8> AvailablePredsTy; + + using AvailablePredsTy = SmallVector<std::pair<BasicBlock *, Value *>, 8>; + AvailablePredsTy AvailablePreds; BasicBlock *OneUnavailablePred = nullptr; SmallVector<LoadInst*, 8> CSELoads; @@ -1283,8 +1428,8 @@ bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { /// the list. static BasicBlock * FindMostPopularDest(BasicBlock *BB, - const SmallVectorImpl<std::pair<BasicBlock*, - BasicBlock*> > &PredToDestList) { + const SmallVectorImpl<std::pair<BasicBlock *, + BasicBlock *>> &PredToDestList) { assert(!PredToDestList.empty()); // Determine popularity. If there are multiple possible destinations, we @@ -1502,7 +1647,6 @@ bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, /// ProcessBranchOnPHI - We have an otherwise unthreadable conditional branch on /// a PHI node in the current block. See if there are any simplifications we /// can do based on inputs to the phi node. -/// bool JumpThreadingPass::ProcessBranchOnPHI(PHINode *PN) { BasicBlock *BB = PN->getParent(); @@ -1532,7 +1676,6 @@ bool JumpThreadingPass::ProcessBranchOnPHI(PHINode *PN) { /// ProcessBranchOnXOR - We have an otherwise unthreadable conditional branch on /// a xor instruction in the current block. See if there are any /// simplifications we can do based on inputs to the xor. -/// bool JumpThreadingPass::ProcessBranchOnXOR(BinaryOperator *BO) { BasicBlock *BB = BO->getParent(); @@ -1637,7 +1780,6 @@ bool JumpThreadingPass::ProcessBranchOnXOR(BinaryOperator *BO) { return DuplicateCondBranchOnPHIIntoPred(BB, BlocksToFoldInto); } - /// AddPHINodeEntriesForMappedBlock - We're adding 'NewPred' as a new /// predecessor to the PHIBB block. If it has PHI nodes, add entries for /// NewPred using the entries from OldPred (suitably mapped). @@ -1677,10 +1819,15 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, // If threading this would thread across a loop header, don't thread the edge. // See the comments above FindLoopHeaders for justifications and caveats. - if (LoopHeaders.count(BB)) { - DEBUG(dbgs() << " Not threading across loop header BB '" << BB->getName() - << "' to dest BB '" << SuccBB->getName() - << "' - it might create an irreducible loop!\n"); + if (LoopHeaders.count(BB) || LoopHeaders.count(SuccBB)) { + DEBUG({ + bool BBIsHeader = LoopHeaders.count(BB); + bool SuccIsHeader = LoopHeaders.count(SuccBB); + dbgs() << " Not threading across " + << (BBIsHeader ? "loop header BB '" : "block BB '") << BB->getName() + << "' to dest " << (SuccIsHeader ? "loop header BB '" : "block BB '") + << SuccBB->getName() << "' - it might create an irreducible loop!\n"; + }); return false; } @@ -1795,7 +1942,6 @@ bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, DEBUG(dbgs() << "\n"); } - // Ok, NewBB is good to go. Update the terminator of PredBB to jump to // NewBB instead of BB. This eliminates predecessors from BB, which requires // us to simplify any PHI nodes in BB. @@ -2194,7 +2340,7 @@ bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { /// %p = phi [0, %bb1], [1, %bb2], [0, %bb3], [1, %bb4], ... /// %c = cmp %p, 0 /// %s = select %c, trueval, falseval -// +/// /// And expand the select into a branch structure. This later enables /// jump-threading over bb in this pass. /// @@ -2280,6 +2426,7 @@ bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { /// guard is then threaded to one of them. bool JumpThreadingPass::ProcessGuards(BasicBlock *BB) { using namespace PatternMatch; + // We only want to deal with two predecessors. BasicBlock *Pred1, *Pred2; auto PI = pred_begin(BB), PE = pred_end(BB); @@ -2331,8 +2478,7 @@ bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard, TrueDestIsSafe = true; else { // False dest is safe if !BranchCond => GuardCond. - Impl = - isImpliedCondition(BranchCond, GuardCond, DL, /* InvertAPred */ true); + Impl = isImpliedCondition(BranchCond, GuardCond, DL, /* LHSIsTrue */ false); if (Impl && *Impl) FalseDestIsSafe = true; } diff --git a/lib/Transforms/Scalar/LICM.cpp b/lib/Transforms/Scalar/LICM.cpp index 37b9c4b1094e..4ea935793b80 100644 --- a/lib/Transforms/Scalar/LICM.cpp +++ b/lib/Transforms/Scalar/LICM.cpp @@ -42,7 +42,8 @@ #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" @@ -62,6 +63,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/SSAUpdater.h" @@ -88,15 +90,15 @@ static cl::opt<uint32_t> MaxNumUsesTraversed( "invariance in loop using invariant start (default = 8)")); static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); -static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo); +static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + TargetTransformInfo *TTI, bool &FreeInLoop); static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE); -static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, - const Loop *CurLoop, AliasSetTracker *CurAST, - const LoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE); +static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, + const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, bool FreeInLoop); static bool isSafeToExecuteUnconditionally(Instruction &Inst, const DominatorTree *DT, const Loop *CurLoop, @@ -114,7 +116,8 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, namespace { struct LoopInvariantCodeMotion { bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, - TargetLibraryInfo *TLI, ScalarEvolution *SE, + TargetLibraryInfo *TLI, TargetTransformInfo *TTI, + ScalarEvolution *SE, MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool DeleteAST); DenseMap<Loop *, AliasSetTracker *> &getLoopToAliasSetMap() { @@ -146,6 +149,9 @@ struct LegacyLICMPass : public LoopPass { } auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + MemorySSA *MSSA = EnableMSSALoopDependency + ? (&getAnalysis<MemorySSAWrapperPass>().getMSSA()) + : nullptr; // For the old PM, we can't use OptimizationRemarkEmitter as an analysis // pass. Function analyses need to be preserved across loop transformations // but ORE cannot be preserved (see comment before the pass definition). @@ -155,7 +161,9 @@ struct LegacyLICMPass : public LoopPass { &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), - SE ? &SE->getSE() : nullptr, &ORE, false); + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()), + SE ? &SE->getSE() : nullptr, MSSA, &ORE, false); } /// This transformation requires natural loop information & requires that @@ -164,6 +172,9 @@ struct LegacyLICMPass : public LoopPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.setPreservesCFG(); AU.addRequired<TargetLibraryInfoWrapperPass>(); + if (EnableMSSALoopDependency) + AU.addRequired<MemorySSAWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); getLoopAnalysisUsage(AU); } @@ -189,7 +200,7 @@ private: /// Simple Analysis hook. Delete loop L from alias set map. void deleteAnalysisLoop(Loop *L) override; }; -} +} // namespace PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, LPMUpdater &) { @@ -204,7 +215,8 @@ PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, "cached at a higher level"); LoopInvariantCodeMotion LICM; - if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.SE, ORE, true)) + if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.TTI, &AR.SE, + AR.MSSA, ORE, true)) return PreservedAnalyses::all(); auto PA = getLoopPassPreservedAnalyses(); @@ -217,6 +229,8 @@ INITIALIZE_PASS_BEGIN(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false, false) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) INITIALIZE_PASS_END(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false, false) @@ -228,12 +242,10 @@ Pass *llvm::createLICMPass() { return new LegacyLICMPass(); } /// We should delete AST for inner loops in the new pass manager to avoid /// memory leak. /// -bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, - LoopInfo *LI, DominatorTree *DT, - TargetLibraryInfo *TLI, - ScalarEvolution *SE, - OptimizationRemarkEmitter *ORE, - bool DeleteAST) { +bool LoopInvariantCodeMotion::runOnLoop( + Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, + TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE, + MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool DeleteAST) { bool Changed = false; assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); @@ -258,7 +270,7 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, // instructions, we perform another pass to hoist them out of the loop. // if (L->hasDedicatedExits()) - Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, + Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L, CurAST, &SafetyInfo, ORE); if (Preheader) Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, @@ -292,10 +304,26 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, bool Promoted = false; // Loop over all of the alias sets in the tracker object. - for (AliasSet &AS : *CurAST) - Promoted |= - promoteLoopAccessesToScalars(AS, ExitBlocks, InsertPts, PIC, LI, DT, - TLI, L, CurAST, &SafetyInfo, ORE); + for (AliasSet &AS : *CurAST) { + // We can promote this alias set if it has a store, if it is a "Must" + // alias set, if the pointer is loop invariant, and if we are not + // eliminating any volatile loads or stores. + if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || + AS.isVolatile() || !L->isLoopInvariant(AS.begin()->getValue())) + continue; + + assert( + !AS.empty() && + "Must alias set should have at least one pointer element in it!"); + + SmallSetVector<Value *, 8> PointerMustAliases; + for (const auto &ASI : AS) + PointerMustAliases.insert(ASI.getValue()); + + Promoted |= promoteLoopAccessesToScalars(PointerMustAliases, ExitBlocks, + InsertPts, PIC, LI, DT, TLI, L, + CurAST, &SafetyInfo, ORE); + } // Once we have promoted values across the loop body we have to // recursively reform LCSSA as any nested loop may now have values defined @@ -335,7 +363,8 @@ bool LoopInvariantCodeMotion::runOnLoop(Loop *L, AliasAnalysis *AA, /// definitions, allowing us to sink a loop body in one pass without iteration. /// bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, - DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, + DominatorTree *DT, TargetLibraryInfo *TLI, + TargetTransformInfo *TTI, Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { @@ -344,46 +373,50 @@ bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && "Unexpected input to sinkRegion"); - BasicBlock *BB = N->getBlock(); - // If this subregion is not in the top level loop at all, exit. - if (!CurLoop->contains(BB)) - return false; + // We want to visit children before parents. We will enque all the parents + // before their children in the worklist and process the worklist in reverse + // order. + SmallVector<DomTreeNode *, 16> Worklist = collectChildrenInLoop(N, CurLoop); - // We are processing blocks in reverse dfo, so process children first. bool Changed = false; - const std::vector<DomTreeNode *> &Children = N->getChildren(); - for (DomTreeNode *Child : Children) - Changed |= - sinkRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo, ORE); - - // Only need to process the contents of this block if it is not part of a - // subloop (which would already have been processed). - if (inSubLoop(BB, CurLoop, LI)) - return Changed; + for (DomTreeNode *DTN : reverse(Worklist)) { + BasicBlock *BB = DTN->getBlock(); + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (inSubLoop(BB, CurLoop, LI)) + continue; - for (BasicBlock::iterator II = BB->end(); II != BB->begin();) { - Instruction &I = *--II; + for (BasicBlock::iterator II = BB->end(); II != BB->begin();) { + Instruction &I = *--II; - // If the instruction is dead, we would try to sink it because it isn't used - // in the loop, instead, just delete it. - if (isInstructionTriviallyDead(&I, TLI)) { - DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); - ++II; - CurAST->deleteValue(&I); - I.eraseFromParent(); - Changed = true; - continue; - } + // If the instruction is dead, we would try to sink it because it isn't + // used in the loop, instead, just delete it. + if (isInstructionTriviallyDead(&I, TLI)) { + DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); + ++II; + CurAST->deleteValue(&I); + I.eraseFromParent(); + Changed = true; + continue; + } - // Check to see if we can sink this instruction to the exit blocks - // of the loop. We can do this if the all users of the instruction are - // outside of the loop. In this case, it doesn't even matter if the - // operands of the instruction are loop invariant. - // - if (isNotUsedInLoop(I, CurLoop, SafetyInfo) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE)) { - ++II; - Changed |= sink(I, LI, DT, CurLoop, CurAST, SafetyInfo, ORE); + // Check to see if we can sink this instruction to the exit blocks + // of the loop. We can do this if the all users of the instruction are + // outside of the loop. In this case, it doesn't even matter if the + // operands of the instruction are loop invariant. + // + bool FreeInLoop = false; + if (isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE)) { + if (sink(I, LI, DT, CurLoop, SafetyInfo, ORE, FreeInLoop)) { + if (!FreeInLoop) { + ++II; + CurAST->deleteValue(&I); + I.eraseFromParent(); + } + Changed = true; + } + } } } return Changed; @@ -403,73 +436,70 @@ bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && "Unexpected input to hoistRegion"); - BasicBlock *BB = N->getBlock(); - - // If this subregion is not in the top level loop at all, exit. - if (!CurLoop->contains(BB)) - return false; + // We want to visit parents before children. We will enque all the parents + // before their children in the worklist and process the worklist in order. + SmallVector<DomTreeNode *, 16> Worklist = collectChildrenInLoop(N, CurLoop); - // Only need to process the contents of this block if it is not part of a - // subloop (which would already have been processed). bool Changed = false; - if (!inSubLoop(BB, CurLoop, LI)) - for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { - Instruction &I = *II++; - // Try constant folding this instruction. If all the operands are - // constants, it is technically hoistable, but it would be better to just - // fold it. - if (Constant *C = ConstantFoldInstruction( - &I, I.getModule()->getDataLayout(), TLI)) { - DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C << '\n'); - CurAST->copyValue(&I, C); - I.replaceAllUsesWith(C); - if (isInstructionTriviallyDead(&I, TLI)) { - CurAST->deleteValue(&I); - I.eraseFromParent(); + for (DomTreeNode *DTN : Worklist) { + BasicBlock *BB = DTN->getBlock(); + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (!inSubLoop(BB, CurLoop, LI)) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { + Instruction &I = *II++; + // Try constant folding this instruction. If all the operands are + // constants, it is technically hoistable, but it would be better to + // just fold it. + if (Constant *C = ConstantFoldInstruction( + &I, I.getModule()->getDataLayout(), TLI)) { + DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C << '\n'); + CurAST->copyValue(&I, C); + I.replaceAllUsesWith(C); + if (isInstructionTriviallyDead(&I, TLI)) { + CurAST->deleteValue(&I); + I.eraseFromParent(); + } + Changed = true; + continue; } - Changed = true; - continue; - } - // Attempt to remove floating point division out of the loop by converting - // it to a reciprocal multiplication. - if (I.getOpcode() == Instruction::FDiv && - CurLoop->isLoopInvariant(I.getOperand(1)) && - I.hasAllowReciprocal()) { - auto Divisor = I.getOperand(1); - auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); - auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); - ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); - ReciprocalDivisor->insertBefore(&I); - - auto Product = BinaryOperator::CreateFMul(I.getOperand(0), - ReciprocalDivisor); - Product->setFastMathFlags(I.getFastMathFlags()); - Product->insertAfter(&I); - I.replaceAllUsesWith(Product); - I.eraseFromParent(); + // Attempt to remove floating point division out of the loop by + // converting it to a reciprocal multiplication. + if (I.getOpcode() == Instruction::FDiv && + CurLoop->isLoopInvariant(I.getOperand(1)) && + I.hasAllowReciprocal()) { + auto Divisor = I.getOperand(1); + auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); + auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); + ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); + ReciprocalDivisor->insertBefore(&I); + + auto Product = + BinaryOperator::CreateFMul(I.getOperand(0), ReciprocalDivisor); + Product->setFastMathFlags(I.getFastMathFlags()); + Product->insertAfter(&I); + I.replaceAllUsesWith(Product); + I.eraseFromParent(); - hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); - Changed = true; - continue; - } + hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); + Changed = true; + continue; + } - // Try hoisting the instruction out to the preheader. We can only do this - // if all of the operands of the instruction are loop invariant and if it - // is safe to hoist the instruction. - // - if (CurLoop->hasLoopInvariantOperands(&I) && - canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE) && - isSafeToExecuteUnconditionally( - I, DT, CurLoop, SafetyInfo, ORE, - CurLoop->getLoopPreheader()->getTerminator())) - Changed |= hoist(I, DT, CurLoop, SafetyInfo, ORE); - } + // Try hoisting the instruction out to the preheader. We can only do + // this if all of the operands of the instruction are loop invariant and + // if it is safe to hoist the instruction. + // + if (CurLoop->hasLoopInvariantOperands(&I) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE) && + isSafeToExecuteUnconditionally( + I, DT, CurLoop, SafetyInfo, ORE, + CurLoop->getLoopPreheader()->getTerminator())) + Changed |= hoist(I, DT, CurLoop, SafetyInfo, ORE); + } + } - const std::vector<DomTreeNode *> &Children = N->getChildren(); - for (DomTreeNode *Child : Children) - Changed |= - hoistRegion(Child, AA, LI, DT, TLI, CurLoop, CurAST, SafetyInfo, ORE); return Changed; } @@ -492,7 +522,8 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { // Iterate over loop instructions and compute safety info. // Skip header as it has been computed and stored in HeaderMayThrow. // The first block in loopinfo.Blocks is guaranteed to be the header. - assert(Header == *CurLoop->getBlocks().begin() && "First block must be header"); + assert(Header == *CurLoop->getBlocks().begin() && + "First block must be header"); for (Loop::block_iterator BB = std::next(CurLoop->block_begin()), BBE = CurLoop->block_end(); (BB != BBE) && !SafetyInfo->MayThrow; ++BB) @@ -510,9 +541,9 @@ void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { } // Return true if LI is invariant within scope of the loop. LI is invariant if -// CurLoop is dominated by an invariant.start representing the same memory location -// and size as the memory location LI loads from, and also the invariant.start -// has no uses. +// CurLoop is dominated by an invariant.start representing the same memory +// location and size as the memory location LI loads from, and also the +// invariant.start has no uses. static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, Loop *CurLoop) { Value *Addr = LI->getOperand(0); @@ -566,10 +597,13 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, OptimizationRemarkEmitter *ORE) { + // SafetyInfo is nullptr if we are checking for sinking from preheader to + // loop body. + const bool SinkingToLoopBody = !SafetyInfo; // Loads have extra constraints we have to verify before we can hoist them. if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { if (!LI->isUnordered()) - return false; // Don't hoist volatile/atomic loads! + return false; // Don't sink/hoist volatile or ordered atomic loads! // Loads from constant memory are always safe to move, even if they end up // in the same alias set as something that ends up being modified. @@ -578,6 +612,9 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, if (LI->getMetadata(LLVMContext::MD_invariant_load)) return true; + if (LI->isAtomic() && SinkingToLoopBody) + return false; // Don't sink unordered atomic loads to loop body. + // This checks for an invariant.start dominating the load. if (isLoadInvariantInLoop(LI, DT, CurLoop)) return true; @@ -595,10 +632,12 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, // Check loop-invariant address because this may also be a sinkable load // whose address is not necessarily loop-invariant. if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) - ORE->emit(OptimizationRemarkMissed( - DEBUG_TYPE, "LoadWithLoopInvariantAddressInvalidated", LI) - << "failed to move load with loop-invariant address " - "because the loop may invalidate its value"); + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "LoadWithLoopInvariantAddressInvalidated", LI) + << "failed to move load with loop-invariant address " + "because the loop may invalidate its value"; + }); return !Invalidated; } else if (CallInst *CI = dyn_cast<CallInst>(&I)) { @@ -653,9 +692,9 @@ bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, !isa<InsertValueInst>(I)) return false; - // SafetyInfo is nullptr if we are checking for sinking from preheader to - // loop body. It will be always safe as there is no speculative execution. - if (!SafetyInfo) + // If we are checking for sinking from preheader to loop body it will be + // always safe as there is no speculative execution. + if (SinkingToLoopBody) return true; // TODO: Plumb the context instruction through to make hoisting and sinking @@ -677,13 +716,40 @@ static bool isTriviallyReplacablePHI(const PHINode &PN, const Instruction &I) { return true; } +/// Return true if the instruction is free in the loop. +static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, + const TargetTransformInfo *TTI) { + + if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&I)) { + if (TTI->getUserCost(GEP) != TargetTransformInfo::TCC_Free) + return false; + // For a GEP, we cannot simply use getUserCost because currently it + // optimistically assume that a GEP will fold into addressing mode + // regardless of its users. + const BasicBlock *BB = GEP->getParent(); + for (const User *U : GEP->users()) { + const Instruction *UI = cast<Instruction>(U); + if (CurLoop->contains(UI) && + (BB != UI->getParent() || + (!isa<StoreInst>(UI) && !isa<LoadInst>(UI)))) + return false; + } + return true; + } else + return TTI->getUserCost(&I) == TargetTransformInfo::TCC_Free; +} + /// Return true if the only users of this instruction are outside of /// the loop. If this is true, we can sink the instruction to the exit /// blocks of the loop. /// -static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, - const LoopSafetyInfo *SafetyInfo) { +/// We also return true if the instruction could be folded away in lowering. +/// (e.g., a GEP can be folded into a load as an addressing mode in the loop). +static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + TargetTransformInfo *TTI, bool &FreeInLoop) { const auto &BlockColors = SafetyInfo->BlockColors; + bool IsFree = isFreeInLoop(I, CurLoop, TTI); for (const User *U : I.users()) { const Instruction *UI = cast<Instruction>(U); if (const PHINode *PN = dyn_cast<PHINode>(UI)) { @@ -698,30 +764,15 @@ static bool isNotUsedInLoop(const Instruction &I, const Loop *CurLoop, if (!BlockColors.empty() && BlockColors.find(const_cast<BasicBlock *>(BB))->second.size() != 1) return false; - - // A PHI node where all of the incoming values are this instruction are - // special -- they can just be RAUW'ed with the instruction and thus - // don't require a use in the predecessor. This is a particular important - // special case because it is the pattern found in LCSSA form. - if (isTriviallyReplacablePHI(*PN, I)) { - if (CurLoop->contains(PN)) - return false; - else - continue; - } - - // Otherwise, PHI node uses occur in predecessor blocks if the incoming - // values. Check for such a use being inside the loop. - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (PN->getIncomingValue(i) == &I) - if (CurLoop->contains(PN->getIncomingBlock(i))) - return false; - - continue; } - if (CurLoop->contains(UI)) + if (CurLoop->contains(UI)) { + if (IsFree) { + FreeInLoop = true; + continue; + } return false; + } } return true; } @@ -787,77 +838,189 @@ CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, return New; } +static Instruction *sinkThroughTriviallyReplacablePHI( + PHINode *TPN, Instruction *I, LoopInfo *LI, + SmallDenseMap<BasicBlock *, Instruction *, 32> &SunkCopies, + const LoopSafetyInfo *SafetyInfo, const Loop *CurLoop) { + assert(isTriviallyReplacablePHI(*TPN, *I) && + "Expect only trivially replacalbe PHI"); + BasicBlock *ExitBlock = TPN->getParent(); + Instruction *New; + auto It = SunkCopies.find(ExitBlock); + if (It != SunkCopies.end()) + New = It->second; + else + New = SunkCopies[ExitBlock] = + CloneInstructionInExitBlock(*I, *ExitBlock, *TPN, LI, SafetyInfo); + return New; +} + +static bool canSplitPredecessors(PHINode *PN) { + BasicBlock *BB = PN->getParent(); + if (!BB->canSplitPredecessors()) + return false; + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + BasicBlock *BBPred = *PI; + if (isa<IndirectBrInst>(BBPred->getTerminator())) + return false; + } + return true; +} + +static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, + LoopInfo *LI, const Loop *CurLoop) { +#ifndef NDEBUG + SmallVector<BasicBlock *, 32> ExitBlocks; + CurLoop->getUniqueExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 32> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); +#endif + BasicBlock *ExitBB = PN->getParent(); + assert(ExitBlockSet.count(ExitBB) && "Expect the PHI is in an exit block."); + + // Split predecessors of the loop exit to make instructions in the loop are + // exposed to exit blocks through trivially replacable PHIs while keeping the + // loop in the canonical form where each predecessor of each exit block should + // be contained within the loop. For example, this will convert the loop below + // from + // + // LB1: + // %v1 = + // br %LE, %LB2 + // LB2: + // %v2 = + // br %LE, %LB1 + // LE: + // %p = phi [%v1, %LB1], [%v2, %LB2] <-- non-trivially replacable + // + // to + // + // LB1: + // %v1 = + // br %LE.split, %LB2 + // LB2: + // %v2 = + // br %LE.split2, %LB1 + // LE.split: + // %p1 = phi [%v1, %LB1] <-- trivially replacable + // br %LE + // LE.split2: + // %p2 = phi [%v2, %LB2] <-- trivially replacable + // br %LE + // LE: + // %p = phi [%p1, %LE.split], [%p2, %LE.split2] + // + SmallSetVector<BasicBlock *, 8> PredBBs(pred_begin(ExitBB), pred_end(ExitBB)); + while (!PredBBs.empty()) { + BasicBlock *PredBB = *PredBBs.begin(); + assert(CurLoop->contains(PredBB) && + "Expect all predecessors are in the loop"); + if (PN->getBasicBlockIndex(PredBB) >= 0) + SplitBlockPredecessors(ExitBB, PredBB, ".split.loop.exit", DT, LI, true); + PredBBs.remove(PredBB); + } +} + /// When an instruction is found to only be used outside of the loop, this /// function moves it to the exit blocks and patches up SSA form as needed. /// This method is guaranteed to remove the original instruction from its /// position, and may either delete it or move it to outside of the loop. /// -static bool sink(Instruction &I, const LoopInfo *LI, const DominatorTree *DT, - const Loop *CurLoop, AliasSetTracker *CurAST, - const LoopSafetyInfo *SafetyInfo, - OptimizationRemarkEmitter *ORE) { +static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, + const Loop *CurLoop, const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, bool FreeInLoop) { DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); - ORE->emit(OptimizationRemark(DEBUG_TYPE, "InstSunk", &I) - << "sinking " << ore::NV("Inst", &I)); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "InstSunk", &I) + << "sinking " << ore::NV("Inst", &I); + }); bool Changed = false; if (isa<LoadInst>(I)) ++NumMovedLoads; else if (isa<CallInst>(I)) ++NumMovedCalls; ++NumSunk; - Changed = true; -#ifndef NDEBUG - SmallVector<BasicBlock *, 32> ExitBlocks; - CurLoop->getUniqueExitBlocks(ExitBlocks); - SmallPtrSet<BasicBlock *, 32> ExitBlockSet(ExitBlocks.begin(), - ExitBlocks.end()); -#endif + // Iterate over users to be ready for actual sinking. Replace users via + // unrechable blocks with undef and make all user PHIs trivially replcable. + SmallPtrSet<Instruction *, 8> VisitedUsers; + for (Value::user_iterator UI = I.user_begin(), UE = I.user_end(); UI != UE;) { + auto *User = cast<Instruction>(*UI); + Use &U = UI.getUse(); + ++UI; - // Clones of this instruction. Don't create more than one per exit block! - SmallDenseMap<BasicBlock *, Instruction *, 32> SunkCopies; + if (VisitedUsers.count(User) || CurLoop->contains(User)) + continue; - // If this instruction is only used outside of the loop, then all users are - // PHI nodes in exit blocks due to LCSSA form. Just RAUW them with clones of - // the instruction. - while (!I.use_empty()) { - Value::user_iterator UI = I.user_begin(); - auto *User = cast<Instruction>(*UI); if (!DT->isReachableFromEntry(User->getParent())) { - User->replaceUsesOfWith(&I, UndefValue::get(I.getType())); + U = UndefValue::get(I.getType()); + Changed = true; continue; } + // The user must be a PHI node. PHINode *PN = cast<PHINode>(User); // Surprisingly, instructions can be used outside of loops without any // exits. This can only happen in PHI nodes if the incoming block is // unreachable. - Use &U = UI.getUse(); BasicBlock *BB = PN->getIncomingBlock(U); if (!DT->isReachableFromEntry(BB)) { U = UndefValue::get(I.getType()); + Changed = true; continue; } - BasicBlock *ExitBlock = PN->getParent(); - assert(ExitBlockSet.count(ExitBlock) && - "The LCSSA PHI is not in an exit block!"); + VisitedUsers.insert(PN); + if (isTriviallyReplacablePHI(*PN, I)) + continue; - Instruction *New; - auto It = SunkCopies.find(ExitBlock); - if (It != SunkCopies.end()) - New = It->second; - else - New = SunkCopies[ExitBlock] = - CloneInstructionInExitBlock(I, *ExitBlock, *PN, LI, SafetyInfo); + if (!canSplitPredecessors(PN)) + return Changed; + + // Split predecessors of the PHI so that we can make users trivially + // replacable. + splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop); + + // Should rebuild the iterators, as they may be invalidated by + // splitPredecessorsOfLoopExit(). + UI = I.user_begin(); + UE = I.user_end(); + } + if (VisitedUsers.empty()) + return Changed; + +#ifndef NDEBUG + SmallVector<BasicBlock *, 32> ExitBlocks; + CurLoop->getUniqueExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 32> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); +#endif + + // Clones of this instruction. Don't create more than one per exit block! + SmallDenseMap<BasicBlock *, Instruction *, 32> SunkCopies; + + // If this instruction is only used outside of the loop, then all users are + // PHI nodes in exit blocks due to LCSSA form. Just RAUW them with clones of + // the instruction. + SmallSetVector<User*, 8> Users(I.user_begin(), I.user_end()); + for (auto *UI : Users) { + auto *User = cast<Instruction>(UI); + + if (CurLoop->contains(User)) + continue; + + PHINode *PN = cast<PHINode>(User); + assert(ExitBlockSet.count(PN->getParent()) && + "The LCSSA PHI is not in an exit block!"); + // The PHI must be trivially replacable. + Instruction *New = sinkThroughTriviallyReplacablePHI(PN, &I, LI, SunkCopies, + SafetyInfo, CurLoop); PN->replaceAllUsesWith(New); PN->eraseFromParent(); + Changed = true; } - - CurAST->deleteValue(&I); - I.eraseFromParent(); return Changed; } @@ -870,8 +1033,10 @@ static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, auto *Preheader = CurLoop->getLoopPreheader(); DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I << "\n"); - ORE->emit(OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) - << "hoisting " << ore::NV("Inst", &I)); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) << "hoisting " + << ore::NV("Inst", &I); + }); // Metadata can be dependent on conditions we are hoisting above. // Conservatively strip all metadata on the instruction unless we were @@ -921,10 +1086,12 @@ static bool isSafeToExecuteUnconditionally(Instruction &Inst, if (!GuaranteedToExecute) { auto *LI = dyn_cast<LoadInst>(&Inst); if (LI && CurLoop->isLoopInvariant(LI->getPointerOperand())) - ORE->emit(OptimizationRemarkMissed( - DEBUG_TYPE, "LoadWithLoopInvariantAddressCondExecuted", LI) - << "failed to hoist load with loop-invariant address " - "because load is conditionally executed"); + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "LoadWithLoopInvariantAddressCondExecuted", LI) + << "failed to hoist load with loop-invariant address " + "because load is conditionally executed"; + }); } return GuaranteedToExecute; @@ -933,7 +1100,7 @@ static bool isSafeToExecuteUnconditionally(Instruction &Inst, namespace { class LoopPromoter : public LoadAndStorePromoter { Value *SomePtr; // Designated pointer to store to. - SmallPtrSetImpl<Value *> &PointerMustAliases; + const SmallSetVector<Value *, 8> &PointerMustAliases; SmallVectorImpl<BasicBlock *> &LoopExitBlocks; SmallVectorImpl<Instruction *> &LoopInsertPts; PredIteratorCache &PredCache; @@ -961,7 +1128,7 @@ class LoopPromoter : public LoadAndStorePromoter { public: LoopPromoter(Value *SP, ArrayRef<const Instruction *> Insts, SSAUpdater &S, - SmallPtrSetImpl<Value *> &PMA, + const SmallSetVector<Value *, 8> &PMA, SmallVectorImpl<BasicBlock *> &LEB, SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, @@ -969,7 +1136,7 @@ public: : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), LI(li), DL(std::move(dl)), Alignment(alignment), - UnorderedAtomic(UnorderedAtomic),AATags(AATags) {} + UnorderedAtomic(UnorderedAtomic), AATags(AATags) {} bool isInstInList(Instruction *I, const SmallVectorImpl<Instruction *> &) const override { @@ -1008,7 +1175,31 @@ public: } void instructionDeleted(Instruction *I) const override { AST.deleteValue(I); } }; -} // end anon namespace + + +/// Return true iff we can prove that a caller of this function can not inspect +/// the contents of the provided object in a well defined program. +bool isKnownNonEscaping(Value *Object, const TargetLibraryInfo *TLI) { + if (isa<AllocaInst>(Object)) + // Since the alloca goes out of scope, we know the caller can't retain a + // reference to it and be well defined. Thus, we don't need to check for + // capture. + return true; + + // For all other objects we need to know that the caller can't possibly + // have gotten a reference to the object. There are two components of + // that: + // 1) Object can't be escaped by this function. This is what + // PointerMayBeCaptured checks. + // 2) Object can't have been captured at definition site. For this, we + // need to know the return value is noalias. At the moment, we use a + // weaker condition and handle only AllocLikeFunctions (which are + // known to be noalias). TODO + return isAllocLikeFn(Object, TLI) && + !PointerMayBeCaptured(Object, true, true); +} + +} // namespace /// Try to promote memory values to scalars by sinking stores out of the /// loop and moving loads to before the loop. We do this by looping over @@ -1016,7 +1207,8 @@ public: /// loop invariant. /// bool llvm::promoteLoopAccessesToScalars( - AliasSet &AS, SmallVectorImpl<BasicBlock *> &ExitBlocks, + const SmallSetVector<Value *, 8> &PointerMustAliases, + SmallVectorImpl<BasicBlock *> &ExitBlocks, SmallVectorImpl<Instruction *> &InsertPts, PredIteratorCache &PIC, LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, @@ -1026,17 +1218,7 @@ bool llvm::promoteLoopAccessesToScalars( CurAST != nullptr && SafetyInfo != nullptr && "Unexpected Input to promoteLoopAccessesToScalars"); - // We can promote this alias set if it has a store, if it is a "Must" alias - // set, if the pointer is loop invariant, and if we are not eliminating any - // volatile loads or stores. - if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || - AS.isVolatile() || !CurLoop->isLoopInvariant(AS.begin()->getValue())) - return false; - - assert(!AS.empty() && - "Must alias set should have at least one pointer element in it!"); - - Value *SomePtr = AS.begin()->getValue(); + Value *SomePtr = *PointerMustAliases.begin(); BasicBlock *Preheader = CurLoop->getLoopPreheader(); // It isn't safe to promote a load/store from the loop if the load/store is @@ -1065,8 +1247,8 @@ bool llvm::promoteLoopAccessesToScalars( // is safe (i.e. proving dereferenceability on all paths through the loop). We // can use any access within the alias set to prove dereferenceability, // since they're all must alias. - // - // There are two ways establish (p2): + // + // There are two ways establish (p2): // a) Prove the location is thread-local. In this case the memory model // requirement does not apply, and stores are safe to insert. // b) Prove a store dominates every exit block. In this case, if an exit @@ -1080,55 +1262,36 @@ bool llvm::promoteLoopAccessesToScalars( bool SafeToInsertStore = false; SmallVector<Instruction *, 64> LoopUses; - SmallPtrSet<Value *, 4> PointerMustAliases; // We start with an alignment of one and try to find instructions that allow // us to prove better alignment. unsigned Alignment = 1; // Keep track of which types of access we see - bool SawUnorderedAtomic = false; + bool SawUnorderedAtomic = false; bool SawNotAtomic = false; AAMDNodes AATags; const DataLayout &MDL = Preheader->getModule()->getDataLayout(); - // Do we know this object does not escape ? - bool IsKnownNonEscapingObject = false; + bool IsKnownThreadLocalObject = false; if (SafetyInfo->MayThrow) { // If a loop can throw, we have to insert a store along each unwind edge. // That said, we can't actually make the unwind edge explicit. Therefore, - // we have to prove that the store is dead along the unwind edge. - // - // If the underlying object is not an alloca, nor a pointer that does not - // escape, then we can not effectively prove that the store is dead along - // the unwind edge. i.e. the caller of this function could have ways to - // access the pointed object. + // we have to prove that the store is dead along the unwind edge. We do + // this by proving that the caller can't have a reference to the object + // after return and thus can't possibly load from the object. Value *Object = GetUnderlyingObject(SomePtr, MDL); - // If this is a base pointer we do not understand, simply bail. - // We only handle alloca and return value from alloc-like fn right now. - if (!isa<AllocaInst>(Object)) { - if (!isAllocLikeFn(Object, TLI)) - return false; - // If this is an alloc like fn. There are more constraints we need to verify. - // More specifically, we must make sure that the pointer can not escape. - // - // NOTE: PointerMayBeCaptured is not enough as the pointer may have escaped - // even though its not captured by the enclosing function. Standard allocation - // functions like malloc, calloc, and operator new return values which can - // be assumed not to have previously escaped. - if (PointerMayBeCaptured(Object, true, true)) - return false; - IsKnownNonEscapingObject = true; - } + if (!isKnownNonEscaping(Object, TLI)) + return false; + // Subtlety: Alloca's aren't visible to callers, but *are* potentially + // visible to other threads if captured and used during their lifetimes. + IsKnownThreadLocalObject = !isa<AllocaInst>(Object); } // Check that all of the pointers in the alias set have the same type. We // cannot (yet) promote a memory location that is loaded and stored in // different sizes. While we are at it, collect alignment and AA info. - for (const auto &ASI : AS) { - Value *ASIV = ASI.getValue(); - PointerMustAliases.insert(ASIV); - + for (Value *ASIV : PointerMustAliases) { // Check that all of the pointers in the alias set have the same type. We // cannot (yet) promote a memory location that is loaded and stored in // different sizes. @@ -1147,7 +1310,7 @@ bool llvm::promoteLoopAccessesToScalars( assert(!Load->isVolatile() && "AST broken"); if (!Load->isUnordered()) return false; - + SawUnorderedAtomic |= Load->isAtomic(); SawNotAtomic |= !Load->isAtomic(); @@ -1234,14 +1397,13 @@ bool llvm::promoteLoopAccessesToScalars( // stores along paths which originally didn't have them without violating the // memory model. if (!SafeToInsertStore) { - // If this is a known non-escaping object, it is safe to insert the stores. - if (IsKnownNonEscapingObject) + if (IsKnownThreadLocalObject) SafeToInsertStore = true; else { Value *Object = GetUnderlyingObject(SomePtr, MDL); SafeToInsertStore = - (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) && - !PointerMayBeCaptured(Object, true, true); + (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) && + !PointerMayBeCaptured(Object, true, true); } } @@ -1252,9 +1414,11 @@ bool llvm::promoteLoopAccessesToScalars( // Otherwise, this is safe to promote, lets do it! DEBUG(dbgs() << "LICM: Promoting value stored to in loop: " << *SomePtr << '\n'); - ORE->emit( - OptimizationRemark(DEBUG_TYPE, "PromoteLoopAccessesToScalar", LoopUses[0]) - << "Moving accesses to memory location out of the loop"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "PromoteLoopAccessesToScalar", + LoopUses[0]) + << "Moving accesses to memory location out of the loop"; + }); ++NumPromoted; // Grab a debug location for the inserted loads/stores; given that the @@ -1333,7 +1497,7 @@ LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, auto mergeLoop = [&](Loop *L) { // Loop over the body of this loop, looking for calls, invokes, and stores. for (BasicBlock *BB : L->blocks()) - CurAST->add(*BB); // Incorporate the specified basic block + CurAST->add(*BB); // Incorporate the specified basic block }; // Add everything from the sub loops that are no longer directly available. diff --git a/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/lib/Transforms/Scalar/LoopDataPrefetch.cpp index d09af32a99fd..7f7c6de76450 100644 --- a/lib/Transforms/Scalar/LoopDataPrefetch.cpp +++ b/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -18,25 +18,20 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; @@ -120,9 +115,7 @@ public: AU.addPreserved<LoopInfoWrapperPass>(); AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addRequired<ScalarEvolutionWrapperPass>(); - // FIXME: For some reason, preserving SE here breaks LSR (even if - // this pass changes nothing). - // AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); } @@ -329,8 +322,10 @@ bool LoopDataPrefetch::runOnLoop(Loop *L) { ++NumPrefetches; DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV << "\n"); - ORE->emit(OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) - << "prefetched memory access"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) + << "prefetched memory access"; + }); MadeChange = true; } diff --git a/lib/Transforms/Scalar/LoopDeletion.cpp b/lib/Transforms/Scalar/LoopDeletion.cpp index ac4dd44a0e90..82604a8842bf 100644 --- a/lib/Transforms/Scalar/LoopDeletion.cpp +++ b/lib/Transforms/Scalar/LoopDeletion.cpp @@ -30,20 +30,12 @@ using namespace llvm; STATISTIC(NumDeleted, "Number of loops deleted"); -/// This function deletes dead loops. The caller of this function needs to -/// guarantee that the loop is infact dead. Here we handle two kinds of dead -/// loop. The first kind (\p isLoopDead) is where only invariant values from -/// within the loop are used outside of it. The second kind (\p -/// isLoopNeverExecuted) is where the loop is provably never executed. We can -/// always remove never executed loops since they will not cause any difference -/// to program behaviour. -/// -/// This also updates the relevant analysis information in \p DT, \p SE, and \p -/// LI. It also updates the loop PM if an updater struct is provided. -// TODO: This function will be used by loop-simplifyCFG as well. So, move this -// to LoopUtils.cpp -static void deleteDeadLoop(Loop *L, DominatorTree &DT, ScalarEvolution &SE, - LoopInfo &LI, LPMUpdater *Updater = nullptr); +enum class LoopDeletionResult { + Unmodified, + Modified, + Deleted, +}; + /// Determines if a loop is dead. /// /// This assumes that we've already checked for unique exit and exiting blocks, @@ -144,8 +136,8 @@ static bool isLoopNeverExecuted(Loop *L) { /// \returns true if any changes were made. This may mutate the loop even if it /// is unable to delete it due to hoisting trivially loop invariant /// instructions out of the loop. -static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE, - LoopInfo &LI, LPMUpdater *Updater = nullptr) { +static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, + ScalarEvolution &SE, LoopInfo &LI) { assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); // We can only remove the loop if there is a preheader that we can branch from @@ -155,13 +147,13 @@ static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE, if (!Preheader || !L->hasDedicatedExits()) { DEBUG(dbgs() << "Deletion requires Loop with preheader and dedicated exits.\n"); - return false; + return LoopDeletionResult::Unmodified; } // We can't remove loops that contain subloops. If the subloops were dead, // they would already have been removed in earlier executions of this pass. if (L->begin() != L->end()) { DEBUG(dbgs() << "Loop contains subloops.\n"); - return false; + return LoopDeletionResult::Unmodified; } @@ -176,9 +168,9 @@ static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE, P->setIncomingValue(i, UndefValue::get(P->getType())); BI++; } - deleteDeadLoop(L, DT, SE, LI, Updater); + deleteDeadLoop(L, &DT, &SE, &LI); ++NumDeleted; - return true; + return LoopDeletionResult::Deleted; } // The remaining checks below are for a loop being dead because all statements @@ -192,13 +184,14 @@ static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE, // a loop invariant manner. if (!ExitBlock) { DEBUG(dbgs() << "Deletion requires single exit block\n"); - return false; + return LoopDeletionResult::Unmodified; } // Finally, we have to check that the loop really is dead. bool Changed = false; if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader)) { DEBUG(dbgs() << "Loop is not invariant, cannot delete.\n"); - return Changed; + return Changed ? LoopDeletionResult::Modified + : LoopDeletionResult::Unmodified; } // Don't remove loops for which we can't solve the trip count. @@ -206,114 +199,15 @@ static bool deleteLoopIfDead(Loop *L, DominatorTree &DT, ScalarEvolution &SE, const SCEV *S = SE.getMaxBackedgeTakenCount(L); if (isa<SCEVCouldNotCompute>(S)) { DEBUG(dbgs() << "Could not compute SCEV MaxBackedgeTakenCount.\n"); - return Changed; + return Changed ? LoopDeletionResult::Modified + : LoopDeletionResult::Unmodified; } DEBUG(dbgs() << "Loop is invariant, delete it!"); - deleteDeadLoop(L, DT, SE, LI, Updater); + deleteDeadLoop(L, &DT, &SE, &LI); ++NumDeleted; - return true; -} - -static void deleteDeadLoop(Loop *L, DominatorTree &DT, ScalarEvolution &SE, - LoopInfo &LI, LPMUpdater *Updater) { - assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); - auto *Preheader = L->getLoopPreheader(); - assert(Preheader && "Preheader should exist!"); - - // Now that we know the removal is safe, remove the loop by changing the - // branch from the preheader to go to the single exit block. - // - // Because we're deleting a large chunk of code at once, the sequence in which - // we remove things is very important to avoid invalidation issues. - - // If we have an LPM updater, tell it about the loop being removed. - if (Updater) - Updater->markLoopAsDeleted(*L); - - // Tell ScalarEvolution that the loop is deleted. Do this before - // deleting the loop so that ScalarEvolution can look at the loop - // to determine what it needs to clean up. - SE.forgetLoop(L); - - auto *ExitBlock = L->getUniqueExitBlock(); - assert(ExitBlock && "Should have a unique exit block!"); - - assert(L->hasDedicatedExits() && "Loop should have dedicated exits!"); - - // Connect the preheader directly to the exit block. - // Even when the loop is never executed, we cannot remove the edge from the - // source block to the exit block. Consider the case where the unexecuted loop - // branches back to an outer loop. If we deleted the loop and removed the edge - // coming to this inner loop, this will break the outer loop structure (by - // deleting the backedge of the outer loop). If the outer loop is indeed a - // non-loop, it will be deleted in a future iteration of loop deletion pass. - Preheader->getTerminator()->replaceUsesOfWith(L->getHeader(), ExitBlock); - - // Rewrite phis in the exit block to get their inputs from the Preheader - // instead of the exiting block. - BasicBlock::iterator BI = ExitBlock->begin(); - while (PHINode *P = dyn_cast<PHINode>(BI)) { - // Set the zero'th element of Phi to be from the preheader and remove all - // other incoming values. Given the loop has dedicated exits, all other - // incoming values must be from the exiting blocks. - int PredIndex = 0; - P->setIncomingBlock(PredIndex, Preheader); - // Removes all incoming values from all other exiting blocks (including - // duplicate values from an exiting block). - // Nuke all entries except the zero'th entry which is the preheader entry. - // NOTE! We need to remove Incoming Values in the reverse order as done - // below, to keep the indices valid for deletion (removeIncomingValues - // updates getNumIncomingValues and shifts all values down into the operand - // being deleted). - for (unsigned i = 0, e = P->getNumIncomingValues() - 1; i != e; ++i) - P->removeIncomingValue(e-i, false); - - assert((P->getNumIncomingValues() == 1 && - P->getIncomingBlock(PredIndex) == Preheader) && - "Should have exactly one value and that's from the preheader!"); - ++BI; - } - - // Update the dominator tree and remove the instructions and blocks that will - // be deleted from the reference counting scheme. - SmallVector<DomTreeNode*, 8> ChildNodes; - for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); - LI != LE; ++LI) { - // Move all of the block's children to be children of the Preheader, which - // allows us to remove the domtree entry for the block. - ChildNodes.insert(ChildNodes.begin(), DT[*LI]->begin(), DT[*LI]->end()); - for (DomTreeNode *ChildNode : ChildNodes) { - DT.changeImmediateDominator(ChildNode, DT[Preheader]); - } - - ChildNodes.clear(); - DT.eraseNode(*LI); - - // Remove the block from the reference counting scheme, so that we can - // delete it freely later. - (*LI)->dropAllReferences(); - } - - // Erase the instructions and the blocks without having to worry - // about ordering because we already dropped the references. - // NOTE: This iteration is safe because erasing the block does not remove its - // entry from the loop's block list. We do that in the next section. - for (Loop::block_iterator LI = L->block_begin(), LE = L->block_end(); - LI != LE; ++LI) - (*LI)->eraseFromParent(); - - // Finally, the blocks from loopinfo. This has to happen late because - // otherwise our loop iterators won't work. - - SmallPtrSet<BasicBlock *, 8> blocks; - blocks.insert(L->block_begin(), L->block_end()); - for (BasicBlock *BB : blocks) - LI.removeBlock(BB); - - // The last step is to update LoopInfo now that we've eliminated this loop. - LI.markAsRemoved(L); + return LoopDeletionResult::Deleted; } PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, @@ -322,9 +216,14 @@ PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, DEBUG(dbgs() << "Analyzing Loop for deletion: "); DEBUG(L.dump()); - if (!deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI, &Updater)) + std::string LoopName = L.getName(); + auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI); + if (Result == LoopDeletionResult::Unmodified) return PreservedAnalyses::all(); + if (Result == LoopDeletionResult::Deleted) + Updater.markLoopAsDeleted(L, LoopName); + return getLoopPassPreservedAnalyses(); } @@ -354,7 +253,7 @@ INITIALIZE_PASS_END(LoopDeletionLegacyPass, "loop-deletion", Pass *llvm::createLoopDeletionPass() { return new LoopDeletionLegacyPass(); } -bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &) { +bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { if (skipLoop(L)) return false; DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); @@ -363,5 +262,11 @@ bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &) { DEBUG(dbgs() << "Analyzing Loop for deletion: "); DEBUG(L->dump()); - return deleteLoopIfDead(L, DT, SE, LI); + + LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI); + + if (Result == LoopDeletionResult::Deleted) + LPM.markLoopAsDeleted(*L); + + return Result != LoopDeletionResult::Unmodified; } diff --git a/lib/Transforms/Scalar/LoopDistribute.cpp b/lib/Transforms/Scalar/LoopDistribute.cpp index 3624bba10345..0d7e3db901cb 100644 --- a/lib/Transforms/Scalar/LoopDistribute.cpp +++ b/lib/Transforms/Scalar/LoopDistribute.cpp @@ -23,32 +23,61 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopDistribute.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DepthFirstIterator.h" #include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <cassert> +#include <functional> #include <list> +#include <tuple> +#include <utility> + +using namespace llvm; #define LDIST_NAME "loop-distribute" #define DEBUG_TYPE LDIST_NAME -using namespace llvm; - static cl::opt<bool> LDistVerify("loop-distribute-verify", cl::Hidden, cl::desc("Turn on DominatorTree and LoopInfo verification " @@ -81,14 +110,15 @@ static cl::opt<bool> EnableLoopDistribute( STATISTIC(NumLoopsDistributed, "Number of loops distributed"); namespace { + /// \brief Maintains the set of instructions of the loop for a partition before /// cloning. After cloning, it hosts the new loop. class InstPartition { - typedef SmallPtrSet<Instruction *, 8> InstructionSet; + using InstructionSet = SmallPtrSet<Instruction *, 8>; public: InstPartition(Instruction *I, Loop *L, bool DepCycle = false) - : DepCycle(DepCycle), OrigLoop(L), ClonedLoop(nullptr) { + : DepCycle(DepCycle), OrigLoop(L) { Set.insert(I); } @@ -220,7 +250,7 @@ private: /// \brief The cloned loop. If this partition is mapped to the original loop, /// this is null. - Loop *ClonedLoop; + Loop *ClonedLoop = nullptr; /// \brief The blocks of ClonedLoop including the preheader. If this /// partition is mapped to the original loop, this is empty. @@ -235,7 +265,7 @@ private: /// \brief Holds the set of Partitions. It populates them, merges them and then /// clones the loops. class InstPartitionContainer { - typedef DenseMap<Instruction *, int> InstToPartitionIdT; + using InstToPartitionIdT = DenseMap<Instruction *, int>; public: InstPartitionContainer(Loop *L, LoopInfo *LI, DominatorTree *DT) @@ -308,8 +338,8 @@ public: /// /// Return if any partitions were merged. bool mergeToAvoidDuplicatedLoads() { - typedef DenseMap<Instruction *, InstPartition *> LoadToPartitionT; - typedef EquivalenceClasses<InstPartition *> ToBeMergedT; + using LoadToPartitionT = DenseMap<Instruction *, InstPartition *>; + using ToBeMergedT = EquivalenceClasses<InstPartition *>; LoadToPartitionT LoadToPartition; ToBeMergedT ToBeMerged; @@ -511,7 +541,7 @@ public: } private: - typedef std::list<InstPartition> PartitionContainerT; + using PartitionContainerT = std::list<InstPartition>; /// \brief List of partitions. PartitionContainerT PartitionContainer; @@ -552,17 +582,17 @@ private: /// By traversing the memory instructions in program order and accumulating this /// number, we know whether any unsafe dependence crosses over a program point. class MemoryInstructionDependences { - typedef MemoryDepChecker::Dependence Dependence; + using Dependence = MemoryDepChecker::Dependence; public: struct Entry { Instruction *Inst; - unsigned NumUnsafeDependencesStartOrEnd; + unsigned NumUnsafeDependencesStartOrEnd = 0; - Entry(Instruction *Inst) : Inst(Inst), NumUnsafeDependencesStartOrEnd(0) {} + Entry(Instruction *Inst) : Inst(Inst) {} }; - typedef SmallVector<Entry, 8> AccessesType; + using AccessesType = SmallVector<Entry, 8>; AccessesType::const_iterator begin() const { return Accesses.begin(); } AccessesType::const_iterator end() const { return Accesses.end(); } @@ -594,7 +624,7 @@ class LoopDistributeForLoop { public: LoopDistributeForLoop(Loop *L, Function *F, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE) - : L(L), F(F), LI(LI), LAI(nullptr), DT(DT), SE(SE), ORE(ORE) { + : L(L), F(F), LI(LI), DT(DT), SE(SE), ORE(ORE) { setForced(); } @@ -755,9 +785,11 @@ public: ++NumLoopsDistributed; // Report the success. - ORE->emit(OptimizationRemark(LDIST_NAME, "Distribute", L->getStartLoc(), - L->getHeader()) - << "distributed loop"); + ORE->emit([&]() { + return OptimizationRemark(LDIST_NAME, "Distribute", L->getStartLoc(), + L->getHeader()) + << "distributed loop"; + }); return true; } @@ -769,11 +801,13 @@ public: DEBUG(dbgs() << "Skipping; " << Message << "\n"); // With Rpass-missed report that distribution failed. - ORE->emit( - OptimizationRemarkMissed(LDIST_NAME, "NotDistributed", L->getStartLoc(), - L->getHeader()) - << "loop not distributed: use -Rpass-analysis=loop-distribute for more " - "info"); + ORE->emit([&]() { + return OptimizationRemarkMissed(LDIST_NAME, "NotDistributed", + L->getStartLoc(), L->getHeader()) + << "loop not distributed: use -Rpass-analysis=loop-distribute for " + "more " + "info"; + }); // With Rpass-analysis report why. This is on by default if distribution // was requested explicitly. @@ -857,7 +891,7 @@ private: // Analyses used. LoopInfo *LI; - const LoopAccessInfo *LAI; + const LoopAccessInfo *LAI = nullptr; DominatorTree *DT; ScalarEvolution *SE; OptimizationRemarkEmitter *ORE; @@ -871,6 +905,8 @@ private: Optional<bool> IsForced; }; +} // end anonymous namespace + /// Shared implementation between new and old PMs. static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, @@ -901,9 +937,13 @@ static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, return Changed; } +namespace { + /// \brief The pass class. class LoopDistributeLegacy : public FunctionPass { public: + static char ID; + LoopDistributeLegacy() : FunctionPass(ID) { // The default is set by the caller. initializeLoopDistributeLegacyPass(*PassRegistry::getPassRegistry()); @@ -934,10 +974,9 @@ public: AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } - - static char ID; }; -} // anonymous namespace + +} // end anonymous namespace PreservedAnalyses LoopDistributePass::run(Function &F, FunctionAnalysisManager &AM) { @@ -956,7 +995,7 @@ PreservedAnalyses LoopDistributePass::run(Function &F, auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI, nullptr}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }; @@ -971,6 +1010,7 @@ PreservedAnalyses LoopDistributePass::run(Function &F, } char LoopDistributeLegacy::ID; + static const char ldist_name[] = "Loop Distribution"; INITIALIZE_PASS_BEGIN(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, @@ -982,6 +1022,4 @@ INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, false) -namespace llvm { -FunctionPass *createLoopDistributePass() { return new LoopDistributeLegacy(); } -} +FunctionPass *llvm::createLoopDistributePass() { return new LoopDistributeLegacy(); } diff --git a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp index 4a6a35c0ab1b..21551f0a0825 100644 --- a/lib/Transforms/Scalar/LoopIdiomRecognize.cpp +++ b/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -1,4 +1,4 @@ -//===-- LoopIdiomRecognize.cpp - Loop idiom recognition -------------------===// +//===- LoopIdiomRecognize.cpp - Loop idiom recognition --------------------===// // // The LLVM Compiler Infrastructure // @@ -38,32 +38,64 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/BasicAliasAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "loop-idiom" @@ -80,7 +112,7 @@ static cl::opt<bool> UseLIRCodeSizeHeurs( namespace { class LoopIdiomRecognize { - Loop *CurLoop; + Loop *CurLoop = nullptr; AliasAnalysis *AA; DominatorTree *DT; LoopInfo *LI; @@ -96,20 +128,21 @@ public: TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, const DataLayout *DL) - : CurLoop(nullptr), AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), - DL(DL) {} + : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL) {} bool runOnLoop(Loop *L); private: - typedef SmallVector<StoreInst *, 8> StoreList; - typedef MapVector<Value *, StoreList> StoreListMap; + using StoreList = SmallVector<StoreInst *, 8>; + using StoreListMap = MapVector<Value *, StoreList>; + StoreListMap StoreRefsForMemset; StoreListMap StoreRefsForMemsetPattern; StoreList StoreRefsForMemcpy; bool HasMemset; bool HasMemsetPattern; bool HasMemcpy; + /// Return code for isLegalStore() enum LegalStoreKind { None = 0, @@ -164,6 +197,7 @@ private: class LoopIdiomRecognizeLegacyPass : public LoopPass { public: static char ID; + explicit LoopIdiomRecognizeLegacyPass() : LoopPass(ID) { initializeLoopIdiomRecognizeLegacyPassPass( *PassRegistry::getPassRegistry()); @@ -190,14 +224,16 @@ public: /// This transformation requires natural loop information & requires that /// loop preheaders be inserted into the CFG. - /// void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); getLoopAnalysisUsage(AU); } }; -} // End anonymous namespace. + +} // end anonymous namespace + +char LoopIdiomRecognizeLegacyPass::ID = 0; PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, @@ -211,7 +247,6 @@ PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, return getLoopPassPreservedAnalyses(); } -char LoopIdiomRecognizeLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom", "Recognize loop idioms", false, false) INITIALIZE_PASS_DEPENDENCY(LoopPass) @@ -299,13 +334,6 @@ bool LoopIdiomRecognize::runOnCountableLoop() { return MadeChange; } -static unsigned getStoreSizeInBytes(StoreInst *SI, const DataLayout *DL) { - uint64_t SizeInBits = DL->getTypeSizeInBits(SI->getValueOperand()->getType()); - assert(((SizeInBits & 7) || (SizeInBits >> 32) == 0) && - "Don't overflow unsigned."); - return (unsigned)SizeInBits >> 3; -} - static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) { const SCEVConstant *ConstStride = cast<SCEVConstant>(StoreEv->getOperand(1)); return ConstStride->getAPInt(); @@ -354,7 +382,6 @@ static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) { LoopIdiomRecognize::LegalStoreKind LoopIdiomRecognize::isLegalStore(StoreInst *SI) { - // Don't touch volatile stores. if (SI->isVolatile()) return LegalStoreKind::None; @@ -424,7 +451,7 @@ LoopIdiomRecognize::isLegalStore(StoreInst *SI) { // Check to see if the stride matches the size of the store. If so, then we // know that every byte is touched in the loop. APInt Stride = getStoreStride(StoreEv); - unsigned StoreSize = getStoreSizeInBytes(SI, DL); + unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); if (StoreSize != Stride && StoreSize != -Stride) return LegalStoreKind::None; @@ -563,7 +590,7 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEVAddRecExpr *FirstStoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(FirstStorePtr)); APInt FirstStride = getStoreStride(FirstStoreEv); - unsigned FirstStoreSize = getStoreSizeInBytes(SL[i], DL); + unsigned FirstStoreSize = DL->getTypeStoreSize(SL[i]->getValueOperand()->getType()); // See if we can optimize just this store in isolation. if (FirstStride == FirstStoreSize || -FirstStride == FirstStoreSize) { @@ -656,7 +683,7 @@ bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, break; AdjacentStores.insert(I); - StoreSize += getStoreSizeInBytes(I, DL); + StoreSize += DL->getTypeStoreSize(I->getValueOperand()->getType()); // Move to the next value in the chain. I = ConsecutiveChain[I]; } @@ -761,7 +788,8 @@ mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, ++BI) for (Instruction &I : **BI) if (IgnoredStores.count(&I) == 0 && - (AA.getModRefInfo(&I, StoreLoc) & Access)) + isModOrRefSet( + intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) return true; return false; @@ -780,6 +808,41 @@ static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, return SE->getMinusSCEV(Start, Index); } +/// Compute the number of bytes as a SCEV from the backedge taken count. +/// +/// This also maps the SCEV into the provided type and tries to handle the +/// computation in a way that will fold cleanly. +static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, + unsigned StoreSize, Loop *CurLoop, + const DataLayout *DL, ScalarEvolution *SE) { + const SCEV *NumBytesS; + // The # stored bytes is (BECount+1)*Size. Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + NumBytesS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + // And scale it based on the store size. + if (StoreSize != 1) { + NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), + SCEV::FlagNUW); + } + return NumBytesS; +} + /// processLoopStridedStore - We see a strided store of some value. If we can /// transform this into a memset or memset_pattern in the loop preheader, do so. bool LoopIdiomRecognize::processLoopStridedStore( @@ -824,8 +887,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( // base pointer and checking the region. Value *BasePtr = Expander.expandCodeFor(Start, DestInt8PtrTy, Preheader->getTerminator()); - if (mayLoopAccessLocation(BasePtr, MRI_ModRef, CurLoop, BECount, StoreSize, - *AA, Stores)) { + if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount, + StoreSize, *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(BasePtr, TLI); @@ -837,16 +900,8 @@ bool LoopIdiomRecognize::processLoopStridedStore( // Okay, everything looks good, insert the memset. - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtr); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtr), SCEV::FlagNUW); - if (StoreSize != 1) { - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), - SCEV::FlagNUW); - } + getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE); // TODO: ideally we should still be able to generate memset if SCEV expander // is taught to generate the dependencies at the latest point. @@ -903,7 +958,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *StorePtr = SI->getPointerOperand(); const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); APInt Stride = getStoreStride(StoreEv); - unsigned StoreSize = getStoreSizeInBytes(SI, DL); + unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); bool NegStride = StoreSize == -Stride; // The store must be feeding a non-volatile load. @@ -942,7 +997,7 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, SmallPtrSet<Instruction *, 1> Stores; Stores.insert(SI); - if (mayLoopAccessLocation(StoreBasePtr, MRI_ModRef, CurLoop, BECount, + if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, StoreSize, *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. @@ -962,8 +1017,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, Value *LoadBasePtr = Expander.expandCodeFor( LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); - if (mayLoopAccessLocation(LoadBasePtr, MRI_Mod, CurLoop, BECount, StoreSize, - *AA, Stores)) { + if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, + StoreSize, *AA, Stores)) { Expander.clear(); // If we generated new code for the base pointer, clean up. RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); @@ -976,16 +1031,8 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, // Okay, everything is safe, we can transform this! - // The # stored bytes is (BECount+1)*Size. Expand the trip count out to - // pointer size if it isn't already. - BECount = SE->getTruncateOrZeroExtend(BECount, IntPtrTy); - const SCEV *NumBytesS = - SE->getAddExpr(BECount, SE->getOne(IntPtrTy), SCEV::FlagNUW); - - if (StoreSize != 1) - NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtrTy, StoreSize), - SCEV::FlagNUW); + getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE); Value *NumBytes = Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator()); @@ -1010,16 +1057,12 @@ bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, if (StoreSize > TTI->getAtomicMemIntrinsicMaxElementSize()) return false; + // Create the call. + // Note that unordered atomic loads/stores are *required* by the spec to + // have an alignment but non-atomic loads/stores may not. NewCall = Builder.CreateElementUnorderedAtomicMemCpy( - StoreBasePtr, LoadBasePtr, NumBytes, StoreSize); - - // Propagate alignment info onto the pointer args. Note that unordered - // atomic loads/stores are *required* by the spec to have an alignment - // but non-atomic loads/stores may not. - NewCall->addParamAttr(0, Attribute::getWithAlignment(NewCall->getContext(), - SI->getAlignment())); - NewCall->addParamAttr(1, Attribute::getWithAlignment(NewCall->getContext(), - LI->getAlignment())); + StoreBasePtr, SI->getAlignment(), LoadBasePtr, LI->getAlignment(), + NumBytes, StoreSize); } NewCall->setDebugLoc(SI->getDebugLoc()); @@ -1273,9 +1316,9 @@ static bool detectCTLZIdiom(Loop *CurLoop, PHINode *&PhiX, // step 2: detect instructions corresponding to "x.next = x >> 1" if (!DefX || DefX->getOpcode() != Instruction::AShr) return false; - if (ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1))) - if (!Shft || !Shft->isOne()) - return false; + ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1)); + if (!Shft || !Shft->isOne()) + return false; VarX = DefX->getOperand(0); // step 3: Check the recurrence of variable X @@ -1469,7 +1512,7 @@ static CallInst *createCTLZIntrinsic(IRBuilder<> &IRBuilder, Value *Val, /// PhiX = PHI [InitX, DefX] /// CntInst = CntPhi + 1 /// DefX = PhiX >> 1 -// LOOP_BODY +/// LOOP_BODY /// Br: loop if (DefX != 0) /// Use(CntPhi) or Use(CntInst) /// diff --git a/lib/Transforms/Scalar/LoopInstSimplify.cpp b/lib/Transforms/Scalar/LoopInstSimplify.cpp index af095560cc02..40d468a084d4 100644 --- a/lib/Transforms/Scalar/LoopInstSimplify.cpp +++ b/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -12,22 +12,33 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopInstSimplify.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/Support/Debug.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/User.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include <algorithm> +#include <utility> + using namespace llvm; #define DEBUG_TYPE "loop-instsimplify" @@ -45,7 +56,7 @@ static bool SimplifyLoopInst(Loop *L, DominatorTree *DT, LoopInfo *LI, // The bit we are stealing from the pointer represents whether this basic // block is the header of a subloop, in which case we only process its phis. - typedef PointerIntPair<BasicBlock *, 1> WorklistItem; + using WorklistItem = PointerIntPair<BasicBlock *, 1>; SmallVector<WorklistItem, 16> VisitStack; SmallPtrSet<BasicBlock *, 32> Visited; @@ -151,9 +162,11 @@ static bool SimplifyLoopInst(Loop *L, DominatorTree *DT, LoopInfo *LI, } namespace { + class LoopInstSimplifyLegacyPass : public LoopPass { public: static char ID; // Pass ID, replacement for typeid + LoopInstSimplifyLegacyPass() : LoopPass(ID) { initializeLoopInstSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -181,7 +194,8 @@ public: getLoopAnalysisUsage(AU); } }; -} + +} // end anonymous namespace PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, LoopStandardAnalysisResults &AR, @@ -195,6 +209,7 @@ PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, } char LoopInstSimplifyLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(LoopInstSimplifyLegacyPass, "loop-instsimplify", "Simplify instructions in loops", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) diff --git a/lib/Transforms/Scalar/LoopInterchange.cpp b/lib/Transforms/Scalar/LoopInterchange.cpp index 2e0d8e0374c0..4f8dafef230a 100644 --- a/lib/Transforms/Scalar/LoopInterchange.cpp +++ b/lib/Transforms/Scalar/LoopInterchange.cpp @@ -1,4 +1,4 @@ -//===- LoopInterchange.cpp - Loop interchange pass------------------------===// +//===- LoopInterchange.cpp - Loop interchange pass-------------------------===// // // The LLVM Compiler Infrastructure // @@ -13,33 +13,38 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" -#include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DependenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/TargetTransformInfo.h" -#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/InstIterator.h" -#include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include <cassert> +#include <utility> +#include <vector> using namespace llvm; @@ -51,10 +56,12 @@ static cl::opt<int> LoopInterchangeCostThreshold( namespace { -typedef SmallVector<Loop *, 8> LoopVector; +using LoopVector = SmallVector<Loop *, 8>; // TODO: Check if we can use a sparse matrix here. -typedef std::vector<std::vector<char>> CharMatrix; +using CharMatrix = std::vector<std::vector<char>>; + +} // end anonymous namespace // Maximum number of dependencies that can be handled in the dependency matrix. static const unsigned MaxMemInstrCount = 100; @@ -62,14 +69,11 @@ static const unsigned MaxMemInstrCount = 100; // Maximum loop depth supported. static const unsigned MaxLoopNestDepth = 10; -struct LoopInterchange; - #ifdef DUMP_DEP_MATRICIES -void printDepMatrix(CharMatrix &DepMatrix) { - for (auto I = DepMatrix.begin(), E = DepMatrix.end(); I != E; ++I) { - std::vector<char> Vec = *I; - for (auto II = Vec.begin(), EE = Vec.end(); II != EE; ++II) - DEBUG(dbgs() << *II << " "); +static void printDepMatrix(CharMatrix &DepMatrix) { + for (auto &Row : DepMatrix) { + for (auto D : Row) + DEBUG(dbgs() << D << " "); DEBUG(dbgs() << "\n"); } } @@ -77,25 +81,24 @@ void printDepMatrix(CharMatrix &DepMatrix) { static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, Loop *L, DependenceInfo *DI) { - typedef SmallVector<Value *, 16> ValueVector; + using ValueVector = SmallVector<Value *, 16>; + ValueVector MemInstr; // For each block. - for (Loop::block_iterator BB = L->block_begin(), BE = L->block_end(); - BB != BE; ++BB) { + for (BasicBlock *BB : L->blocks()) { // Scan the BB and collect legal loads and stores. - for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; - ++I) { + for (Instruction &I : *BB) { if (!isa<Instruction>(I)) return false; - if (LoadInst *Ld = dyn_cast<LoadInst>(I)) { + if (auto *Ld = dyn_cast<LoadInst>(&I)) { if (!Ld->isSimple()) return false; - MemInstr.push_back(&*I); - } else if (StoreInst *St = dyn_cast<StoreInst>(I)) { + MemInstr.push_back(&I); + } else if (auto *St = dyn_cast<StoreInst>(&I)) { if (!St->isSimple()) return false; - MemInstr.push_back(&*I); + MemInstr.push_back(&I); } } } @@ -171,7 +174,7 @@ static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, } // We don't have a DepMatrix to check legality return false. - if (DepMatrix.size() == 0) + if (DepMatrix.empty()) return false; return true; } @@ -216,7 +219,6 @@ static bool containsNoDependence(CharMatrix &DepMatrix, unsigned Row, static bool validDepInterchange(CharMatrix &DepMatrix, unsigned Row, unsigned OuterLoopId, char InnerDep, char OuterDep) { - if (isOuterMostDepPositive(DepMatrix, Row, OuterLoopId)) return false; @@ -255,7 +257,6 @@ static bool validDepInterchange(CharMatrix &DepMatrix, unsigned Row, static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, unsigned InnerLoopId, unsigned OuterLoopId) { - unsigned NumRows = DepMatrix.size(); // For each row check if it is valid to interchange. for (unsigned Row = 0; Row < NumRows; ++Row) { @@ -270,7 +271,6 @@ static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, } static void populateWorklist(Loop &L, SmallVector<LoopVector, 8> &V) { - DEBUG(dbgs() << "Calling populateWorklist on Func: " << L.getHeader()->getParent()->getName() << " Loop: %" << L.getHeader()->getName() << '\n'); @@ -320,6 +320,8 @@ static PHINode *getInductionVariable(Loop *L, ScalarEvolution *SE) { return nullptr; } +namespace { + /// LoopInterchangeLegality checks if it is legal to interchange the loop. class LoopInterchangeLegality { public: @@ -327,11 +329,12 @@ public: LoopInfo *LI, DominatorTree *DT, bool PreserveLCSSA, OptimizationRemarkEmitter *ORE) : OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), - PreserveLCSSA(PreserveLCSSA), ORE(ORE), InnerLoopHasReduction(false) {} + PreserveLCSSA(PreserveLCSSA), ORE(ORE) {} /// Check if the loops can be interchanged. bool canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId, CharMatrix &DepMatrix); + /// Check if the loop structure is understood. We do not handle triangular /// loops for now. bool isLoopStructureUnderstood(PHINode *InnerInductionVar); @@ -348,6 +351,7 @@ private: bool findInductionAndReductions(Loop *L, SmallVector<PHINode *, 8> &Inductions, SmallVector<PHINode *, 8> &Reductions); + Loop *OuterLoop; Loop *InnerLoop; @@ -355,10 +359,11 @@ private: LoopInfo *LI; DominatorTree *DT; bool PreserveLCSSA; + /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; - bool InnerLoopHasReduction; + bool InnerLoopHasReduction = false; }; /// LoopInterchangeProfitability checks if it is profitable to interchange the @@ -381,6 +386,7 @@ private: /// Scev analysis. ScalarEvolution *SE; + /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; }; @@ -415,6 +421,7 @@ private: /// Scev analysis. ScalarEvolution *SE; + LoopInfo *LI; DominatorTree *DT; BasicBlock *LoopExit; @@ -424,16 +431,16 @@ private: // Main LoopInterchange Pass. struct LoopInterchange : public FunctionPass { static char ID; - ScalarEvolution *SE; - LoopInfo *LI; - DependenceInfo *DI; - DominatorTree *DT; + ScalarEvolution *SE = nullptr; + LoopInfo *LI = nullptr; + DependenceInfo *DI = nullptr; + DominatorTree *DT = nullptr; bool PreserveLCSSA; + /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; - LoopInterchange() - : FunctionPass(ID), SE(nullptr), LI(nullptr), DI(nullptr), DT(nullptr) { + LoopInterchange() : FunctionPass(ID) { initializeLoopInterchangePass(*PassRegistry::getPassRegistry()); } @@ -501,7 +508,6 @@ struct LoopInterchange : public FunctionPass { } bool processLoopList(LoopVector LoopList, Function &F) { - bool Changed = false; unsigned LoopNestDepth = LoopList.size(); if (LoopNestDepth < 2) { @@ -580,7 +586,6 @@ struct LoopInterchange : public FunctionPass { bool processLoop(LoopVector LoopList, unsigned InnerLoopId, unsigned OuterLoopId, BasicBlock *LoopNestExit, std::vector<std::vector<char>> &DependencyMatrix) { - DEBUG(dbgs() << "Processing Inner Loop Id = " << InnerLoopId << " and OuterLoopId = " << OuterLoopId << "\n"); Loop *InnerLoop = LoopList[InnerLoopId]; @@ -599,10 +604,12 @@ struct LoopInterchange : public FunctionPass { return false; } - ORE->emit(OptimizationRemark(DEBUG_TYPE, "Interchanged", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Loop interchanged with enclosing loop."); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Interchanged", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Loop interchanged with enclosing loop."; + }); LoopInterchangeTransform LIT(OuterLoop, InnerLoop, SE, LI, DT, LoopNestExit, LIL.hasInnerLoopReduction()); @@ -612,9 +619,10 @@ struct LoopInterchange : public FunctionPass { } }; -} // end of namespace +} // end anonymous namespace + bool LoopInterchangeLegality::areAllUsesReductions(Instruction *Ins, Loop *L) { - return none_of(Ins->users(), [=](User *U) -> bool { + return llvm::none_of(Ins->users(), [=](User *U) -> bool { auto *UserIns = dyn_cast<PHINode>(U); RecurrenceDescriptor RD; return !UserIns || !RecurrenceDescriptor::isReductionPHI(UserIns, L, RD); @@ -664,11 +672,9 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { if (!OuterLoopHeaderBI) return false; - for (unsigned i = 0, e = OuterLoopHeaderBI->getNumSuccessors(); i < e; ++i) { - if (OuterLoopHeaderBI->getSuccessor(i) != InnerLoopPreHeader && - OuterLoopHeaderBI->getSuccessor(i) != OuterLoopLatch) + for (BasicBlock *Succ : OuterLoopHeaderBI->successors()) + if (Succ != InnerLoopPreHeader && Succ != OuterLoopLatch) return false; - } DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); // We do not have any basic block in between now make sure the outer header @@ -682,10 +688,8 @@ bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { return true; } - bool LoopInterchangeLegality::isLoopStructureUnderstood( PHINode *InnerInduction) { - unsigned Num = InnerInduction->getNumOperands(); BasicBlock *InnerLoopPreheader = InnerLoop->getLoopPreheader(); for (unsigned i = 0; i < Num; ++i) { @@ -750,12 +754,12 @@ static bool containsSafePHI(BasicBlock *Block, bool isOuterLoopExitBlock) { static BasicBlock *getLoopLatchExitBlock(BasicBlock *LatchBlock, BasicBlock *LoopHeader) { if (BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator())) { - unsigned Num = BI->getNumSuccessors(); - assert(Num == 2); - for (unsigned i = 0; i < Num; ++i) { - if (BI->getSuccessor(i) == LoopHeader) + assert(BI->getNumSuccessors() == 2 && + "Branch leaving loop latch must have 2 successors"); + for (BasicBlock *Succ : BI->successors()) { + if (Succ == LoopHeader) continue; - return BI->getSuccessor(i); + return Succ; } } return nullptr; @@ -764,7 +768,6 @@ static BasicBlock *getLoopLatchExitBlock(BasicBlock *LatchBlock, // This function indicates the current limitations in the transform as a result // of which we do not proceed. bool LoopInterchangeLegality::currentLimitations() { - BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); @@ -777,12 +780,13 @@ bool LoopInterchangeLegality::currentLimitations() { if (!findInductionAndReductions(InnerLoop, Inductions, Reductions)) { DEBUG(dbgs() << "Only inner loops with induction or reduction PHI nodes " << "are supported currently.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "UnsupportedPHIInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Only inner loops with induction or reduction PHI nodes can be" - " interchange currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with induction or reduction PHI nodes can be" + " interchange currently."; + }); return true; } @@ -790,12 +794,13 @@ bool LoopInterchangeLegality::currentLimitations() { if (Inductions.size() != 1) { DEBUG(dbgs() << "We currently only support loops with 1 induction variable." << "Failed to interchange due to current limitation\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "MultiInductionInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Only inner loops with 1 induction variable can be " - "interchanged currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "MultiInductionInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with 1 induction variable can be " + "interchanged currently."; + }); return true; } if (Reductions.size() > 0) @@ -806,12 +811,13 @@ bool LoopInterchangeLegality::currentLimitations() { if (!findInductionAndReductions(OuterLoop, Inductions, Reductions)) { DEBUG(dbgs() << "Only outer loops with induction or reduction PHI nodes " << "are supported currently.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "UnsupportedPHIOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Only outer loops with induction or reduction PHI nodes can be" - " interchanged currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with induction or reduction PHI nodes can be" + " interchanged currently."; + }); return true; } @@ -820,35 +826,38 @@ bool LoopInterchangeLegality::currentLimitations() { if (!Reductions.empty()) { DEBUG(dbgs() << "Outer loops with reductions are not supported " << "currently.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "ReductionsOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Outer loops with reductions cannot be interchangeed " - "currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ReductionsOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Outer loops with reductions cannot be interchangeed " + "currently."; + }); return true; } // TODO: Currently we handle only loops with 1 induction variable. if (Inductions.size() != 1) { DEBUG(dbgs() << "Loops with more than 1 induction variables are not " << "supported currently.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "MultiIndutionOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Only outer loops with 1 induction variable can be " - "interchanged currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "MultiIndutionOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with 1 induction variable can be " + "interchanged currently."; + }); return true; } // TODO: Triangular loops are not handled for now. if (!isLoopStructureUnderstood(InnerInductionVar)) { DEBUG(dbgs() << "Loop structure not understood by pass\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "UnsupportedStructureInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Inner loop structure not understood currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedStructureInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Inner loop structure not understood currently."; + }); return true; } @@ -857,24 +866,26 @@ bool LoopInterchangeLegality::currentLimitations() { getLoopLatchExitBlock(OuterLoopLatch, OuterLoopHeader); if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, true)) { DEBUG(dbgs() << "Can only handle LCSSA PHIs in outer loops currently.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "NoLCSSAPHIOuter", - OuterLoop->getStartLoc(), - OuterLoop->getHeader()) - << "Only outer loops with LCSSA PHIs can be interchange " - "currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoLCSSAPHIOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with LCSSA PHIs can be interchange " + "currently."; + }); return true; } LoopExitBlock = getLoopLatchExitBlock(InnerLoopLatch, InnerLoopHeader); if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, false)) { DEBUG(dbgs() << "Can only handle LCSSA PHIs in inner loops currently.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "NoLCSSAPHIOuterInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Only inner loops with LCSSA PHIs can be interchange " - "currently."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoLCSSAPHIOuterInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with LCSSA PHIs can be interchange " + "currently."; + }); return true; } @@ -899,11 +910,12 @@ bool LoopInterchangeLegality::currentLimitations() { if (!InnerIndexVarInc) { DEBUG(dbgs() << "Did not find an instruction to increment the induction " << "variable.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "NoIncrementInInner", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "The inner loop does not increment the induction variable."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoIncrementInInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "The inner loop does not increment the induction variable."; + }); return true; } @@ -912,8 +924,9 @@ bool LoopInterchangeLegality::currentLimitations() { // instruction. bool FoundInduction = false; - for (const Instruction &I : reverse(*InnerLoopLatch)) { - if (isa<BranchInst>(I) || isa<CmpInst>(I) || isa<TruncInst>(I)) + for (const Instruction &I : llvm::reverse(*InnerLoopLatch)) { + if (isa<BranchInst>(I) || isa<CmpInst>(I) || isa<TruncInst>(I) || + isa<ZExtInst>(I)) continue; // We found an instruction. If this is not induction variable then it is not @@ -921,12 +934,13 @@ bool LoopInterchangeLegality::currentLimitations() { if (!I.isIdenticalTo(InnerIndexVarInc)) { DEBUG(dbgs() << "Found unsupported instructions between induction " << "variable increment and branch.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "UnsupportedInsBetweenInduction", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Found unsupported instruction between induction variable " - "increment and branch."); + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "UnsupportedInsBetweenInduction", + InnerLoop->getStartLoc(), InnerLoop->getHeader()) + << "Found unsupported instruction between induction variable " + "increment and branch."; + }); return true; } @@ -937,11 +951,12 @@ bool LoopInterchangeLegality::currentLimitations() { // current limitation. if (!FoundInduction) { DEBUG(dbgs() << "Did not find the induction variable.\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "NoIndutionVariable", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Did not find the induction variable."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoIndutionVariable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Did not find the induction variable."; + }); return true; } return false; @@ -950,19 +965,31 @@ bool LoopInterchangeLegality::currentLimitations() { bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId, CharMatrix &DepMatrix) { - if (!isLegalToInterChangeLoops(DepMatrix, InnerLoopId, OuterLoopId)) { DEBUG(dbgs() << "Failed interchange InnerLoopId = " << InnerLoopId << " and OuterLoopId = " << OuterLoopId << " due to dependence\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "Dependence", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Cannot interchange loops due to dependences."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "Dependence", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Cannot interchange loops due to dependences."; + }); return false; } + // Check if outer and inner loop contain legal instructions only. + for (auto *BB : OuterLoop->blocks()) + for (Instruction &I : *BB) + if (CallInst *CI = dyn_cast<CallInst>(&I)) { + // readnone functions do not prevent interchanging. + if (CI->doesNotReadMemory()) + continue; + DEBUG(dbgs() << "Loops with call instructions cannot be interchanged " + << "safely."); + return false; + } + // Create unique Preheaders if we already do not have one. BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); @@ -995,12 +1022,13 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, // Check if the loops are tightly nested. if (!tightlyNested(OuterLoop, InnerLoop)) { DEBUG(dbgs() << "Loops not tightly nested\n"); - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "NotTightlyNested", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Cannot interchange loops because they are not tightly " - "nested."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotTightlyNested", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Cannot interchange loops because they are not tightly " + "nested."; + }); return false; } @@ -1010,9 +1038,8 @@ bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, int LoopInterchangeProfitability::getInstrOrderCost() { unsigned GoodOrder, BadOrder; BadOrder = GoodOrder = 0; - for (auto BI = InnerLoop->block_begin(), BE = InnerLoop->block_end(); - BI != BE; ++BI) { - for (Instruction &Ins : **BI) { + for (BasicBlock *BB : InnerLoop->blocks()) { + for (Instruction &Ins : *BB) { if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&Ins)) { unsigned NumOp = GEP->getNumOperands(); bool FoundInnerInduction = false; @@ -1064,12 +1091,11 @@ static bool isProfitableForVectorization(unsigned InnerLoopId, // TODO: Improve this heuristic to catch more cases. // If the inner loop is loop independent or doesn't carry any dependency it is // profitable to move this to outer position. - unsigned Row = DepMatrix.size(); - for (unsigned i = 0; i < Row; ++i) { - if (DepMatrix[i][InnerLoopId] != 'S' && DepMatrix[i][InnerLoopId] != 'I') + for (auto &Row : DepMatrix) { + if (Row[InnerLoopId] != 'S' && Row[InnerLoopId] != 'I') return false; // TODO: We need to improve this heuristic. - if (DepMatrix[i][OuterLoopId] != '=') + if (Row[OuterLoopId] != '=') return false; } // If outer loop has dependence and inner loop is loop independent then it is @@ -1080,7 +1106,6 @@ static bool isProfitableForVectorization(unsigned InnerLoopId, bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, unsigned OuterLoopId, CharMatrix &DepMatrix) { - // TODO: Add better profitability checks. // e.g // 1) Construct dependency matrix and move the one with no loop carried dep @@ -1099,14 +1124,15 @@ bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, if (isProfitableForVectorization(InnerLoopId, OuterLoopId, DepMatrix)) return true; - ORE->emit(OptimizationRemarkMissed(DEBUG_TYPE, - "InterchangeNotProfitable", - InnerLoop->getStartLoc(), - InnerLoop->getHeader()) - << "Interchanging loops is too costly (cost=" - << ore::NV("Cost", Cost) << ", threshold=" - << ore::NV("Threshold", LoopInterchangeCostThreshold) << - ") and it does not improve parallelism."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "InterchangeNotProfitable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Interchanging loops is too costly (cost=" + << ore::NV("Cost", Cost) << ", threshold=" + << ore::NV("Threshold", LoopInterchangeCostThreshold) + << ") and it does not improve parallelism."; + }); return false; } @@ -1145,7 +1171,7 @@ bool LoopInterchangeTransform::transform() { bool Transformed = false; Instruction *InnerIndexVar; - if (InnerLoop->getSubLoops().size() == 0) { + if (InnerLoop->getSubLoops().empty()) { BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); DEBUG(dbgs() << "Calling Split Inner Loop\n"); PHINode *InductionPHI = getInductionVariable(InnerLoop, SE); @@ -1159,7 +1185,11 @@ bool LoopInterchangeTransform::transform() { else InnerIndexVar = dyn_cast<Instruction>(InductionPHI->getIncomingValue(0)); - // + // Ensure that InductionPHI is the first Phi node as required by + // splitInnerLoopHeader + if (&InductionPHI->getParent()->front() != InductionPHI) + InductionPHI->moveBefore(&InductionPHI->getParent()->front()); + // Split at the place were the induction variable is // incremented/decremented. // TODO: This splitting logic may not work always. Fix this. @@ -1188,13 +1218,12 @@ void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) { } void LoopInterchangeTransform::splitInnerLoopHeader() { - // Split the inner loop header out. Here make sure that the reduction PHI's // stay in the innerloop body. BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); if (InnerLoopHasReduction) { - // FIXME: Check if the induction PHI will always be the first PHI. + // Note: The induction PHI must be the first PHI for this to work BasicBlock *New = InnerLoopHeader->splitBasicBlock( ++(InnerLoopHeader->begin()), InnerLoopHeader->getName() + ".split"); if (LI) @@ -1244,7 +1273,6 @@ void LoopInterchangeTransform::updateIncomingBlock(BasicBlock *CurrBlock, } bool LoopInterchangeTransform::adjustLoopBranches() { - DEBUG(dbgs() << "adjustLoopBranches called\n"); // Adjust the loop preheader BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); @@ -1352,8 +1380,8 @@ bool LoopInterchangeTransform::adjustLoopBranches() { return true; } -void LoopInterchangeTransform::adjustLoopPreheaders() { +void LoopInterchangeTransform::adjustLoopPreheaders() { // We have interchanged the preheaders so we need to interchange the data in // the preheader as well. // This is because the content of inner preheader was previously executed @@ -1373,7 +1401,6 @@ void LoopInterchangeTransform::adjustLoopPreheaders() { } bool LoopInterchangeTransform::adjustLoopLinks() { - // Adjust all branches in the inner and outer loop. bool Changed = adjustLoopBranches(); if (Changed) @@ -1382,6 +1409,7 @@ bool LoopInterchangeTransform::adjustLoopLinks() { } char LoopInterchange::ID = 0; + INITIALIZE_PASS_BEGIN(LoopInterchange, "loop-interchange", "Interchanges loops for cache reuse", false, false) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) diff --git a/lib/Transforms/Scalar/LoopLoadElimination.cpp b/lib/Transforms/Scalar/LoopLoadElimination.cpp index 20b37c4b70e6..dfa5ec1f354d 100644 --- a/lib/Transforms/Scalar/LoopLoadElimination.cpp +++ b/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -28,22 +28,29 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/LoopVersioning.h" #include <algorithm> @@ -53,11 +60,11 @@ #include <tuple> #include <utility> +using namespace llvm; + #define LLE_OPTION "loop-load-elim" #define DEBUG_TYPE LLE_OPTION -using namespace llvm; - static cl::opt<unsigned> CheckPerElim( "runtime-check-per-loop-load-elim", cl::Hidden, cl::desc("Max number of memchecks allowed per eliminated load on average"), @@ -127,10 +134,12 @@ struct StoreToLoadForwardingCandidate { #endif }; +} // end anonymous namespace + /// \brief Check if the store dominates all latches, so as long as there is no /// intervening store this value will be loaded in the next iteration. -bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, - DominatorTree *DT) { +static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, + DominatorTree *DT) { SmallVector<BasicBlock *, 8> Latches; L->getLoopLatches(Latches); return llvm::all_of(Latches, [&](const BasicBlock *Latch) { @@ -143,6 +152,8 @@ static bool isLoadConditional(LoadInst *Load, Loop *L) { return Load->getParent() != L->getHeader(); } +namespace { + /// \brief The per-loop class that does most of the work. class LoadEliminationForLoop { public: @@ -241,8 +252,8 @@ public: std::forward_list<StoreToLoadForwardingCandidate> &Candidates) { // If Store is nullptr it means that we have multiple stores forwarding to // this store. - typedef DenseMap<LoadInst *, const StoreToLoadForwardingCandidate *> - LoadToSingleCandT; + using LoadToSingleCandT = + DenseMap<LoadInst *, const StoreToLoadForwardingCandidate *>; LoadToSingleCandT LoadToSingleCand; for (const auto &Cand : Candidates) { @@ -393,7 +404,6 @@ public: void propagateStoredValueToLoadUsers(const StoreToLoadForwardingCandidate &Cand, SCEVExpander &SEE) { - // // loop: // %x = load %gep_i // = ... %x @@ -431,6 +441,7 @@ public: bool processLoop() { DEBUG(dbgs() << "\nIn \"" << L->getHeader()->getParent()->getName() << "\" checking " << *L << "\n"); + // Look for store-to-load forwarding cases across the // backedge. E.g.: // @@ -558,6 +569,8 @@ private: PredicatedScalarEvolution PSE; }; +} // end anonymous namespace + static bool eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, function_ref<const LoopAccessInfo &(Loop &)> GetLAI) { @@ -584,10 +597,14 @@ eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, return Changed; } +namespace { + /// \brief The pass. Most of the work is delegated to the per-loop /// LoadEliminationForLoop class. class LoopLoadElimination : public FunctionPass { public: + static char ID; + LoopLoadElimination() : FunctionPass(ID) { initializeLoopLoadEliminationPass(*PassRegistry::getPassRegistry()); } @@ -616,13 +633,12 @@ public: AU.addPreserved<DominatorTreeWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } - - static char ID; }; } // end anonymous namespace char LoopLoadElimination::ID; + static const char LLE_name[] = "Loop Load Elimination"; INITIALIZE_PASS_BEGIN(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) @@ -633,9 +649,7 @@ INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopSimplify) INITIALIZE_PASS_END(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) -namespace llvm { - -FunctionPass *createLoopLoadEliminationPass() { +FunctionPass *llvm::createLoopLoadEliminationPass() { return new LoopLoadElimination(); } @@ -652,7 +666,8 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F, auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); bool Changed = eliminateLoadsAcrossLoops( F, LI, DT, [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, + SE, TLI, TTI, nullptr}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }); @@ -662,5 +677,3 @@ PreservedAnalyses LoopLoadEliminationPass::run(Function &F, PreservedAnalyses PA; return PA; } - -} // end namespace llvm diff --git a/lib/Transforms/Scalar/LoopPredication.cpp b/lib/Transforms/Scalar/LoopPredication.cpp index 9b12ba180444..2e4c7b19e476 100644 --- a/lib/Transforms/Scalar/LoopPredication.cpp +++ b/lib/Transforms/Scalar/LoopPredication.cpp @@ -34,6 +34,143 @@ // else // deoptimize // +// It's tempting to rely on SCEV here, but it has proven to be problematic. +// Generally the facts SCEV provides about the increment step of add +// recurrences are true if the backedge of the loop is taken, which implicitly +// assumes that the guard doesn't fail. Using these facts to optimize the +// guard results in a circular logic where the guard is optimized under the +// assumption that it never fails. +// +// For example, in the loop below the induction variable will be marked as nuw +// basing on the guard. Basing on nuw the guard predicate will be considered +// monotonic. Given a monotonic condition it's tempting to replace the induction +// variable in the condition with its value on the last iteration. But this +// transformation is not correct, e.g. e = 4, b = 5 breaks the loop. +// +// for (int i = b; i != e; i++) +// guard(i u< len) +// +// One of the ways to reason about this problem is to use an inductive proof +// approach. Given the loop: +// +// if (B(0)) { +// do { +// I = PHI(0, I.INC) +// I.INC = I + Step +// guard(G(I)); +// } while (B(I)); +// } +// +// where B(x) and G(x) are predicates that map integers to booleans, we want a +// loop invariant expression M such the following program has the same semantics +// as the above: +// +// if (B(0)) { +// do { +// I = PHI(0, I.INC) +// I.INC = I + Step +// guard(G(0) && M); +// } while (B(I)); +// } +// +// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step) +// +// Informal proof that the transformation above is correct: +// +// By the definition of guards we can rewrite the guard condition to: +// G(I) && G(0) && M +// +// Let's prove that for each iteration of the loop: +// G(0) && M => G(I) +// And the condition above can be simplified to G(Start) && M. +// +// Induction base. +// G(0) && M => G(0) +// +// Induction step. Assuming G(0) && M => G(I) on the subsequent +// iteration: +// +// B(I) is true because it's the backedge condition. +// G(I) is true because the backedge is guarded by this condition. +// +// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step). +// +// Note that we can use anything stronger than M, i.e. any condition which +// implies M. +// +// When S = 1 (i.e. forward iterating loop), the transformation is supported +// when: +// * The loop has a single latch with the condition of the form: +// B(X) = latchStart + X <pred> latchLimit, +// where <pred> is u<, u<=, s<, or s<=. +// * The guard condition is of the form +// G(X) = guardStart + X u< guardLimit +// +// For the ult latch comparison case M is: +// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit => +// guardStart + X + 1 u< guardLimit +// +// The only way the antecedent can be true and the consequent can be false is +// if +// X == guardLimit - 1 - guardStart +// (and guardLimit is non-zero, but we won't use this latter fact). +// If X == guardLimit - 1 - guardStart then the second half of the antecedent is +// latchStart + guardLimit - 1 - guardStart u< latchLimit +// and its negation is +// latchStart + guardLimit - 1 - guardStart u>= latchLimit +// +// In other words, if +// latchLimit u<= latchStart + guardLimit - 1 - guardStart +// then: +// (the ranges below are written in ConstantRange notation, where [A, B) is the +// set for (I = A; I != B; I++ /*maywrap*/) yield(I);) +// +// forall X . guardStart + X u< guardLimit && +// latchStart + X u< latchLimit => +// guardStart + X + 1 u< guardLimit +// == forall X . guardStart + X u< guardLimit && +// latchStart + X u< latchStart + guardLimit - 1 - guardStart => +// guardStart + X + 1 u< guardLimit +// == forall X . (guardStart + X) in [0, guardLimit) && +// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) => +// (guardStart + X + 1) in [0, guardLimit) +// == forall X . X in [-guardStart, guardLimit - guardStart) && +// X in [-latchStart, guardLimit - 1 - guardStart) => +// X in [-guardStart - 1, guardLimit - guardStart - 1) +// == true +// +// So the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart u>= latchLimit +// Similarly for ule condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart u> latchLimit +// For slt condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart s>= latchLimit +// For sle condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart s> latchLimit +// +// When S = -1 (i.e. reverse iterating loop), the transformation is supported +// when: +// * The loop has a single latch with the condition of the form: +// B(X) = X <pred> latchLimit, where <pred> is u> or s>. +// * The guard condition is of the form +// G(X) = X - 1 u< guardLimit +// +// For the ugt latch comparison case M is: +// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit +// +// The only way the antecedent can be true and the consequent can be false is if +// X == 1. +// If X == 1 then the second half of the antecedent is +// 1 u> latchLimit, and its negation is latchLimit u>= 1. +// +// So the widened condition is: +// guardStart u< guardLimit && latchLimit u>= 1. +// Similarly for sgt condition the widened condition is: +// guardStart u< guardLimit && latchLimit s>= 1. //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopPredication.h" @@ -56,6 +193,11 @@ using namespace llvm; +static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation", + cl::Hidden, cl::init(true)); + +static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop", + cl::Hidden, cl::init(true)); namespace { class LoopPredication { /// Represents an induction variable check: @@ -68,6 +210,10 @@ class LoopPredication { const SCEV *Limit) : Pred(Pred), IV(IV), Limit(Limit) {} LoopICmp() {} + void dump() { + dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV + << ", Limit = " << *Limit << "\n"; + } }; ScalarEvolution *SE; @@ -75,17 +221,51 @@ class LoopPredication { Loop *L; const DataLayout *DL; BasicBlock *Preheader; + LoopICmp LatchCheck; - Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI); + bool isSupportedStep(const SCEV* Step); + Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) { + return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0), + ICI->getOperand(1)); + } + Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, + Value *RHS); + + Optional<LoopICmp> parseLoopLatchICmp(); + bool CanExpand(const SCEV* S); Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Instruction *InsertAt); Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, IRBuilder<> &Builder); + Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, + LoopICmp RangeCheck, + SCEVExpander &Expander, + IRBuilder<> &Builder); + Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, + LoopICmp RangeCheck, + SCEVExpander &Expander, + IRBuilder<> &Builder); bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + // When the IV type is wider than the range operand type, we can still do loop + // predication, by generating SCEVs for the range and latch that are of the + // same type. We achieve this by generating a SCEV truncate expression for the + // latch IV. This is done iff truncation of the IV is a safe operation, + // without loss of information. + // Another way to achieve this is by generating a wider type SCEV for the + // range check operand, however, this needs a more involved check that + // operands do not overflow. This can lead to loss of information when the + // range operand is of the form: add i32 %offset, %iv. We need to prove that + // sext(x + y) is same as sext(x) + sext(y). + // This function returns true if we can safely represent the IV type in + // the RangeCheckType without loss of information. + bool isSafeToTruncateWideIVType(Type *RangeCheckType); + // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do + // so. + Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType); public: LoopPredication(ScalarEvolution *SE) : SE(SE){}; bool runOnLoop(Loop *L); @@ -135,11 +315,8 @@ PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, } Optional<LoopPredication::LoopICmp> -LoopPredication::parseLoopICmp(ICmpInst *ICI) { - ICmpInst::Predicate Pred = ICI->getPredicate(); - - Value *LHS = ICI->getOperand(0); - Value *RHS = ICI->getOperand(1); +LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, + Value *RHS) { const SCEV *LHSS = SE->getSCEV(LHS); if (isa<SCEVCouldNotCompute>(LHSS)) return None; @@ -165,13 +342,146 @@ Value *LoopPredication::expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, Instruction *InsertAt) { + // TODO: we can check isLoopEntryGuardedByCond before emitting the check + Type *Ty = LHS->getType(); assert(Ty == RHS->getType() && "expandCheck operands have different types?"); + + if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) + return Builder.getTrue(); + Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt); return Builder.CreateICmp(Pred, LHSV, RHSV); } +Optional<LoopPredication::LoopICmp> +LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) { + + auto *LatchType = LatchCheck.IV->getType(); + if (RangeCheckType == LatchType) + return LatchCheck; + // For now, bail out if latch type is narrower than range type. + if (DL->getTypeSizeInBits(LatchType) < DL->getTypeSizeInBits(RangeCheckType)) + return None; + if (!isSafeToTruncateWideIVType(RangeCheckType)) + return None; + // We can now safely identify the truncated version of the IV and limit for + // RangeCheckType. + LoopICmp NewLatchCheck; + NewLatchCheck.Pred = LatchCheck.Pred; + NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( + SE->getTruncateExpr(LatchCheck.IV, RangeCheckType)); + if (!NewLatchCheck.IV) + return None; + NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType); + DEBUG(dbgs() << "IV of type: " << *LatchType + << "can be represented as range check type:" << *RangeCheckType + << "\n"); + DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n"); + DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n"); + return NewLatchCheck; +} + +bool LoopPredication::isSupportedStep(const SCEV* Step) { + return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop); +} + +bool LoopPredication::CanExpand(const SCEV* S) { + return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); +} + +Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( + LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, + SCEVExpander &Expander, IRBuilder<> &Builder) { + auto *Ty = RangeCheck.IV->getType(); + // Generate the widened condition for the forward loop: + // guardStart u< guardLimit && + // latchLimit <pred> guardLimit - 1 - guardStart + latchStart + // where <pred> depends on the latch condition predicate. See the file + // header comment for the reasoning. + // guardLimit - guardStart + latchStart - 1 + const SCEV *GuardStart = RangeCheck.IV->getStart(); + const SCEV *GuardLimit = RangeCheck.Limit; + const SCEV *LatchStart = LatchCheck.IV->getStart(); + const SCEV *LatchLimit = LatchCheck.Limit; + + // guardLimit - guardStart + latchStart - 1 + const SCEV *RHS = + SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart), + SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); + if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || + !CanExpand(LatchLimit) || !CanExpand(RHS)) { + DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + ICmpInst::Predicate LimitCheckPred; + switch (LatchCheck.Pred) { + case ICmpInst::ICMP_ULT: + LimitCheckPred = ICmpInst::ICMP_ULE; + break; + case ICmpInst::ICMP_ULE: + LimitCheckPred = ICmpInst::ICMP_ULT; + break; + case ICmpInst::ICMP_SLT: + LimitCheckPred = ICmpInst::ICMP_SLE; + break; + case ICmpInst::ICMP_SLE: + LimitCheckPred = ICmpInst::ICMP_SLT; + break; + default: + llvm_unreachable("Unsupported loop latch!"); + } + + DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); + DEBUG(dbgs() << "RHS: " << *RHS << "\n"); + DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); + + Instruction *InsertAt = Preheader->getTerminator(); + auto *LimitCheck = + expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt); + auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred, + GuardStart, GuardLimit, InsertAt); + return Builder.CreateAnd(FirstIterationCheck, LimitCheck); +} + +Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( + LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, + SCEVExpander &Expander, IRBuilder<> &Builder) { + auto *Ty = RangeCheck.IV->getType(); + const SCEV *GuardStart = RangeCheck.IV->getStart(); + const SCEV *GuardLimit = RangeCheck.Limit; + const SCEV *LatchLimit = LatchCheck.Limit; + if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || + !CanExpand(LatchLimit)) { + DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + // The decrement of the latch check IV should be the same as the + // rangeCheckIV. + auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE); + if (RangeCheck.IV != PostDecLatchCheckIV) { + DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " + << *PostDecLatchCheckIV + << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); + return None; + } + + // Generate the widened condition for CountDownLoop: + // guardStart u< guardLimit && + // latchLimit <pred> 1. + // See the header comment for reasoning of the checks. + Instruction *InsertAt = Preheader->getTerminator(); + auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred) + ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_UGE; + auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT, + GuardStart, GuardLimit, InsertAt); + auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, + SE->getOne(Ty), InsertAt); + return Builder.CreateAnd(FirstIterationCheck, LimitCheck); +} + /// If ICI can be widened to a loop invariant condition emits the loop /// invariant condition in the loop preheader and return it, otherwise /// returns None. @@ -181,51 +491,62 @@ Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); DEBUG(ICI->dump()); + // parseLoopStructure guarantees that the latch condition is: + // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=. + // We are looking for the range checks of the form: + // i u< guardLimit auto RangeCheck = parseLoopICmp(ICI); if (!RangeCheck) { DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); return None; } - - ICmpInst::Predicate Pred = RangeCheck->Pred; - const SCEVAddRecExpr *IndexAR = RangeCheck->IV; - const SCEV *RHSS = RangeCheck->Limit; - - auto CanExpand = [this](const SCEV *S) { - return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); - }; - if (!CanExpand(RHSS)) + DEBUG(dbgs() << "Guard check:\n"); + DEBUG(RangeCheck->dump()); + if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { + DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred + << ")!\n"); return None; - - DEBUG(dbgs() << "IndexAR: "); - DEBUG(IndexAR->dump()); - - bool IsIncreasing = false; - if (!SE->isMonotonicPredicate(IndexAR, Pred, IsIncreasing)) + } + auto *RangeCheckIV = RangeCheck->IV; + if (!RangeCheckIV->isAffine()) { + DEBUG(dbgs() << "Range check IV is not affine!\n"); return None; - - // If the predicate is increasing the condition can change from false to true - // as the loop progresses, in this case take the value on the first iteration - // for the widened check. Otherwise the condition can change from true to - // false as the loop progresses, so take the value on the last iteration. - const SCEV *NewLHSS = IsIncreasing - ? IndexAR->getStart() - : SE->getSCEVAtScope(IndexAR, L->getParentLoop()); - if (NewLHSS == IndexAR) { - DEBUG(dbgs() << "Can't compute NewLHSS!\n"); + } + auto *Step = RangeCheckIV->getStepRecurrence(*SE); + // We cannot just compare with latch IV step because the latch and range IVs + // may have different types. + if (!isSupportedStep(Step)) { + DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); return None; } - - DEBUG(dbgs() << "NewLHSS: "); - DEBUG(NewLHSS->dump()); - - if (!CanExpand(NewLHSS)) + auto *Ty = RangeCheckIV->getType(); + auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty); + if (!CurrLatchCheckOpt) { + DEBUG(dbgs() << "Failed to generate a loop latch check " + "corresponding to range type: " + << *Ty << "\n"); return None; + } - DEBUG(dbgs() << "NewLHSS is loop invariant and safe to expand. Expand!\n"); + LoopICmp CurrLatchCheck = *CurrLatchCheckOpt; + // At this point, the range and latch step should have the same type, but need + // not have the same value (we support both 1 and -1 steps). + assert(Step->getType() == + CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() && + "Range and latch steps should be of same type!"); + if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) { + DEBUG(dbgs() << "Range and latch have different step values!\n"); + return None; + } - Instruction *InsertAt = Preheader->getTerminator(); - return expandCheck(Expander, Builder, Pred, NewLHSS, RHSS, InsertAt); + if (Step->isOne()) + return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck, + Expander, Builder); + else { + assert(Step->isAllOnesValue() && "Step should be -1!"); + return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck, + Expander, Builder); + } } bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, @@ -288,6 +609,97 @@ bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, return true; } +Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { + using namespace PatternMatch; + + BasicBlock *LoopLatch = L->getLoopLatch(); + if (!LoopLatch) { + DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); + return None; + } + + ICmpInst::Predicate Pred; + Value *LHS, *RHS; + BasicBlock *TrueDest, *FalseDest; + + if (!match(LoopLatch->getTerminator(), + m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest, + FalseDest))) { + DEBUG(dbgs() << "Failed to match the latch terminator!\n"); + return None; + } + assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) && + "One of the latch's destinations must be the header"); + if (TrueDest != L->getHeader()) + Pred = ICmpInst::getInversePredicate(Pred); + + auto Result = parseLoopICmp(Pred, LHS, RHS); + if (!Result) { + DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + return None; + } + + // Check affine first, so if it's not we don't try to compute the step + // recurrence. + if (!Result->IV->isAffine()) { + DEBUG(dbgs() << "The induction variable is not affine!\n"); + return None; + } + + auto *Step = Result->IV->getStepRecurrence(*SE); + if (!isSupportedStep(Step)) { + DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); + return None; + } + + auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) { + if (Step->isOne()) { + return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT && + Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE; + } else { + assert(Step->isAllOnesValue() && "Step should be -1!"); + return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT; + } + }; + + if (IsUnsupportedPredicate(Step, Result->Pred)) { + DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred + << ")!\n"); + return None; + } + return Result; +} + +// Returns true if its safe to truncate the IV to RangeCheckType. +bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) { + if (!EnableIVTruncation) + return false; + assert(DL->getTypeSizeInBits(LatchCheck.IV->getType()) > + DL->getTypeSizeInBits(RangeCheckType) && + "Expected latch check IV type to be larger than range check operand " + "type!"); + // The start and end values of the IV should be known. This is to guarantee + // that truncating the wide type will not lose information. + auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit); + auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart()); + if (!Limit || !Start) + return false; + // This check makes sure that the IV does not change sign during loop + // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, + // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the + // IV wraps around, and the truncation of the IV would lose the range of + // iterations between 2^32 and 2^64. + bool Increasing; + if (!SE->isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing)) + return false; + // The active bits should be less than the bits in the RangeCheckType. This + // guarantees that truncating the latch check to RangeCheckType is a safe + // operation. + auto RangeCheckTypeBitSize = DL->getTypeSizeInBits(RangeCheckType); + return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && + Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; +} + bool LoopPredication::runOnLoop(Loop *Loop) { L = Loop; @@ -308,6 +720,14 @@ bool LoopPredication::runOnLoop(Loop *Loop) { if (!Preheader) return false; + auto LatchCheckOpt = parseLoopLatchICmp(); + if (!LatchCheckOpt) + return false; + LatchCheck = *LatchCheckOpt; + + DEBUG(dbgs() << "Latch check:\n"); + DEBUG(LatchCheck.dump()); + // Collect all the guards into a vector and process later, so as not // to invalidate the instruction iterator. SmallVector<IntrinsicInst *, 4> Guards; diff --git a/lib/Transforms/Scalar/LoopRerollPass.cpp b/lib/Transforms/Scalar/LoopRerollPass.cpp index fc0216e76a5b..d1a54b877950 100644 --- a/lib/Transforms/Scalar/LoopRerollPass.cpp +++ b/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -1,4 +1,4 @@ -//===-- LoopReroll.cpp - Loop rerolling pass ------------------------------===// +//===- LoopReroll.cpp - Loop rerolling pass -------------------------------===// // // The LLVM Compiler Infrastructure // @@ -11,22 +11,42 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" #include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -34,6 +54,13 @@ #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <iterator> +#include <map> +#include <utility> using namespace llvm; @@ -127,6 +154,7 @@ NumToleratedFailedMatches("reroll-num-tolerated-failed-matches", cl::init(400), // br %cmp, header, exit namespace { + enum IterationLimits { /// The maximum number of iterations that we'll try and reroll. IL_MaxRerollIterations = 32, @@ -139,6 +167,7 @@ namespace { class LoopReroll : public LoopPass { public: static char ID; // Pass ID, replacement for typeid + LoopReroll() : LoopPass(ID) { initializeLoopRerollPass(*PassRegistry::getPassRegistry()); } @@ -158,11 +187,12 @@ namespace { DominatorTree *DT; bool PreserveLCSSA; - typedef SmallVector<Instruction *, 16> SmallInstructionVector; - typedef SmallSet<Instruction *, 16> SmallInstructionSet; + using SmallInstructionVector = SmallVector<Instruction *, 16>; + using SmallInstructionSet = SmallSet<Instruction *, 16>; // Map between induction variable and its increment DenseMap<Instruction *, int64_t> IVToIncMap; + // For loop with multiple induction variable, remember the one used only to // control the loop. Instruction *LoopControlIV; @@ -171,8 +201,7 @@ namespace { // representing a reduction. Only the last value may be used outside the // loop. struct SimpleLoopReduction { - SimpleLoopReduction(Instruction *P, Loop *L) - : Valid(false), Instructions(1, P) { + SimpleLoopReduction(Instruction *P, Loop *L) : Instructions(1, P) { assert(isa<PHINode>(P) && "First reduction instruction must be a PHI"); add(L); } @@ -204,8 +233,8 @@ namespace { return Instructions.size()-1; } - typedef SmallInstructionVector::iterator iterator; - typedef SmallInstructionVector::const_iterator const_iterator; + using iterator = SmallInstructionVector::iterator; + using const_iterator = SmallInstructionVector::const_iterator; iterator begin() { assert(Valid && "Using invalid reduction"); @@ -221,7 +250,7 @@ namespace { const_iterator end() const { return Instructions.end(); } protected: - bool Valid; + bool Valid = false; SmallInstructionVector Instructions; void add(Loop *L); @@ -230,7 +259,7 @@ namespace { // The set of all reductions, and state tracking of possible reductions // during loop instruction processing. struct ReductionTracker { - typedef SmallVector<SimpleLoopReduction, 16> SmallReductionVector; + using SmallReductionVector = SmallVector<SimpleLoopReduction, 16>; // Add a new possible reduction. void addSLR(SimpleLoopReduction &SLR) { PossibleReds.push_back(SLR); } @@ -342,6 +371,7 @@ namespace { struct DAGRootSet { Instruction *BaseInst; SmallInstructionVector Roots; + // The instructions between IV and BaseInst (but not including BaseInst). SmallInstructionSet SubsumedInsts; }; @@ -361,15 +391,17 @@ namespace { /// Stage 1: Find all the DAG roots for the induction variable. bool findRoots(); + /// Stage 2: Validate if the found roots are valid. bool validate(ReductionTracker &Reductions); + /// Stage 3: Assuming validate() returned true, perform the /// replacement. /// @param IterCount The maximum iteration count of L. void replace(const SCEV *IterCount); protected: - typedef MapVector<Instruction*, BitVector> UsesTy; + using UsesTy = MapVector<Instruction *, BitVector>; void findRootsRecursive(Instruction *IVU, SmallInstructionSet SubsumedInsts); @@ -412,22 +444,29 @@ namespace { // The loop induction variable. Instruction *IV; + // Loop step amount. int64_t Inc; + // Loop reroll count; if Inc == 1, this records the scaling applied // to the indvar: a[i*2+0] = ...; a[i*2+1] = ... ; // If Inc is not 1, Scale = Inc. uint64_t Scale; + // The roots themselves. SmallVector<DAGRootSet,16> RootSets; + // All increment instructions for IV. SmallInstructionVector LoopIncs; + // Map of all instructions in the loop (in order) to the iterations // they are used in (or specially, IL_All for instructions // used in the loop increment mechanism). UsesTy Uses; + // Map between induction variable and its increment DenseMap<Instruction *, int64_t> &IVToIncMap; + Instruction *LoopControlIV; }; @@ -446,9 +485,11 @@ namespace { bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *IterCount, ReductionTracker &Reductions); }; -} + +} // end anonymous namespace char LoopReroll::ID = 0; + INITIALIZE_PASS_BEGIN(LoopReroll, "loop-reroll", "Reroll loops", false, false) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) @@ -1069,7 +1110,6 @@ bool LoopReroll::DAGRootTracker::collectUsedInstructions(SmallInstructionSet &Po } return true; - } /// Get the next instruction in "In" that is a member of set Val. @@ -1124,7 +1164,7 @@ static bool isIgnorableInst(const Instruction *I) { switch (II->getIntrinsicID()) { default: return false; - case llvm::Intrinsic::annotation: + case Intrinsic::annotation: case Intrinsic::ptr_annotation: case Intrinsic::var_annotation: // TODO: the following intrinsics may also be whitelisted: @@ -1407,8 +1447,8 @@ bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { BaseIt = nextInstr(0, Uses, Visited); RootIt = nextInstr(Iter, Uses, Visited); } - assert (BaseIt == Uses.end() && RootIt == Uses.end() && - "Mismatched set sizes!"); + assert(BaseIt == Uses.end() && RootIt == Uses.end() && + "Mismatched set sizes!"); } DEBUG(dbgs() << "LRR: Matched all iteration increments for " << diff --git a/lib/Transforms/Scalar/LoopRotation.cpp b/lib/Transforms/Scalar/LoopRotation.cpp index 3506ac343d59..a91f53ba663f 100644 --- a/lib/Transforms/Scalar/LoopRotation.cpp +++ b/lib/Transforms/Scalar/LoopRotation.cpp @@ -25,6 +25,7 @@ #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IntrinsicInst.h" @@ -141,37 +142,29 @@ static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, // Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug // intrinsics. - LLVMContext &C = OrigHeader->getContext(); - if (auto *VAM = ValueAsMetadata::getIfExists(OrigHeaderVal)) { - if (auto *MAV = MetadataAsValue::getIfExists(C, VAM)) { - for (auto UI = MAV->use_begin(), E = MAV->use_end(); UI != E;) { - // Grab the use before incrementing the iterator. Otherwise, altering - // the Use will invalidate the iterator. - Use &U = *UI++; - DbgInfoIntrinsic *UserInst = dyn_cast<DbgInfoIntrinsic>(U.getUser()); - if (!UserInst) - continue; - - // The original users in the OrigHeader are already using the original - // definitions. - BasicBlock *UserBB = UserInst->getParent(); - if (UserBB == OrigHeader) - continue; - - // Users in the OrigPreHeader need to use the value to which the - // original definitions are mapped and anything else can be handled by - // the SSAUpdater. To avoid adding PHINodes, check if the value is - // available in UserBB, if not substitute undef. - Value *NewVal; - if (UserBB == OrigPreheader) - NewVal = OrigPreHeaderVal; - else if (SSA.HasValueForBlock(UserBB)) - NewVal = SSA.GetValueInMiddleOfBlock(UserBB); - else - NewVal = UndefValue::get(OrigHeaderVal->getType()); - U = MetadataAsValue::get(C, ValueAsMetadata::get(NewVal)); - } - } + SmallVector<DbgValueInst *, 1> DbgValues; + llvm::findDbgValues(DbgValues, OrigHeaderVal); + for (auto &DbgValue : DbgValues) { + // The original users in the OrigHeader are already using the original + // definitions. + BasicBlock *UserBB = DbgValue->getParent(); + if (UserBB == OrigHeader) + continue; + + // Users in the OrigPreHeader need to use the value to which the + // original definitions are mapped and anything else can be handled by + // the SSAUpdater. To avoid adding PHINodes, check if the value is + // available in UserBB, if not substitute undef. + Value *NewVal; + if (UserBB == OrigPreheader) + NewVal = OrigPreHeaderVal; + else if (SSA.HasValueForBlock(UserBB)) + NewVal = SSA.GetValueInMiddleOfBlock(UserBB); + else + NewVal = UndefValue::get(OrigHeaderVal->getType()); + DbgValue->setOperand(0, + MetadataAsValue::get(OrigHeaderVal->getContext(), + ValueAsMetadata::get(NewVal))); } } } @@ -315,6 +308,22 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { // For the rest of the instructions, either hoist to the OrigPreheader if // possible or create a clone in the OldPreHeader if not. TerminatorInst *LoopEntryBranch = OrigPreheader->getTerminator(); + + // Record all debug intrinsics preceding LoopEntryBranch to avoid duplication. + using DbgIntrinsicHash = + std::pair<std::pair<Value *, DILocalVariable *>, DIExpression *>; + auto makeHash = [](DbgInfoIntrinsic *D) -> DbgIntrinsicHash { + return {{D->getVariableLocation(), D->getVariable()}, D->getExpression()}; + }; + SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics; + for (auto I = std::next(OrigPreheader->rbegin()), E = OrigPreheader->rend(); + I != E; ++I) { + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&*I)) + DbgIntrinsics.insert(makeHash(DII)); + else + break; + } + while (I != E) { Instruction *Inst = &*I++; @@ -338,6 +347,13 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { RemapInstruction(C, ValueMap, RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + // Avoid inserting the same intrinsic twice. + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(C)) + if (DbgIntrinsics.count(makeHash(DII))) { + C->deleteValue(); + continue; + } + // With the operands remapped, see if the instruction constant folds or is // otherwise simplifyable. This commonly occurs because the entry from PHI // nodes allows icmps and other instructions to fold. @@ -395,6 +411,17 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { L->moveToHeader(NewHeader); assert(L->getHeader() == NewHeader && "Latch block is our new header"); + // Inform DT about changes to the CFG. + if (DT) { + // The OrigPreheader branches to the NewHeader and Exit now. Then, inform + // the DT about the removed edge to the OrigHeader (that got removed). + SmallVector<DominatorTree::UpdateType, 3> Updates; + Updates.push_back({DominatorTree::Insert, OrigPreheader, Exit}); + Updates.push_back({DominatorTree::Insert, OrigPreheader, NewHeader}); + Updates.push_back({DominatorTree::Delete, OrigPreheader, OrigHeader}); + DT->applyUpdates(Updates); + } + // At this point, we've finished our major CFG changes. As part of cloning // the loop into the preheader we've simplified instructions and the // duplicated conditional branch may now be branching on a constant. If it is @@ -408,26 +435,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) != NewHeader) { // The conditional branch can't be folded, handle the general case. - // Update DominatorTree to reflect the CFG change we just made. Then split - // edges as necessary to preserve LoopSimplify form. - if (DT) { - // Everything that was dominated by the old loop header is now dominated - // by the original loop preheader. Conceptually the header was merged - // into the preheader, even though we reuse the actual block as a new - // loop latch. - DomTreeNode *OrigHeaderNode = DT->getNode(OrigHeader); - SmallVector<DomTreeNode *, 8> HeaderChildren(OrigHeaderNode->begin(), - OrigHeaderNode->end()); - DomTreeNode *OrigPreheaderNode = DT->getNode(OrigPreheader); - for (unsigned I = 0, E = HeaderChildren.size(); I != E; ++I) - DT->changeImmediateDominator(HeaderChildren[I], OrigPreheaderNode); - - assert(DT->getNode(Exit)->getIDom() == OrigPreheaderNode); - assert(DT->getNode(NewHeader)->getIDom() == OrigPreheaderNode); - - // Update OrigHeader to be dominated by the new header block. - DT->changeImmediateDominator(OrigHeader, OrigLatch); - } + // Split edges as necessary to preserve LoopSimplify form. // Right now OrigPreHeader has two successors, NewHeader and ExitBlock, and // thus is not a preheader anymore. @@ -467,52 +475,7 @@ bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { PHBI->eraseFromParent(); // With our CFG finalized, update DomTree if it is available. - if (DT) { - // Update OrigHeader to be dominated by the new header block. - DT->changeImmediateDominator(NewHeader, OrigPreheader); - DT->changeImmediateDominator(OrigHeader, OrigLatch); - - // Brute force incremental dominator tree update. Call - // findNearestCommonDominator on all CFG predecessors of each child of the - // original header. - DomTreeNode *OrigHeaderNode = DT->getNode(OrigHeader); - SmallVector<DomTreeNode *, 8> HeaderChildren(OrigHeaderNode->begin(), - OrigHeaderNode->end()); - bool Changed; - do { - Changed = false; - for (unsigned I = 0, E = HeaderChildren.size(); I != E; ++I) { - DomTreeNode *Node = HeaderChildren[I]; - BasicBlock *BB = Node->getBlock(); - - BasicBlock *NearestDom = nullptr; - for (BasicBlock *Pred : predecessors(BB)) { - // Consider only reachable basic blocks. - if (!DT->getNode(Pred)) - continue; - - if (!NearestDom) { - NearestDom = Pred; - continue; - } - - NearestDom = DT->findNearestCommonDominator(NearestDom, Pred); - assert(NearestDom && "No NearestCommonDominator found"); - } - - assert(NearestDom && "Nearest dominator not found"); - - // Remember if this changes the DomTree. - if (Node->getIDom()->getBlock() != NearestDom) { - DT->changeImmediateDominator(BB, NearestDom); - Changed = true; - } - } - - // If the dominator changed, this may have an effect on other - // predecessors, continue until we reach a fixpoint. - } while (Changed); - } + if (DT) DT->deleteEdge(OrigPreheader, Exit); } assert(L->getLoopPreheader() && "Invalid loop preheader after loop rotation"); @@ -671,7 +634,7 @@ bool LoopRotate::processLoop(Loop *L) { if ((MadeChange || SimplifiedLatch) && LoopMD) L->setLoopID(LoopMD); - return MadeChange; + return MadeChange || SimplifiedLatch; } LoopRotatePass::LoopRotatePass(bool EnableHeaderDuplication) diff --git a/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/lib/Transforms/Scalar/LoopStrengthReduce.cpp index 3638da118cb7..953854c8b7b7 100644 --- a/lib/Transforms/Scalar/LoopStrengthReduce.cpp +++ b/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -65,7 +65,9 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/IVUsers.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" @@ -80,13 +82,18 @@ #include "llvm/IR/Dominators.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/OperandTraits.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" @@ -98,7 +105,6 @@ #include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/LoopPassManager.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> @@ -107,8 +113,8 @@ #include <cstdint> #include <cstdlib> #include <iterator> +#include <limits> #include <map> -#include <tuple> #include <utility> using namespace llvm; @@ -131,7 +137,7 @@ static cl::opt<bool> EnablePhiElim( // The flag adds instruction count to solutions cost comparision. static cl::opt<bool> InsnsCost( - "lsr-insns-cost", cl::Hidden, cl::init(false), + "lsr-insns-cost", cl::Hidden, cl::init(true), cl::desc("Add instruction count to a LSR cost model")); // Flag to choose how to narrow complex lsr solution @@ -160,15 +166,14 @@ namespace { struct MemAccessTy { /// Used in situations where the accessed memory type is unknown. - static const unsigned UnknownAddressSpace = ~0u; + static const unsigned UnknownAddressSpace = + std::numeric_limits<unsigned>::max(); - Type *MemTy; - unsigned AddrSpace; + Type *MemTy = nullptr; + unsigned AddrSpace = UnknownAddressSpace; - MemAccessTy() : MemTy(nullptr), AddrSpace(UnknownAddressSpace) {} - - MemAccessTy(Type *Ty, unsigned AS) : - MemTy(Ty), AddrSpace(AS) {} + MemAccessTy() = default; + MemAccessTy(Type *Ty, unsigned AS) : MemTy(Ty), AddrSpace(AS) {} bool operator==(MemAccessTy Other) const { return MemTy == Other.MemTy && AddrSpace == Other.AddrSpace; @@ -195,11 +200,11 @@ public: } // end anonymous namespace +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void RegSortData::print(raw_ostream &OS) const { OS << "[NumUses=" << UsedByIndices.count() << ']'; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void RegSortData::dump() const { print(errs()); errs() << '\n'; } @@ -209,7 +214,7 @@ namespace { /// Map register candidates to information about how they are used. class RegUseTracker { - typedef DenseMap<const SCEV *, RegSortData> RegUsesTy; + using RegUsesTy = DenseMap<const SCEV *, RegSortData>; RegUsesTy RegUsesMap; SmallVector<const SCEV *, 16> RegSequence; @@ -225,8 +230,9 @@ public: void clear(); - typedef SmallVectorImpl<const SCEV *>::iterator iterator; - typedef SmallVectorImpl<const SCEV *>::const_iterator const_iterator; + using iterator = SmallVectorImpl<const SCEV *>::iterator; + using const_iterator = SmallVectorImpl<const SCEV *>::const_iterator; + iterator begin() { return RegSequence.begin(); } iterator end() { return RegSequence.end(); } const_iterator begin() const { return RegSequence.begin(); } @@ -299,16 +305,16 @@ namespace { /// satisfying a use. It may include broken-out immediates and scaled registers. struct Formula { /// Global base address used for complex addressing. - GlobalValue *BaseGV; + GlobalValue *BaseGV = nullptr; /// Base offset for complex addressing. - int64_t BaseOffset; + int64_t BaseOffset = 0; /// Whether any complex addressing has a base register. - bool HasBaseReg; + bool HasBaseReg = false; /// The scale of any complex addressing. - int64_t Scale; + int64_t Scale = 0; /// The list of "base" registers for this use. When this is non-empty. The /// canonical representation of a formula is @@ -328,16 +334,14 @@ struct Formula { /// The 'scaled' register for this use. This should be non-null when Scale is /// not zero. - const SCEV *ScaledReg; + const SCEV *ScaledReg = nullptr; /// An additional constant offset which added near the use. This requires a /// temporary register, but the offset itself can live in an add immediate /// field rather than a register. - int64_t UnfoldedOffset; + int64_t UnfoldedOffset = 0; - Formula() - : BaseGV(nullptr), BaseOffset(0), HasBaseReg(false), Scale(0), - ScaledReg(nullptr), UnfoldedOffset(0) {} + Formula() = default; void initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE); @@ -562,6 +566,7 @@ bool Formula::hasRegsUsedByUsesOtherThan(size_t LUIdx, return false; } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void Formula::print(raw_ostream &OS) const { bool First = true; if (BaseGV) { @@ -598,7 +603,6 @@ void Formula::print(raw_ostream &OS) const { } } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void Formula::dump() const { print(errs()); errs() << '\n'; } @@ -773,7 +777,8 @@ static GlobalValue *ExtractSymbol(const SCEV *&S, ScalarEvolution &SE) { /// Returns true if the specified instruction is using the specified value as an /// address. -static bool isAddressUse(Instruction *Inst, Value *OperandVal) { +static bool isAddressUse(const TargetTransformInfo &TTI, + Instruction *Inst, Value *OperandVal) { bool isAddress = isa<LoadInst>(Inst); if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { if (SI->getPointerOperand() == OperandVal) @@ -782,11 +787,24 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) { // Addressing modes can also be folded into prefetches and a variety // of intrinsics. switch (II->getIntrinsicID()) { - default: break; - case Intrinsic::prefetch: - if (II->getArgOperand(0) == OperandVal) + case Intrinsic::memset: + case Intrinsic::prefetch: + if (II->getArgOperand(0) == OperandVal) + isAddress = true; + break; + case Intrinsic::memmove: + case Intrinsic::memcpy: + if (II->getArgOperand(0) == OperandVal || + II->getArgOperand(1) == OperandVal) + isAddress = true; + break; + default: { + MemIntrinsicInfo IntrInfo; + if (TTI.getTgtMemIntrinsic(II, IntrInfo)) { + if (IntrInfo.PtrVal == OperandVal) isAddress = true; - break; + } + } } } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) { if (RMW->getPointerOperand() == OperandVal) @@ -799,7 +817,8 @@ static bool isAddressUse(Instruction *Inst, Value *OperandVal) { } /// Return the type of the memory being accessed. -static MemAccessTy getAccessType(const Instruction *Inst) { +static MemAccessTy getAccessType(const TargetTransformInfo &TTI, + Instruction *Inst) { MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace); if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { AccessTy.MemTy = SI->getOperand(0)->getType(); @@ -810,6 +829,21 @@ static MemAccessTy getAccessType(const Instruction *Inst) { AccessTy.AddrSpace = RMW->getPointerAddressSpace(); } else if (const AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { AccessTy.AddrSpace = CmpX->getPointerAddressSpace(); + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::prefetch: + AccessTy.AddrSpace = II->getArgOperand(0)->getType()->getPointerAddressSpace(); + break; + default: { + MemIntrinsicInfo IntrInfo; + if (TTI.getTgtMemIntrinsic(II, IntrInfo) && IntrInfo.PtrVal) { + AccessTy.AddrSpace + = IntrInfo.PtrVal->getType()->getPointerAddressSpace(); + } + + break; + } + } } // All pointers have the same requirements, so canonicalize them to an @@ -948,6 +982,7 @@ class LSRUse; /// accurate cost model. static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, const LSRUse &LU, const Formula &F); + // Get the cost of the scaling factor used in F for LU. static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, const LSRUse &LU, const Formula &F, @@ -1013,30 +1048,30 @@ private: ScalarEvolution &SE, DominatorTree &DT, SmallPtrSetImpl<const SCEV *> *LoserRegs); }; - + /// An operand value in an instruction which is to be replaced with some /// equivalent, possibly strength-reduced, replacement. struct LSRFixup { /// The instruction which will be updated. - Instruction *UserInst; + Instruction *UserInst = nullptr; /// The operand of the instruction which will be replaced. The operand may be /// used more than once; every instance will be replaced. - Value *OperandValToReplace; + Value *OperandValToReplace = nullptr; /// If this user is to use the post-incremented value of an induction - /// variable, this variable is non-null and holds the loop associated with the + /// variable, this set is non-empty and holds the loops associated with the /// induction variable. PostIncLoopSet PostIncLoops; /// A constant offset to be added to the LSRUse expression. This allows /// multiple fixups to share the same LSRUse with different offsets, for /// example in an unrolled loop. - int64_t Offset; + int64_t Offset = 0; - bool isUseFullyOutsideLoop(const Loop *L) const; + LSRFixup() = default; - LSRFixup(); + bool isUseFullyOutsideLoop(const Loop *L) const; void print(raw_ostream &OS) const; void dump() const; @@ -1086,7 +1121,7 @@ public: // TODO: Add a generic icmp too? }; - typedef PointerIntPair<const SCEV *, 2, KindType> SCEVUseKindPair; + using SCEVUseKindPair = PointerIntPair<const SCEV *, 2, KindType>; KindType Kind; MemAccessTy AccessTy; @@ -1095,25 +1130,25 @@ public: SmallVector<LSRFixup, 8> Fixups; /// Keep track of the min and max offsets of the fixups. - int64_t MinOffset; - int64_t MaxOffset; + int64_t MinOffset = std::numeric_limits<int64_t>::max(); + int64_t MaxOffset = std::numeric_limits<int64_t>::min(); /// This records whether all of the fixups using this LSRUse are outside of /// the loop, in which case some special-case heuristics may be used. - bool AllFixupsOutsideLoop; + bool AllFixupsOutsideLoop = true; /// RigidFormula is set to true to guarantee that this use will be associated /// with a single formula--the one that initially matched. Some SCEV /// expressions cannot be expanded. This allows LSR to consider the registers /// used by those expressions without the need to expand them later after /// changing the formula. - bool RigidFormula; + bool RigidFormula = false; /// This records the widest use type for any fixup using this /// LSRUse. FindUseWithSimilarFormula can't consider uses with different max /// fixup widths to be equivalent, because the narrower one may be relying on /// the implicit truncation to truncate away bogus bits. - Type *WidestFixupType; + Type *WidestFixupType = nullptr; /// A list of ways to build a value that can satisfy this user. After the /// list is populated, one of these is selected heuristically and used to @@ -1123,10 +1158,7 @@ public: /// The set of register candidates used by all formulae in this LSRUse. SmallPtrSet<const SCEV *, 4> Regs; - LSRUse(KindType K, MemAccessTy AT) - : Kind(K), AccessTy(AT), MinOffset(INT64_MAX), MaxOffset(INT64_MIN), - AllFixupsOutsideLoop(true), RigidFormula(false), - WidestFixupType(nullptr) {} + LSRUse(KindType K, MemAccessTy AT) : Kind(K), AccessTy(AT) {} LSRFixup &getNewFixup() { Fixups.push_back(LSRFixup()); @@ -1140,7 +1172,7 @@ public: if (f.Offset < MinOffset) MinOffset = f.Offset; } - + bool HasFormulaWithSameRegs(const Formula &F) const; float getNotSelectedProbability(const SCEV *Reg) const; bool InsertFormula(const Formula &F, const Loop &L); @@ -1153,6 +1185,12 @@ public: } // end anonymous namespace +static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, + LSRUse::KindType Kind, MemAccessTy AccessTy, + GlobalValue *BaseGV, int64_t BaseOffset, + bool HasBaseReg, int64_t Scale, + Instruction *Fixup = nullptr); + /// Tally up interesting quantities from the given register. void Cost::RateRegister(const SCEV *Reg, SmallPtrSetImpl<const SCEV *> &Regs, @@ -1280,8 +1318,9 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, // Check with target if this offset with this instruction is // specifically not supported. - if ((isa<LoadInst>(Fixup.UserInst) || isa<StoreInst>(Fixup.UserInst)) && - !TTI.isFoldableMemAccessOffset(Fixup.UserInst, Offset)) + if (LU.Kind == LSRUse::Address && Offset != 0 && + !isAMCompletelyFolded(TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, + Offset, F.HasBaseReg, F.Scale, Fixup.UserInst)) C.NumBaseAdds++; } @@ -1325,14 +1364,14 @@ void Cost::RateFormula(const TargetTransformInfo &TTI, /// Set this cost to a losing value. void Cost::Lose() { - C.Insns = ~0u; - C.NumRegs = ~0u; - C.AddRecCost = ~0u; - C.NumIVMuls = ~0u; - C.NumBaseAdds = ~0u; - C.ImmCost = ~0u; - C.SetupCost = ~0u; - C.ScaleCost = ~0u; + C.Insns = std::numeric_limits<unsigned>::max(); + C.NumRegs = std::numeric_limits<unsigned>::max(); + C.AddRecCost = std::numeric_limits<unsigned>::max(); + C.NumIVMuls = std::numeric_limits<unsigned>::max(); + C.NumBaseAdds = std::numeric_limits<unsigned>::max(); + C.ImmCost = std::numeric_limits<unsigned>::max(); + C.SetupCost = std::numeric_limits<unsigned>::max(); + C.ScaleCost = std::numeric_limits<unsigned>::max(); } /// Choose the lower cost. @@ -1343,6 +1382,7 @@ bool Cost::isLess(Cost &Other, const TargetTransformInfo &TTI) { return TTI.isLSRCostLess(C, Other.C); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void Cost::print(raw_ostream &OS) const { if (InsnsCost) OS << C.Insns << " instruction" << (C.Insns == 1 ? " " : "s "); @@ -1363,16 +1403,11 @@ void Cost::print(raw_ostream &OS) const { OS << ", plus " << C.SetupCost << " setup cost"; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void Cost::dump() const { print(errs()); errs() << '\n'; } #endif -LSRFixup::LSRFixup() - : UserInst(nullptr), OperandValToReplace(nullptr), - Offset(0) {} - /// Test whether this fixup always uses its value outside of the given loop. bool LSRFixup::isUseFullyOutsideLoop(const Loop *L) const { // PHI nodes use their value in their incoming blocks. @@ -1387,6 +1422,7 @@ bool LSRFixup::isUseFullyOutsideLoop(const Loop *L) const { return !L->contains(UserInst); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void LSRFixup::print(raw_ostream &OS) const { OS << "UserInst="; // Store is common and interesting enough to be worth special-casing. @@ -1410,7 +1446,6 @@ void LSRFixup::print(raw_ostream &OS) const { OS << ", Offset=" << Offset; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void LSRFixup::dump() const { print(errs()); errs() << '\n'; } @@ -1493,6 +1528,7 @@ void LSRUse::RecomputeRegs(size_t LUIdx, RegUseTracker &RegUses) { RegUses.dropRegister(S, LUIdx); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void LSRUse::print(raw_ostream &OS) const { OS << "LSR Use: Kind="; switch (Kind) { @@ -1526,7 +1562,6 @@ void LSRUse::print(raw_ostream &OS) const { OS << ", widest fixup type: " << *WidestFixupType; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void LSRUse::dump() const { print(errs()); errs() << '\n'; } @@ -1535,11 +1570,12 @@ LLVM_DUMP_METHOD void LSRUse::dump() const { static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, LSRUse::KindType Kind, MemAccessTy AccessTy, GlobalValue *BaseGV, int64_t BaseOffset, - bool HasBaseReg, int64_t Scale) { + bool HasBaseReg, int64_t Scale, + Instruction *Fixup/*= nullptr*/) { switch (Kind) { case LSRUse::Address: return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset, - HasBaseReg, Scale, AccessTy.AddrSpace); + HasBaseReg, Scale, AccessTy.AddrSpace, Fixup); case LSRUse::ICmpZero: // There's not even a target hook for querying whether it would be legal to @@ -1564,7 +1600,8 @@ static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, // ICmpZero -1*ScaleReg + BaseOffset => ICmp ScaleReg, BaseOffset // Offs is the ICmp immediate. if (Scale == 0) - // The cast does the right thing with INT64_MIN. + // The cast does the right thing with + // std::numeric_limits<int64_t>::min(). BaseOffset = -(uint64_t)BaseOffset; return TTI.isLegalICmpImmediate(BaseOffset); } @@ -1645,6 +1682,16 @@ static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset, static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, const LSRUse &LU, const Formula &F) { + // Target may want to look at the user instructions. + if (LU.Kind == LSRUse::Address && TTI.LSRWithInstrQueries()) { + for (const LSRFixup &Fixup : LU.Fixups) + if (!isAMCompletelyFolded(TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, + (F.BaseOffset + Fixup.Offset), F.HasBaseReg, + F.Scale, Fixup.UserInst)) + return false; + return true; + } + return isAMCompletelyFolded(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F.BaseGV, F.BaseOffset, F.HasBaseReg, F.Scale); @@ -1752,22 +1799,21 @@ struct IVInc { Value* IVOperand; const SCEV *IncExpr; - IVInc(Instruction *U, Value *O, const SCEV *E): - UserInst(U), IVOperand(O), IncExpr(E) {} + IVInc(Instruction *U, Value *O, const SCEV *E) + : UserInst(U), IVOperand(O), IncExpr(E) {} }; // The list of IV increments in program order. We typically add the head of a // chain without finding subsequent links. struct IVChain { - SmallVector<IVInc,1> Incs; - const SCEV *ExprBase; - - IVChain() : ExprBase(nullptr) {} + SmallVector<IVInc, 1> Incs; + const SCEV *ExprBase = nullptr; + IVChain() = default; IVChain(const IVInc &Head, const SCEV *Base) - : Incs(1, Head), ExprBase(Base) {} + : Incs(1, Head), ExprBase(Base) {} - typedef SmallVectorImpl<IVInc>::const_iterator const_iterator; + using const_iterator = SmallVectorImpl<IVInc>::const_iterator; // Return the first increment in the chain. const_iterator begin() const { @@ -1809,13 +1855,13 @@ class LSRInstance { LoopInfo &LI; const TargetTransformInfo &TTI; Loop *const L; - bool Changed; + bool Changed = false; /// This is the insert position that the current loop's induction variable /// increment should be placed. In simple loops, this is the latch block's /// terminator. But in more complicated cases, this is a position which will /// dominate all the in-loop post-increment users. - Instruction *IVIncInsertPos; + Instruction *IVIncInsertPos = nullptr; /// Interesting factors between use strides. /// @@ -1861,7 +1907,7 @@ class LSRInstance { void CollectFixupsAndInitialFormulae(); // Support for sharing of LSRUses between LSRFixups. - typedef DenseMap<LSRUse::SCEVUseKindPair, size_t> UseMapTy; + using UseMapTy = DenseMap<LSRUse::SCEVUseKindPair, size_t>; UseMapTy UseMap; bool reconcileNewOffset(LSRUse &LU, int64_t NewOffset, bool HasBaseReg, @@ -2002,6 +2048,14 @@ void LSRInstance::OptimizeShadowIV() { if (!PH) continue; if (PH->getNumIncomingValues() != 2) continue; + // If the calculation in integers overflows, the result in FP type will + // differ. So we only can do this transformation if we are guaranteed to not + // deal with overflowing values + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(PH)); + if (!AR) continue; + if (IsSigned && !AR->hasNoSignedWrap()) continue; + if (!IsSigned && !AR->hasNoUnsignedWrap()) continue; + Type *SrcTy = PH->getType(); int Mantissa = DestTy->getFPMantissaWidth(); if (Mantissa == -1) continue; @@ -2094,7 +2148,7 @@ bool LSRInstance::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse) { /// unfortunately this can come up even for loops where the user didn't use /// a C do-while loop. For example, seemingly well-behaved top-test loops /// will commonly be lowered like this: -// +/// /// if (n > 0) { /// i = 0; /// do { @@ -2128,7 +2182,6 @@ bool LSRInstance::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse) { /// This function solves this problem by detecting this type of loop and /// rewriting their conditions from ICMP_NE back to ICMP_SLT, and deleting /// the instructions for the maximum computation. -/// ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) { // Check that the loop matches the pattern we're looking for. if (Cond->getPredicate() != CmpInst::ICMP_EQ && @@ -2268,7 +2321,6 @@ LSRInstance::OptimizeLoopTermCond() { // Otherwise treat this as a rotated loop. for (BasicBlock *ExitingBlock : ExitingBlocks) { - // Get the terminating condition for the loop if possible. If we // can, we want to change it to use a post-incremented version of its // induction variable, to allow coalescing the live ranges for the IV into @@ -2333,7 +2385,7 @@ LSRInstance::OptimizeLoopTermCond() { C->getValue().isMinSignedValue()) goto decline_post_inc; // Check for possible scaled-address reuse. - MemAccessTy AccessTy = getAccessType(UI->getUser()); + MemAccessTy AccessTy = getAccessType(TTI, UI->getUser()); int64_t Scale = C->getSExtValue(); if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, /*BaseOffset=*/0, @@ -2941,7 +2993,7 @@ void LSRInstance::CollectChains() { // consider leaf IV Users. This effectively rediscovers a portion of // IVUsers analysis but in program order this time. if (SE.isSCEVable(I.getType()) && !isa<SCEVUnknown>(SE.getSCEV(&I))) - continue; + continue; // Remove this instruction from any NearUsers set it may be in. for (unsigned ChainIdx = 0, NChains = IVChainVec.size(); @@ -3003,13 +3055,13 @@ void LSRInstance::FinalizeChain(IVChain &Chain) { static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, Value *Operand, const TargetTransformInfo &TTI) { const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr); - if (!IncConst || !isAddressUse(UserInst, Operand)) + if (!IncConst || !isAddressUse(TTI, UserInst, Operand)) return false; if (IncConst->getAPInt().getMinSignedBits() > 64) return false; - MemAccessTy AccessTy = getAccessType(UserInst); + MemAccessTy AccessTy = getAccessType(TTI, UserInst); int64_t IncOffset = IncConst->getValue()->getSExtValue(); if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr, IncOffset, /*HaseBaseReg=*/false)) @@ -3136,14 +3188,14 @@ void LSRInstance::CollectFixupsAndInitialFormulae() { LSRUse::KindType Kind = LSRUse::Basic; MemAccessTy AccessTy; - if (isAddressUse(UserInst, U.getOperandValToReplace())) { + if (isAddressUse(TTI, UserInst, U.getOperandValToReplace())) { Kind = LSRUse::Address; - AccessTy = getAccessType(UserInst); + AccessTy = getAccessType(TTI, UserInst); } const SCEV *S = IU.getExpr(U); PostIncLoopSet TmpPostIncLoops = U.getPostIncLoops(); - + // Equality (== and !=) ICmps are special. We can rewrite (i == N) as // (N - i == 0), and this allows (N - i) to be the expression that we work // with rather than just N or i, so we can consider the register @@ -3432,7 +3484,6 @@ void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, for (SmallVectorImpl<const SCEV *>::const_iterator J = AddOps.begin(), JE = AddOps.end(); J != JE; ++J) { - // Loop-variant "unknown" values are uninteresting; we won't be able to // do anything meaningful with them. if (isa<SCEVUnknown>(*J) && !SE.isLoopInvariant(*J, L)) @@ -3654,12 +3705,18 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, // Don't do this if there is more than one offset. if (LU.MinOffset != LU.MaxOffset) return; + // Check if transformation is valid. It is illegal to multiply pointer. + if (Base.ScaledReg && Base.ScaledReg->getType()->isPointerTy()) + return; + for (const SCEV *BaseReg : Base.BaseRegs) + if (BaseReg->getType()->isPointerTy()) + return; assert(!Base.BaseGV && "ICmpZero use is not legal!"); // Check each interesting stride. for (int64_t Factor : Factors) { // Check that the multiplication doesn't overflow. - if (Base.BaseOffset == INT64_MIN && Factor == -1) + if (Base.BaseOffset == std::numeric_limits<int64_t>::min() && Factor == -1) continue; int64_t NewBaseOffset = (uint64_t)Base.BaseOffset * Factor; if (NewBaseOffset / Factor != Base.BaseOffset) @@ -3671,7 +3728,7 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, // Check that multiplying with the use offset doesn't overflow. int64_t Offset = LU.MinOffset; - if (Offset == INT64_MIN && Factor == -1) + if (Offset == std::numeric_limits<int64_t>::min() && Factor == -1) continue; Offset = (uint64_t)Offset * Factor; if (Offset / Factor != LU.MinOffset) @@ -3709,7 +3766,8 @@ void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, // Check that multiplying with the unfolded offset doesn't overflow. if (F.UnfoldedOffset != 0) { - if (F.UnfoldedOffset == INT64_MIN && Factor == -1) + if (F.UnfoldedOffset == std::numeric_limits<int64_t>::min() && + Factor == -1) continue; F.UnfoldedOffset = (uint64_t)F.UnfoldedOffset * Factor; if (F.UnfoldedOffset / Factor != Base.UnfoldedOffset) @@ -3833,7 +3891,7 @@ struct WorkItem { const SCEV *OrigReg; WorkItem(size_t LI, int64_t I, const SCEV *R) - : LUIdx(LI), Imm(I), OrigReg(R) {} + : LUIdx(LI), Imm(I), OrigReg(R) {} void print(raw_ostream &OS) const; void dump() const; @@ -3841,12 +3899,12 @@ struct WorkItem { } // end anonymous namespace +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void WorkItem::print(raw_ostream &OS) const { OS << "in formulae referencing " << *OrigReg << " in use " << LUIdx << " , add offset " << Imm; } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void WorkItem::dump() const { print(errs()); errs() << '\n'; } @@ -3856,7 +3914,8 @@ LLVM_DUMP_METHOD void WorkItem::dump() const { /// opportunities between them. void LSRInstance::GenerateCrossUseConstantOffsets() { // Group the registers by their value without any added constant offset. - typedef std::map<int64_t, const SCEV *> ImmMapTy; + using ImmMapTy = std::map<int64_t, const SCEV *>; + DenseMap<const SCEV *, ImmMapTy> Map; DenseMap<const SCEV *, SmallBitVector> UsedByIndicesMap; SmallVector<const SCEV *, 8> Sequence; @@ -4060,8 +4119,9 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { // Collect the best formula for each unique set of shared registers. This // is reset for each use. - typedef DenseMap<SmallVector<const SCEV *, 4>, size_t, UniquifierDenseMapInfo> - BestFormulaeTy; + using BestFormulaeTy = + DenseMap<SmallVector<const SCEV *, 4>, size_t, UniquifierDenseMapInfo>; + BestFormulaeTy BestFormulae; for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { @@ -4148,7 +4208,7 @@ void LSRInstance::FilterOutUndesirableDedicatedRegisters() { } // This is a rough guess that seems to work fairly well. -static const size_t ComplexityLimit = UINT16_MAX; +static const size_t ComplexityLimit = std::numeric_limits<uint16_t>::max(); /// Estimate the worst-case number of solutions the solver might have to /// consider. It almost never considers this many solutions because it prune the @@ -4267,7 +4327,7 @@ void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { LUThatHas->pushFixup(Fixup); DEBUG(dbgs() << "New fixup has offset " << Fixup.Offset << '\n'); } - + // Delete formulae from the new use which are no longer legal. bool Any = false; for (size_t i = 0, e = LUThatHas->Formulae.size(); i != e; ++i) { @@ -4332,7 +4392,8 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { "from the Formulae with the same Scale and ScaledReg.\n"); // Map the "Scale * ScaledReg" pair to the best formula of current LSRUse. - typedef DenseMap<std::pair<const SCEV *, int64_t>, size_t> BestFormulaeTy; + using BestFormulaeTy = DenseMap<std::pair<const SCEV *, int64_t>, size_t>; + BestFormulaeTy BestFormulae; #ifndef NDEBUG bool ChangedFormulae = false; @@ -4454,7 +4515,6 @@ void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { /// Use3: /// reg(c) + reg(b) + reg({0,+,1}) 1 + 1/3 + 4/9 -- to be deleted /// reg(c) + reg({b,+,1}) 1 + 2/3 - void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { if (EstimateSearchSpaceComplexity() < ComplexityLimit) return; @@ -4549,7 +4609,6 @@ void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { print_uses(dbgs())); } - /// Pick a register which seems likely to be profitable, and then in any use /// which has any reference to that register, delete all formulae which do not /// reference that register. @@ -5196,8 +5255,7 @@ void LSRInstance::ImplementSolution( LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, LoopInfo &LI, const TargetTransformInfo &TTI) - : IU(IU), SE(SE), DT(DT), LI(LI), TTI(TTI), L(L), Changed(false), - IVIncInsertPos(nullptr) { + : IU(IU), SE(SE), DT(DT), LI(LI), TTI(TTI), L(L) { // If LoopSimplify form is not available, stay out of trouble. if (!L->isLoopSimplifyForm()) return; @@ -5302,6 +5360,7 @@ LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, ImplementSolution(Solution); } +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) void LSRInstance::print_factors_and_types(raw_ostream &OS) const { if (Factors.empty() && Types.empty()) return; @@ -5352,7 +5411,6 @@ void LSRInstance::print(raw_ostream &OS) const { print_uses(OS); } -#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) LLVM_DUMP_METHOD void LSRInstance::dump() const { print(errs()); errs() << '\n'; } @@ -5448,6 +5506,7 @@ PreservedAnalyses LoopStrengthReducePass::run(Loop &L, LoopAnalysisManager &AM, } char LoopStrengthReduce::ID = 0; + INITIALIZE_PASS_BEGIN(LoopStrengthReduce, "loop-reduce", "Loop Strength Reduction", false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) diff --git a/lib/Transforms/Scalar/LoopUnrollPass.cpp b/lib/Transforms/Scalar/LoopUnrollPass.cpp index 530a68424d5c..7b1d6446a24a 100644 --- a/lib/Transforms/Scalar/LoopUnrollPass.cpp +++ b/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -1,4 +1,4 @@ -//===-- LoopUnroll.cpp - Loop unroller pass -------------------------------===// +//===- LoopUnroll.cpp - Loop unroller pass --------------------------------===// // // The LLVM Compiler Infrastructure // @@ -13,29 +13,55 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/LoopUnrollPass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/CodeMetrics.h" -#include "llvm/Analysis/GlobalsModRef.h" -#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/LoopUnrollAnalyzer.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/IR/DataLayout.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" -#include <climits> +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <limits> +#include <string> +#include <tuple> #include <utility> using namespace llvm; @@ -79,6 +105,10 @@ static cl::opt<unsigned> UnrollFullMaxCount( cl::desc( "Set the max unroll count for full unrolling, for testing purposes")); +static cl::opt<unsigned> UnrollPeelCount( + "unroll-peel-count", cl::Hidden, + cl::desc("Set the unroll peeling count, for testing purposes")); + static cl::opt<bool> UnrollAllowPartial("unroll-allow-partial", cl::Hidden, cl::desc("Allows loops to be partially unrolled until " @@ -114,6 +144,10 @@ static cl::opt<bool> cl::desc("Allows loops to be peeled when the dynamic " "trip count is known to be low.")); +static cl::opt<bool> UnrollUnrollRemainder( + "unroll-remainder", cl::Hidden, + cl::desc("Allow the loop remainder to be unrolled.")); + // This option isn't ever intended to be enabled, it serves to allow // experiments to check the assumptions about when this kind of revisit is // necessary. @@ -126,7 +160,7 @@ static cl::opt<bool> UnrollRevisitChildLoops( /// A magic value for use with the Threshold parameter to indicate /// that the loop unroll should be performed regardless of how much /// code expansion would result. -static const unsigned NoThreshold = UINT_MAX; +static const unsigned NoThreshold = std::numeric_limits<unsigned>::max(); /// Gather the various unrolling parameters based on the defaults, compiler /// flags, TTI overrides and user specified parameters. @@ -134,7 +168,7 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, int OptLevel, Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, - Optional<bool> UserUpperBound) { + Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling) { TargetTransformInfo::UnrollingPreferences UP; // Set up the defaults @@ -146,12 +180,13 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.Count = 0; UP.PeelCount = 0; UP.DefaultUnrollRuntimeCount = 8; - UP.MaxCount = UINT_MAX; - UP.FullUnrollMaxCount = UINT_MAX; + UP.MaxCount = std::numeric_limits<unsigned>::max(); + UP.FullUnrollMaxCount = std::numeric_limits<unsigned>::max(); UP.BEInsns = 2; UP.Partial = false; UP.Runtime = false; UP.AllowRemainder = true; + UP.UnrollRemainder = false; UP.AllowExpensiveTripCount = false; UP.Force = false; UP.UpperBound = false; @@ -177,6 +212,8 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.MaxCount = UnrollMaxCount; if (UnrollFullMaxCount.getNumOccurrences() > 0) UP.FullUnrollMaxCount = UnrollFullMaxCount; + if (UnrollPeelCount.getNumOccurrences() > 0) + UP.PeelCount = UnrollPeelCount; if (UnrollAllowPartial.getNumOccurrences() > 0) UP.Partial = UnrollAllowPartial; if (UnrollAllowRemainder.getNumOccurrences() > 0) @@ -187,6 +224,8 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.UpperBound = false; if (UnrollAllowPeeling.getNumOccurrences() > 0) UP.AllowPeeling = UnrollAllowPeeling; + if (UnrollUnrollRemainder.getNumOccurrences() > 0) + UP.UnrollRemainder = UnrollUnrollRemainder; // Apply user values provided by argument if (UserThreshold.hasValue()) { @@ -201,11 +240,14 @@ static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( UP.Runtime = *UserRuntime; if (UserUpperBound.hasValue()) UP.UpperBound = *UserUpperBound; + if (UserAllowPeeling.hasValue()) + UP.AllowPeeling = *UserAllowPeeling; return UP; } namespace { + /// A struct to densely store the state of an instruction after unrolling at /// each iteration. /// @@ -221,25 +263,27 @@ struct UnrolledInstState { /// Hashing and equality testing for a set of the instruction states. struct UnrolledInstStateKeyInfo { - typedef DenseMapInfo<Instruction *> PtrInfo; - typedef DenseMapInfo<std::pair<Instruction *, int>> PairInfo; + using PtrInfo = DenseMapInfo<Instruction *>; + using PairInfo = DenseMapInfo<std::pair<Instruction *, int>>; + static inline UnrolledInstState getEmptyKey() { return {PtrInfo::getEmptyKey(), 0, 0, 0}; } + static inline UnrolledInstState getTombstoneKey() { return {PtrInfo::getTombstoneKey(), 0, 0, 0}; } + static inline unsigned getHashValue(const UnrolledInstState &S) { return PairInfo::getHashValue({S.I, S.Iteration}); } + static inline bool isEqual(const UnrolledInstState &LHS, const UnrolledInstState &RHS) { return PairInfo::isEqual({LHS.I, LHS.Iteration}, {RHS.I, RHS.Iteration}); } }; -} -namespace { struct EstimatedUnrollCost { /// \brief The estimated cost after unrolling. unsigned UnrolledCost; @@ -248,7 +292,8 @@ struct EstimatedUnrollCost { /// rolled form. unsigned RolledDynamicCost; }; -} + +} // end anonymous namespace /// \brief Figure out if the loop is worth full unrolling. /// @@ -270,7 +315,8 @@ analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, // We want to be able to scale offsets by the trip count and add more offsets // to them without checking for overflows, and we already don't want to // analyze *massive* trip counts, so we force the max to be reasonably small. - assert(UnrollMaxIterationsCountToAnalyze < (INT_MAX / 2) && + assert(UnrollMaxIterationsCountToAnalyze < + (unsigned)(std::numeric_limits<int>::max() / 2) && "The unroll iterations max is too large!"); // Only analyze inner loops. We can't properly estimate cost of nested loops @@ -633,43 +679,6 @@ static unsigned UnrollCountPragmaValue(const Loop *L) { return 0; } -// Remove existing unroll metadata and add unroll disable metadata to -// indicate the loop has already been unrolled. This prevents a loop -// from being unrolled more than is directed by a pragma if the loop -// unrolling pass is run more than once (which it generally is). -static void SetLoopAlreadyUnrolled(Loop *L) { - MDNode *LoopID = L->getLoopID(); - // First remove any existing loop unrolling metadata. - SmallVector<Metadata *, 4> MDs; - // Reserve first location for self reference to the LoopID metadata node. - MDs.push_back(nullptr); - - if (LoopID) { - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - bool IsUnrollMetadata = false; - MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); - if (MD) { - const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); - IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); - } - if (!IsUnrollMetadata) - MDs.push_back(LoopID->getOperand(i)); - } - } - - // Add unroll(disable) metadata to disable future unrolling. - LLVMContext &Context = L->getHeader()->getContext(); - SmallVector<Metadata *, 1> DisableOperands; - DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable")); - MDNode *DisableNode = MDNode::get(Context, DisableOperands); - MDs.push_back(DisableNode); - - MDNode *NewLoopID = MDNode::get(Context, MDs); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - L->setLoopID(NewLoopID); -} - // Computes the boosting factor for complete unrolling. // If fully unrolling the loop would save a lot of RolledDynamicCost, it would // be beneficial to fully unroll the loop even if unrolledcost is large. We @@ -677,7 +686,7 @@ static void SetLoopAlreadyUnrolled(Loop *L) { // the unroll threshold. static unsigned getFullUnrollBoostingFactor(const EstimatedUnrollCost &Cost, unsigned MaxPercentThresholdBoost) { - if (Cost.RolledDynamicCost >= UINT_MAX / 100) + if (Cost.RolledDynamicCost >= std::numeric_limits<unsigned>::max() / 100) return 100; else if (Cost.UnrolledCost != 0) // The boosting factor is RolledDynamicCost / UnrolledCost @@ -826,11 +835,14 @@ static bool computeUnrollCount( } if (UP.Count < 2) { if (PragmaEnableUnroll) - ORE->emit( - OptimizationRemarkMissed(DEBUG_TYPE, "UnrollAsDirectedTooLarge", - L->getStartLoc(), L->getHeader()) - << "Unable to unroll loop as directed by unroll(enable) pragma " - "because unrolled size is too large."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "UnrollAsDirectedTooLarge", + L->getStartLoc(), L->getHeader()) + << "Unable to unroll loop as directed by unroll(enable) " + "pragma " + "because unrolled size is too large."; + }); UP.Count = 0; } } else { @@ -840,22 +852,27 @@ static bool computeUnrollCount( UP.Count = UP.MaxCount; if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && UP.Count != TripCount) - ORE->emit( - OptimizationRemarkMissed(DEBUG_TYPE, "FullUnrollAsDirectedTooLarge", - L->getStartLoc(), L->getHeader()) - << "Unable to fully unroll loop as directed by unroll pragma because " - "unrolled size is too large."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "FullUnrollAsDirectedTooLarge", + L->getStartLoc(), L->getHeader()) + << "Unable to fully unroll loop as directed by unroll pragma " + "because " + "unrolled size is too large."; + }); return ExplicitUnroll; } assert(TripCount == 0 && "All cases when TripCount is constant should be covered here."); if (PragmaFullUnroll) - ORE->emit( - OptimizationRemarkMissed(DEBUG_TYPE, - "CantFullUnrollAsDirectedRuntimeTripCount", - L->getStartLoc(), L->getHeader()) - << "Unable to fully unroll loop as directed by unroll(full) pragma " - "because loop has a runtime trip count."); + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "CantFullUnrollAsDirectedRuntimeTripCount", + L->getStartLoc(), L->getHeader()) + << "Unable to fully unroll loop as directed by unroll(full) " + "pragma " + "because loop has a runtime trip count."; + }); // 6th priority is runtime unrolling. // Don't unroll a runtime trip count loop when it is disabled. @@ -904,19 +921,23 @@ static bool computeUnrollCount( "multiple, " << TripMultiple << ". Reducing unroll count from " << OrigCount << " to " << UP.Count << ".\n"); + using namespace ore; + if (PragmaCount > 0 && !UP.AllowRemainder) - ORE->emit( - OptimizationRemarkMissed(DEBUG_TYPE, - "DifferentUnrollCountFromDirected", - L->getStartLoc(), L->getHeader()) - << "Unable to unroll loop the number of times directed by " - "unroll_count pragma because remainder loop is restricted " - "(that could architecture specific or because the loop " - "contains a convergent instruction) and so must have an unroll " - "count that divides the loop trip multiple of " - << NV("TripMultiple", TripMultiple) << ". Unrolling instead " - << NV("UnrollCount", UP.Count) << " time(s)."); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "DifferentUnrollCountFromDirected", + L->getStartLoc(), L->getHeader()) + << "Unable to unroll loop the number of times directed by " + "unroll_count pragma because remainder loop is restricted " + "(that could architecture specific or because the loop " + "contains a convergent instruction) and so must have an " + "unroll " + "count that divides the loop trip multiple of " + << NV("TripMultiple", TripMultiple) << ". Unrolling instead " + << NV("UnrollCount", UP.Count) << " time(s)."; + }); } if (UP.Count > UP.MaxCount) @@ -927,23 +948,21 @@ static bool computeUnrollCount( return ExplicitUnroll; } -static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, - ScalarEvolution &SE, const TargetTransformInfo &TTI, - AssumptionCache &AC, OptimizationRemarkEmitter &ORE, - bool PreserveLCSSA, int OptLevel, - Optional<unsigned> ProvidedCount, - Optional<unsigned> ProvidedThreshold, - Optional<bool> ProvidedAllowPartial, - Optional<bool> ProvidedRuntime, - Optional<bool> ProvidedUpperBound) { +static LoopUnrollResult tryToUnrollLoop( + Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, + const TargetTransformInfo &TTI, AssumptionCache &AC, + OptimizationRemarkEmitter &ORE, bool PreserveLCSSA, int OptLevel, + Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, + Optional<bool> ProvidedAllowPartial, Optional<bool> ProvidedRuntime, + Optional<bool> ProvidedUpperBound, Optional<bool> ProvidedAllowPeeling) { DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() << "] Loop %" << L->getHeader()->getName() << "\n"); - if (HasUnrollDisablePragma(L)) - return false; - if (!L->isLoopSimplifyForm()) { + if (HasUnrollDisablePragma(L)) + return LoopUnrollResult::Unmodified; + if (!L->isLoopSimplifyForm()) { DEBUG( dbgs() << " Not unrolling loop which is not in loop-simplify form.\n"); - return false; + return LoopUnrollResult::Unmodified; } unsigned NumInlineCandidates; @@ -951,21 +970,22 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, bool Convergent; TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( L, SE, TTI, OptLevel, ProvidedThreshold, ProvidedCount, - ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound); + ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, + ProvidedAllowPeeling); // Exit early if unrolling is disabled. if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0)) - return false; + return LoopUnrollResult::Unmodified; unsigned LoopSize = ApproximateLoopSize( L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, &AC, UP.BEInsns); DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); if (NotDuplicatable) { DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" << " instructions.\n"); - return false; + return LoopUnrollResult::Unmodified; } if (NumInlineCandidates != 0) { DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); - return false; + return LoopUnrollResult::Unmodified; } // Find trip count and trip multiple if count is not available @@ -1024,41 +1044,35 @@ static bool tryToUnrollLoop(Loop *L, DominatorTree &DT, LoopInfo *LI, computeUnrollCount(L, TTI, DT, LI, SE, &ORE, TripCount, MaxTripCount, TripMultiple, LoopSize, UP, UseUpperBound); if (!UP.Count) - return false; + return LoopUnrollResult::Unmodified; // Unroll factor (Count) must be less or equal to TripCount. if (TripCount && UP.Count > TripCount) UP.Count = TripCount; // Unroll the loop. - if (!UnrollLoop(L, UP.Count, TripCount, UP.Force, UP.Runtime, - UP.AllowExpensiveTripCount, UseUpperBound, MaxOrZero, - TripMultiple, UP.PeelCount, LI, &SE, &DT, &AC, &ORE, - PreserveLCSSA)) - return false; + LoopUnrollResult UnrollResult = UnrollLoop( + L, UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, + UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder, + LI, &SE, &DT, &AC, &ORE, PreserveLCSSA); + if (UnrollResult == LoopUnrollResult::Unmodified) + return LoopUnrollResult::Unmodified; // If loop has an unroll count pragma or unrolled by explicitly set count // mark loop as unrolled to prevent unrolling beyond that requested. // If the loop was peeled, we already "used up" the profile information // we had, so we don't want to unroll or peel again. - if (IsCountSetExplicitly || UP.PeelCount) - SetLoopAlreadyUnrolled(L); + if (UnrollResult != LoopUnrollResult::FullyUnrolled && + (IsCountSetExplicitly || UP.PeelCount)) + L->setLoopAlreadyUnrolled(); - return true; + return UnrollResult; } namespace { + class LoopUnroll : public LoopPass { public: static char ID; // Pass ID, replacement for typeid - LoopUnroll(int OptLevel = 2, Optional<unsigned> Threshold = None, - Optional<unsigned> Count = None, - Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, - Optional<bool> UpperBound = None) - : LoopPass(ID), OptLevel(OptLevel), ProvidedCount(std::move(Count)), - ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), - ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound) { - initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); - } int OptLevel; Optional<unsigned> ProvidedCount; @@ -1066,8 +1080,21 @@ public: Optional<bool> ProvidedAllowPartial; Optional<bool> ProvidedRuntime; Optional<bool> ProvidedUpperBound; + Optional<bool> ProvidedAllowPeeling; - bool runOnLoop(Loop *L, LPPassManager &) override { + LoopUnroll(int OptLevel = 2, Optional<unsigned> Threshold = None, + Optional<unsigned> Count = None, + Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, + Optional<bool> UpperBound = None, + Optional<bool> AllowPeeling = None) + : LoopPass(ID), OptLevel(OptLevel), ProvidedCount(std::move(Count)), + ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), + ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound), + ProvidedAllowPeeling(AllowPeeling) { + initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { if (skipLoop(L)) return false; @@ -1085,15 +1112,19 @@ public: OptimizationRemarkEmitter ORE(&F); bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); - return tryToUnrollLoop(L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel, - ProvidedCount, ProvidedThreshold, - ProvidedAllowPartial, ProvidedRuntime, - ProvidedUpperBound); + LoopUnrollResult Result = tryToUnrollLoop( + L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel, ProvidedCount, + ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime, + ProvidedUpperBound, ProvidedAllowPeeling); + + if (Result == LoopUnrollResult::FullyUnrolled) + LPM.markLoopAsDeleted(*L); + + return Result != LoopUnrollResult::Unmodified; } /// This transformation requires natural loop information & requires that /// loop preheaders be inserted into the CFG... - /// void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); @@ -1102,9 +1133,11 @@ public: getLoopAnalysisUsage(AU); } }; -} + +} // end anonymous namespace char LoopUnroll::ID = 0; + INITIALIZE_PASS_BEGIN(LoopUnroll, "loop-unroll", "Unroll loops", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_DEPENDENCY(LoopPass) @@ -1112,8 +1145,8 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count, - int AllowPartial, int Runtime, - int UpperBound) { + int AllowPartial, int Runtime, int UpperBound, + int AllowPeeling) { // TODO: It would make more sense for this function to take the optionals // directly, but that's dangerous since it would silently break out of tree // callers. @@ -1122,16 +1155,17 @@ Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count, Count == -1 ? None : Optional<unsigned>(Count), AllowPartial == -1 ? None : Optional<bool>(AllowPartial), Runtime == -1 ? None : Optional<bool>(Runtime), - UpperBound == -1 ? None : Optional<bool>(UpperBound)); + UpperBound == -1 ? None : Optional<bool>(UpperBound), + AllowPeeling == -1 ? None : Optional<bool>(AllowPeeling)); } Pass *llvm::createSimpleLoopUnrollPass(int OptLevel) { - return llvm::createLoopUnrollPass(OptLevel, -1, -1, 0, 0, 0); + return createLoopUnrollPass(OptLevel, -1, -1, 0, 0, 0, 0); } -PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, - LoopStandardAnalysisResults &AR, - LPMUpdater &Updater) { +PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &Updater) { const auto &FAM = AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); Function *F = L.getHeader()->getParent(); @@ -1139,8 +1173,9 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); // FIXME: This should probably be optional rather than required. if (!ORE) - report_fatal_error("LoopUnrollPass: OptimizationRemarkEmitterAnalysis not " - "cached at a higher level"); + report_fatal_error( + "LoopFullUnrollPass: OptimizationRemarkEmitterAnalysis not " + "cached at a higher level"); // Keep track of the previous loop structure so we can identify new loops // created by unrolling. @@ -1151,17 +1186,14 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, else OldLoops.insert(AR.LI.begin(), AR.LI.end()); - // The API here is quite complex to call, but there are only two interesting - // states we support: partial and full (or "simple") unrolling. However, to - // enable these things we actually pass "None" in for the optional to avoid - // providing an explicit choice. - Optional<bool> AllowPartialParam, RuntimeParam, UpperBoundParam; - if (!AllowPartialUnrolling) - AllowPartialParam = RuntimeParam = UpperBoundParam = false; - bool Changed = tryToUnrollLoop( - &L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, - /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, - /*Threshold*/ None, AllowPartialParam, RuntimeParam, UpperBoundParam); + std::string LoopName = L.getName(); + + bool Changed = + tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, + /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, + /*Threshold*/ None, /*AllowPartial*/ false, + /*Runtime*/ false, /*UpperBound*/ false, + /*AllowPeeling*/ false) != LoopUnrollResult::Unmodified; if (!Changed) return PreservedAnalyses::all(); @@ -1172,17 +1204,13 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, #endif // Unrolling can do several things to introduce new loops into a loop nest: - // - Partial unrolling clones child loops within the current loop. If it - // uses a remainder, then it can also create any number of sibling loops. // - Full unrolling clones child loops within the current loop but then // removes the current loop making all of the children appear to be new // sibling loops. - // - Loop peeling can directly introduce new sibling loops by peeling one - // iteration. // - // When a new loop appears as a sibling loop, either from peeling an - // iteration or fully unrolling, its nesting structure has fundamentally - // changed and we want to revisit it to reflect that. + // When a new loop appears as a sibling loop after fully unrolling, + // its nesting structure has fundamentally changed and we want to revisit + // it to reflect that. // // When unrolling has removed the current loop, we need to tell the // infrastructure that it is gone. @@ -1209,13 +1237,11 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, Updater.addSiblingLoops(SibLoops); if (!IsCurrentLoopValid) { - Updater.markLoopAsDeleted(L); + Updater.markLoopAsDeleted(L, LoopName); } else { // We can only walk child loops if the current loop remained valid. if (UnrollRevisitChildLoops) { - // Walk *all* of the child loops. This is a highly speculative mode - // anyways so look for any simplifications that arose from partial - // unrolling or peeling off of iterations. + // Walk *all* of the child loops. SmallVector<Loop *, 4> ChildLoops(L.begin(), L.end()); Updater.addChildLoops(ChildLoops); } @@ -1223,3 +1249,105 @@ PreservedAnalyses LoopUnrollPass::run(Loop &L, LoopAnalysisManager &AM, return getLoopPassPreservedAnalyses(); } + +template <typename RangeT> +static SmallVector<Loop *, 8> appendLoopsToWorklist(RangeT &&Loops) { + SmallVector<Loop *, 8> Worklist; + // We use an internal worklist to build up the preorder traversal without + // recursion. + SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist; + + for (Loop *RootL : Loops) { + assert(PreOrderLoops.empty() && "Must start with an empty preorder walk."); + assert(PreOrderWorklist.empty() && + "Must start with an empty preorder walk worklist."); + PreOrderWorklist.push_back(RootL); + do { + Loop *L = PreOrderWorklist.pop_back_val(); + PreOrderWorklist.append(L->begin(), L->end()); + PreOrderLoops.push_back(L); + } while (!PreOrderWorklist.empty()); + + Worklist.append(PreOrderLoops.begin(), PreOrderLoops.end()); + PreOrderLoops.clear(); + } + return Worklist; +} + +PreservedAnalyses LoopUnrollPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + LoopAnalysisManager *LAM = nullptr; + if (auto *LAMProxy = AM.getCachedResult<LoopAnalysisManagerFunctionProxy>(F)) + LAM = &LAMProxy->getManager(); + + const ModuleAnalysisManager &MAM = + AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + ProfileSummaryInfo *PSI = + MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + + bool Changed = false; + + // The unroller requires loops to be in simplified form, and also needs LCSSA. + // Since simplification may add new inner loops, it has to run before the + // legality and profitability checks. This means running the loop unroller + // will simplify all loops, regardless of whether anything end up being + // unrolled. + for (auto &L : LI) { + Changed |= simplifyLoop(L, &DT, &LI, &SE, &AC, false /* PreserveLCSSA */); + Changed |= formLCSSARecursively(*L, DT, &LI, &SE); + } + + SmallVector<Loop *, 8> Worklist = appendLoopsToWorklist(LI); + + while (!Worklist.empty()) { + // Because the LoopInfo stores the loops in RPO, we walk the worklist + // from back to front so that we work forward across the CFG, which + // for unrolling is only needed to get optimization remarks emitted in + // a forward order. + Loop &L = *Worklist.pop_back_val(); +#ifndef NDEBUG + Loop *ParentL = L.getParentLoop(); +#endif + + // The API here is quite complex to call, but there are only two interesting + // states we support: partial and full (or "simple") unrolling. However, to + // enable these things we actually pass "None" in for the optional to avoid + // providing an explicit choice. + Optional<bool> AllowPartialParam, RuntimeParam, UpperBoundParam, + AllowPeeling; + // Check if the profile summary indicates that the profiled application + // has a huge working set size, in which case we disable peeling to avoid + // bloating it further. + if (PSI && PSI->hasHugeWorkingSetSize()) + AllowPeeling = false; + std::string LoopName = L.getName(); + LoopUnrollResult Result = + tryToUnrollLoop(&L, DT, &LI, SE, TTI, AC, ORE, + /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, + /*Threshold*/ None, AllowPartialParam, RuntimeParam, + UpperBoundParam, AllowPeeling); + Changed |= Result != LoopUnrollResult::Unmodified; + + // The parent must not be damaged by unrolling! +#ifndef NDEBUG + if (Result != LoopUnrollResult::Unmodified && ParentL) + ParentL->verifyLoop(); +#endif + + // Clear any cached analysis results for L if we removed it completely. + if (LAM && Result == LoopUnrollResult::FullyUnrolled) + LAM->clear(L, LoopName); + } + + if (!Changed) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} diff --git a/lib/Transforms/Scalar/LoopUnswitch.cpp b/lib/Transforms/Scalar/LoopUnswitch.cpp index d0c96fa627a4..bd468338a1d0 100644 --- a/lib/Transforms/Scalar/LoopUnswitch.cpp +++ b/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -1,4 +1,4 @@ -//===-- LoopUnswitch.cpp - Hoist loop-invariant conditionals in loop ------===// +//===- LoopUnswitch.cpp - Hoist loop-invariant conditionals in loop -------===// // // The LLVM Compiler Infrastructure // @@ -26,30 +26,40 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/BlockFrequencyInfo.h" -#include "llvm/Analysis/BlockFrequencyInfoImpl.h" -#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/DivergenceAnalysis.h" -#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" -#include "llvm/Support/BranchProbability.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -58,9 +68,15 @@ #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> +#include <cassert> #include <map> #include <set> +#include <tuple> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "loop-unswitch" @@ -82,11 +98,9 @@ Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"), namespace { class LUAnalysisCache { - - typedef DenseMap<const SwitchInst*, SmallPtrSet<const Value *, 8> > - UnswitchedValsMap; - - typedef UnswitchedValsMap::iterator UnswitchedValsIt; + using UnswitchedValsMap = + DenseMap<const SwitchInst *, SmallPtrSet<const Value *, 8>>; + using UnswitchedValsIt = UnswitchedValsMap::iterator; struct LoopProperties { unsigned CanBeUnswitchedCount; @@ -97,12 +111,12 @@ namespace { // Here we use std::map instead of DenseMap, since we need to keep valid // LoopProperties pointer for current loop for better performance. - typedef std::map<const Loop*, LoopProperties> LoopPropsMap; - typedef LoopPropsMap::iterator LoopPropsMapIt; + using LoopPropsMap = std::map<const Loop *, LoopProperties>; + using LoopPropsMapIt = LoopPropsMap::iterator; LoopPropsMap LoopsProperties; - UnswitchedValsMap *CurLoopInstructions; - LoopProperties *CurrentLoopProperties; + UnswitchedValsMap *CurLoopInstructions = nullptr; + LoopProperties *CurrentLoopProperties = nullptr; // A loop unswitching with an estimated cost above this threshold // is not performed. MaxSize is turned into unswitching quota for @@ -121,9 +135,7 @@ namespace { unsigned MaxSize; public: - LUAnalysisCache() - : CurLoopInstructions(nullptr), CurrentLoopProperties(nullptr), - MaxSize(Threshold) {} + LUAnalysisCache() : MaxSize(Threshold) {} // Analyze loop. Check its size, calculate is it possible to unswitch // it. Returns true if we can unswitch this loop. @@ -164,12 +176,12 @@ namespace { LUAnalysisCache BranchesInfo; bool OptimizeForSize; - bool redoLoop; + bool redoLoop = false; - Loop *currentLoop; - DominatorTree *DT; - BasicBlock *loopHeader; - BasicBlock *loopPreheader; + Loop *currentLoop = nullptr; + DominatorTree *DT = nullptr; + BasicBlock *loopHeader = nullptr; + BasicBlock *loopPreheader = nullptr; bool SanitizeMemory; LoopSafetyInfo SafetyInfo; @@ -185,16 +197,17 @@ namespace { public: static char ID; // Pass ID, replacement for typeid - explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) : - LoopPass(ID), OptimizeForSize(Os), redoLoop(false), - currentLoop(nullptr), DT(nullptr), loopHeader(nullptr), - loopPreheader(nullptr), hasBranchDivergence(hasBranchDivergence) { + + explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) + : LoopPass(ID), OptimizeForSize(Os), + hasBranchDivergence(hasBranchDivergence) { initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); - } + } bool runOnLoop(Loop *L, LPPassManager &LPM) override; bool processCurrentLoop(); bool isUnreachableDueToPreviousUnswitching(BasicBlock *); + /// This transformation requires natural loop information & requires that /// loop preheaders be inserted into the CFG. /// @@ -207,7 +220,6 @@ namespace { } private: - void releaseMemory() override { BranchesInfo.forgetLoop(currentLoop); } @@ -237,7 +249,7 @@ namespace { void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, BasicBlock *TrueDest, BasicBlock *FalseDest, - Instruction *InsertPt, + BranchInst *OldBranch, TerminatorInst *TI); void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); @@ -247,13 +259,13 @@ namespace { Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, Constant *Val); }; -} + +} // end anonymous namespace // Analyze loop. Check its size, calculate is it possible to unswitch // it. Returns true if we can unswitch this loop. bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI, AssumptionCache *AC) { - LoopPropsMapIt PropsIt; bool Inserted; std::tie(PropsIt, Inserted) = @@ -302,7 +314,6 @@ bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI, // Clean all data related to given loop. void LUAnalysisCache::forgetLoop(const Loop *L) { - LoopPropsMapIt LIt = LoopsProperties.find(L); if (LIt != LoopsProperties.end()) { @@ -337,7 +348,6 @@ bool LUAnalysisCache::CostAllowsUnswitching() { // Note, that new loop data is stored inside the VMap. void LUAnalysisCache::cloneData(const Loop *NewLoop, const Loop *OldLoop, const ValueToValueMapTy &VMap) { - LoopProperties &NewLoopProps = LoopsProperties[NewLoop]; LoopProperties &OldLoopProps = *CurrentLoopProperties; UnswitchedValsMap &Insts = OldLoopProps.UnswitchedVals; @@ -367,6 +377,7 @@ void LUAnalysisCache::cloneData(const Loop *NewLoop, const Loop *OldLoop, } char LoopUnswitch::ID = 0; + INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) @@ -518,9 +529,6 @@ bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { Changed |= processCurrentLoop(); } while(redoLoop); - // FIXME: Reconstruct dom info, because it is not preserved properly. - if (Changed) - DT->recalculate(*F); return Changed; } @@ -553,6 +561,48 @@ bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) { return false; } +/// FIXME: Remove this workaround when freeze related patches are done. +/// LoopUnswitch and Equality propagation in GVN have discrepancy about +/// whether branch on undef/poison has undefine behavior. Here it is to +/// rule out some common cases that we found such discrepancy already +/// causing problems. Detail could be found in PR31652. Note if the +/// func returns true, it is unsafe. But if it is false, it doesn't mean +/// it is necessarily safe. +static bool EqualityPropUnSafe(Value &LoopCond) { + ICmpInst *CI = dyn_cast<ICmpInst>(&LoopCond); + if (!CI || !CI->isEquality()) + return false; + + Value *LHS = CI->getOperand(0); + Value *RHS = CI->getOperand(1); + if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) + return true; + + auto hasUndefInPHI = [](PHINode &PN) { + for (Value *Opd : PN.incoming_values()) { + if (isa<UndefValue>(Opd)) + return true; + } + return false; + }; + PHINode *LPHI = dyn_cast<PHINode>(LHS); + PHINode *RPHI = dyn_cast<PHINode>(RHS); + if ((LPHI && hasUndefInPHI(*LPHI)) || (RPHI && hasUndefInPHI(*RPHI))) + return true; + + auto hasUndefInSelect = [](SelectInst &SI) { + if (isa<UndefValue>(SI.getTrueValue()) || + isa<UndefValue>(SI.getFalseValue())) + return true; + return false; + }; + SelectInst *LSI = dyn_cast<SelectInst>(LHS); + SelectInst *RSI = dyn_cast<SelectInst>(RHS); + if ((LSI && hasUndefInSelect(*LSI)) || (RSI && hasUndefInSelect(*RSI))) + return true; + return false; +} + /// Do actual work and unswitch loop if possible and profitable. bool LoopUnswitch::processCurrentLoop() { bool Changed = false; @@ -666,7 +716,7 @@ bool LoopUnswitch::processCurrentLoop() { // unswitch on it if we desire. Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), currentLoop, Changed).first; - if (LoopCond && + if (LoopCond && !EqualityPropUnSafe(*LoopCond) && UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { ++NumBranches; return true; @@ -831,7 +881,7 @@ bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, /// mapping the blocks with the specified map. static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, LoopInfo *LI, LPPassManager *LPM) { - Loop &New = *new Loop(); + Loop &New = *LI->AllocateLoop(); if (PL) PL->addChildLoop(&New); else @@ -852,31 +902,59 @@ static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, } /// Emit a conditional branch on two values if LIC == Val, branch to TrueDst, -/// otherwise branch to FalseDest. Insert the code immediately before InsertPt. +/// otherwise branch to FalseDest. Insert the code immediately before OldBranch +/// and remove (but not erase!) it from the function. void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, BasicBlock *TrueDest, BasicBlock *FalseDest, - Instruction *InsertPt, + BranchInst *OldBranch, TerminatorInst *TI) { + assert(OldBranch->isUnconditional() && "Preheader is not split correctly"); // Insert a conditional branch on LIC to the two preheaders. The original // code is the true version and the new code is the false version. Value *BranchVal = LIC; bool Swapped = false; if (!isa<ConstantInt>(Val) || Val->getType() != Type::getInt1Ty(LIC->getContext())) - BranchVal = new ICmpInst(InsertPt, ICmpInst::ICMP_EQ, LIC, Val); + BranchVal = new ICmpInst(OldBranch, ICmpInst::ICMP_EQ, LIC, Val); else if (Val != ConstantInt::getTrue(Val->getContext())) { // We want to enter the new loop when the condition is true. std::swap(TrueDest, FalseDest); Swapped = true; } + // Old branch will be removed, so save its parent and successor to update the + // DomTree. + auto *OldBranchSucc = OldBranch->getSuccessor(0); + auto *OldBranchParent = OldBranch->getParent(); + // Insert the new branch. BranchInst *BI = - IRBuilder<>(InsertPt).CreateCondBr(BranchVal, TrueDest, FalseDest, TI); + IRBuilder<>(OldBranch).CreateCondBr(BranchVal, TrueDest, FalseDest, TI); if (Swapped) BI->swapProfMetadata(); + // Remove the old branch so there is only one branch at the end. This is + // needed to perform DomTree's internal DFS walk on the function's CFG. + OldBranch->removeFromParent(); + + // Inform the DT about the new branch. + if (DT) { + // First, add both successors. + SmallVector<DominatorTree::UpdateType, 3> Updates; + if (TrueDest != OldBranchParent) + Updates.push_back({DominatorTree::Insert, OldBranchParent, TrueDest}); + if (FalseDest != OldBranchParent) + Updates.push_back({DominatorTree::Insert, OldBranchParent, FalseDest}); + // If both of the new successors are different from the old one, inform the + // DT that the edge was deleted. + if (OldBranchSucc != TrueDest && OldBranchSucc != FalseDest) { + Updates.push_back({DominatorTree::Delete, OldBranchParent, OldBranchSucc}); + } + + DT->applyUpdates(Updates); + } + // If either edge is critical, split it. This helps preserve LoopSimplify // form for enclosing loops. auto Options = CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA(); @@ -916,10 +994,14 @@ void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, // Okay, now we have a position to branch from and a position to branch to, // insert the new conditional branch. - EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, - loopPreheader->getTerminator(), TI); - LPM->deleteSimpleAnalysisValue(loopPreheader->getTerminator(), L); - loopPreheader->getTerminator()->eraseFromParent(); + auto *OldBranch = dyn_cast<BranchInst>(loopPreheader->getTerminator()); + assert(OldBranch && "Failed to split the preheader"); + EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI); + LPM->deleteSimpleAnalysisValue(OldBranch, L); + + // EmitPreheaderBranchOnCondition removed the OldBranch from the function. + // Delete it, as it is no longer needed. + delete OldBranch; // We need to reprocess this loop, it could be unswitched again. redoLoop = true; @@ -1035,6 +1117,9 @@ bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) return false; // Can't handle this. + if (EqualityPropUnSafe(*LoopCond)) + return false; + UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, CurrentTerm); ++NumBranches; @@ -1231,7 +1316,10 @@ void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR, TI); LPM->deleteSimpleAnalysisValue(OldBR, L); - OldBR->eraseFromParent(); + + // The OldBr was replaced by a new one and removed (but not erased) by + // EmitPreheaderBranchOnCondition. It is no longer needed, so delete it. + delete OldBR; LoopProcessWorklist.push_back(NewLoop); redoLoop = true; diff --git a/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/lib/Transforms/Scalar/LoopVersioningLICM.cpp index c23d891b6504..53b25e688e82 100644 --- a/lib/Transforms/Scalar/LoopVersioningLICM.cpp +++ b/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -1,4 +1,4 @@ -//===----------- LoopVersioningLICM.cpp - LICM Loop Versioning ------------===// +//===- LoopVersioningLICM.cpp - LICM Loop Versioning ----------------------===// // // The LLVM Compiler Infrastructure // @@ -60,41 +60,41 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AliasSetTracker.h" -#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpander.h" -#include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/Analysis/ValueTracking.h" -#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" -#include "llvm/IR/PatternMatch.h" -#include "llvm/IR/PredIteratorCache.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" -#include "llvm/Transforms/Utils/ValueMapper.h" +#include <cassert> +#include <memory> + +using namespace llvm; #define DEBUG_TYPE "loop-versioning-licm" -static const char *LICMVersioningMetaData = "llvm.loop.licm_versioning.disable"; -using namespace llvm; +static const char *LICMVersioningMetaData = "llvm.loop.licm_versioning.disable"; /// Threshold minimum allowed percentage for possible /// invariant instructions in a loop. @@ -143,9 +143,16 @@ void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *MDString, } namespace { + struct LoopVersioningLICM : public LoopPass { static char ID; + LoopVersioningLICM() + : LoopPass(ID), LoopDepthThreshold(LVLoopDepthThreshold), + InvariantThreshold(LVInvarThreshold) { + initializeLoopVersioningLICMPass(*PassRegistry::getPassRegistry()); + } + bool runOnLoop(Loop *L, LPPassManager &LPM) override; void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -161,13 +168,6 @@ struct LoopVersioningLICM : public LoopPass { AU.addPreserved<GlobalsAAWrapperPass>(); } - LoopVersioningLICM() - : LoopPass(ID), AA(nullptr), SE(nullptr), LAA(nullptr), LAI(nullptr), - CurLoop(nullptr), LoopDepthThreshold(LVLoopDepthThreshold), - InvariantThreshold(LVInvarThreshold), LoadAndStoreCounter(0), - InvariantCounter(0), IsReadOnlyLoop(true) { - initializeLoopVersioningLICMPass(*PassRegistry::getPassRegistry()); - } StringRef getPassName() const override { return "Loop Versioning for LICM"; } void reset() { @@ -191,30 +191,49 @@ struct LoopVersioningLICM : public LoopPass { }; private: - AliasAnalysis *AA; // Current AliasAnalysis information - ScalarEvolution *SE; // Current ScalarEvolution - LoopAccessLegacyAnalysis *LAA; // Current LoopAccessAnalysis - const LoopAccessInfo *LAI; // Current Loop's LoopAccessInfo + // Current AliasAnalysis information + AliasAnalysis *AA = nullptr; + + // Current ScalarEvolution + ScalarEvolution *SE = nullptr; + + // Current LoopAccessAnalysis + LoopAccessLegacyAnalysis *LAA = nullptr; + + // Current Loop's LoopAccessInfo + const LoopAccessInfo *LAI = nullptr; + + // The current loop we are working on. + Loop *CurLoop = nullptr; + + // AliasSet information for the current loop. + std::unique_ptr<AliasSetTracker> CurAST; - Loop *CurLoop; // The current loop we are working on. - std::unique_ptr<AliasSetTracker> - CurAST; // AliasSet information for the current loop. + // Maximum loop nest threshold + unsigned LoopDepthThreshold; - unsigned LoopDepthThreshold; // Maximum loop nest threshold - float InvariantThreshold; // Minimum invariant threshold - unsigned LoadAndStoreCounter; // Counter to track num of load & store - unsigned InvariantCounter; // Counter to track num of invariant - bool IsReadOnlyLoop; // Read only loop marker. + // Minimum invariant threshold + float InvariantThreshold; + + // Counter to track num of load & store + unsigned LoadAndStoreCounter = 0; + + // Counter to track num of invariant + unsigned InvariantCounter = 0; + + // Read only loop marker. + bool IsReadOnlyLoop = true; bool isLegalForVersioning(); bool legalLoopStructure(); bool legalLoopInstructions(); bool legalLoopMemoryAccesses(); bool isLoopAlreadyVisited(); - void setNoAliasToLoop(Loop *); - bool instructionSafeForVersioning(Instruction *); + void setNoAliasToLoop(Loop *VerLoop); + bool instructionSafeForVersioning(Instruction *I); }; -} + +} // end anonymous namespace /// \brief Check loop structure and confirms it's good for LoopVersioningLICM. bool LoopVersioningLICM::legalLoopStructure() { @@ -225,7 +244,7 @@ bool LoopVersioningLICM::legalLoopStructure() { return false; } // Loop should be innermost loop, if not return false. - if (CurLoop->getSubLoops().size()) { + if (!CurLoop->getSubLoops().empty()) { DEBUG(dbgs() << " loop is not innermost\n"); return false; } @@ -562,6 +581,7 @@ bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { } char LoopVersioningLICM::ID = 0; + INITIALIZE_PASS_BEGIN(LoopVersioningLICM, "loop-versioning-licm", "Loop Versioning For LICM", false, false) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) diff --git a/lib/Transforms/Scalar/LowerAtomic.cpp b/lib/Transforms/Scalar/LowerAtomic.cpp index 6f77c5bd0d07..c165c5ece95c 100644 --- a/lib/Transforms/Scalar/LowerAtomic.cpp +++ b/lib/Transforms/Scalar/LowerAtomic.cpp @@ -15,7 +15,6 @@ #include "llvm/Transforms/Scalar/LowerAtomic.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" using namespace llvm; diff --git a/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/lib/Transforms/Scalar/MemCpyOptimizer.cpp index 7896396f0898..9c870b42a747 100644 --- a/lib/Transforms/Scalar/MemCpyOptimizer.cpp +++ b/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -14,10 +14,12 @@ #include "llvm/Transforms/Scalar/MemCpyOptimizer.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" @@ -25,6 +27,8 @@ #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" @@ -41,6 +45,7 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/Type.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" @@ -54,6 +59,7 @@ #include <algorithm> #include <cassert> #include <cstdint> +#include <utility> using namespace llvm; @@ -225,15 +231,18 @@ bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { namespace { class MemsetRanges { + using range_iterator = SmallVectorImpl<MemsetRange>::iterator; + /// A sorted list of the memset ranges. SmallVector<MemsetRange, 8> Ranges; - typedef SmallVectorImpl<MemsetRange>::iterator range_iterator; + const DataLayout &DL; public: MemsetRanges(const DataLayout &DL) : DL(DL) {} - typedef SmallVectorImpl<MemsetRange>::const_iterator const_iterator; + using const_iterator = SmallVectorImpl<MemsetRange>::const_iterator; + const_iterator begin() const { return Ranges.begin(); } const_iterator end() const { return Ranges.end(); } bool empty() const { return Ranges.empty(); } @@ -259,7 +268,6 @@ public: void addRange(int64_t Start, int64_t Size, Value *Ptr, unsigned Alignment, Instruction *Inst); - }; } // end anonymous namespace @@ -356,10 +364,10 @@ private: } }; -char MemCpyOptLegacyPass::ID = 0; - } // end anonymous namespace +char MemCpyOptLegacyPass::ID = 0; + /// The public interface to this file... FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); } @@ -450,7 +458,6 @@ Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, // emit memset's for anything big enough to be worthwhile. Instruction *AMemSet = nullptr; for (const MemsetRange &Range : Ranges) { - if (Range.TheStores.size() == 1) continue; // If it is profitable to lower this range to memset, do so now. @@ -511,7 +518,7 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, const LoadInst *LI) { // If the store alias this position, early bail out. MemoryLocation StoreLoc = MemoryLocation::get(SI); - if (AA.getModRefInfo(P, StoreLoc) != MRI_NoModRef) + if (isModOrRefSet(AA.getModRefInfo(P, StoreLoc))) return false; // Keep track of the arguments of all instruction we plan to lift @@ -535,20 +542,20 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, for (auto I = --SI->getIterator(), E = P->getIterator(); I != E; --I) { auto *C = &*I; - bool MayAlias = AA.getModRefInfo(C) != MRI_NoModRef; + bool MayAlias = isModOrRefSet(AA.getModRefInfo(C, None)); bool NeedLift = false; if (Args.erase(C)) NeedLift = true; else if (MayAlias) { NeedLift = llvm::any_of(MemLocs, [C, &AA](const MemoryLocation &ML) { - return AA.getModRefInfo(C, ML); + return isModOrRefSet(AA.getModRefInfo(C, ML)); }); if (!NeedLift) NeedLift = llvm::any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { - return AA.getModRefInfo(C, CS); + return isModOrRefSet(AA.getModRefInfo(C, CS)); }); } @@ -558,18 +565,18 @@ static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, if (MayAlias) { // Since LI is implicitly moved downwards past the lifted instructions, // none of them may modify its source. - if (AA.getModRefInfo(C, LoadLoc) & MRI_Mod) + if (isModSet(AA.getModRefInfo(C, LoadLoc))) return false; else if (auto CS = ImmutableCallSite(C)) { // If we can't lift this before P, it's game over. - if (AA.getModRefInfo(P, CS) != MRI_NoModRef) + if (isModOrRefSet(AA.getModRefInfo(P, CS))) return false; CallSites.push_back(CS); } else if (isa<LoadInst>(C) || isa<StoreInst>(C) || isa<VAArgInst>(C)) { // If we can't lift this before P, it's game over. auto ML = MemoryLocation::get(C); - if (AA.getModRefInfo(P, ML) != MRI_NoModRef) + if (isModOrRefSet(AA.getModRefInfo(P, ML))) return false; MemLocs.push_back(ML); @@ -624,7 +631,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { // of at the store position. Instruction *P = SI; for (auto &I : make_range(++LI->getIterator(), SI->getIterator())) { - if (AA.getModRefInfo(&I, LoadLoc) & MRI_Mod) { + if (isModSet(AA.getModRefInfo(&I, LoadLoc))) { P = &I; break; } @@ -695,7 +702,7 @@ bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { MemoryLocation StoreLoc = MemoryLocation::get(SI); for (BasicBlock::iterator I = --SI->getIterator(), E = C->getIterator(); I != E; --I) { - if (AA.getModRefInfo(&*I, StoreLoc) != MRI_NoModRef) { + if (isModOrRefSet(AA.getModRefInfo(&*I, StoreLoc))) { C = nullptr; break; } @@ -927,9 +934,9 @@ bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, AliasAnalysis &AA = LookupAliasAnalysis(); ModRefInfo MR = AA.getModRefInfo(C, cpyDest, srcSize); // If necessary, perform additional analysis. - if (MR != MRI_NoModRef) + if (isModOrRefSet(MR)) MR = AA.callCapturesBefore(C, cpyDest, srcSize, &DT); - if (MR != MRI_NoModRef) + if (isModOrRefSet(MR)) return false; // We can't create address space casts here because we don't know if they're diff --git a/lib/Transforms/Scalar/MergeICmps.cpp b/lib/Transforms/Scalar/MergeICmps.cpp new file mode 100644 index 000000000000..9869a3fb96fa --- /dev/null +++ b/lib/Transforms/Scalar/MergeICmps.cpp @@ -0,0 +1,650 @@ +//===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass turns chains of integer comparisons into memcmp (the memcmp is +// later typically inlined as a chain of efficient hardware comparisons). This +// typically benefits c++ member or nonmember operator==(). +// +// The basic idea is to replace a larger chain of integer comparisons loaded +// from contiguous memory locations into a smaller chain of such integer +// comparisons. Benefits are double: +// - There are less jumps, and therefore less opportunities for mispredictions +// and I-cache misses. +// - Code size is smaller, both because jumps are removed and because the +// encoding of a 2*n byte compare is smaller than that of two n-byte +// compares. + +//===----------------------------------------------------------------------===// + +#include <algorithm> +#include <numeric> +#include <utility> +#include <vector> +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" + +using namespace llvm; + +namespace { + +#define DEBUG_TYPE "mergeicmps" + +// A BCE atom. +struct BCEAtom { + BCEAtom() : GEP(nullptr), LoadI(nullptr), Offset() {} + + const Value *Base() const { return GEP ? GEP->getPointerOperand() : nullptr; } + + bool operator<(const BCEAtom &O) const { + assert(Base() && "invalid atom"); + assert(O.Base() && "invalid atom"); + // Just ordering by (Base(), Offset) is sufficient. However because this + // means that the ordering will depend on the addresses of the base + // values, which are not reproducible from run to run. To guarantee + // stability, we use the names of the values if they exist; we sort by: + // (Base.getName(), Base(), Offset). + const int NameCmp = Base()->getName().compare(O.Base()->getName()); + if (NameCmp == 0) { + if (Base() == O.Base()) { + return Offset.slt(O.Offset); + } + return Base() < O.Base(); + } + return NameCmp < 0; + } + + GetElementPtrInst *GEP; + LoadInst *LoadI; + APInt Offset; +}; + +// If this value is a load from a constant offset w.r.t. a base address, and +// there are no othe rusers of the load or address, returns the base address and +// the offset. +BCEAtom visitICmpLoadOperand(Value *const Val) { + BCEAtom Result; + if (auto *const LoadI = dyn_cast<LoadInst>(Val)) { + DEBUG(dbgs() << "load\n"); + if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + if (LoadI->isVolatile()) { + DEBUG(dbgs() << "volatile\n"); + return {}; + } + Value *const Addr = LoadI->getOperand(0); + if (auto *const GEP = dyn_cast<GetElementPtrInst>(Addr)) { + DEBUG(dbgs() << "GEP\n"); + if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + const auto &DL = GEP->getModule()->getDataLayout(); + if (!isDereferenceablePointer(GEP, DL)) { + DEBUG(dbgs() << "not dereferenceable\n"); + // We need to make sure that we can do comparison in any order, so we + // require memory to be unconditionnally dereferencable. + return {}; + } + Result.Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); + if (GEP->accumulateConstantOffset(DL, Result.Offset)) { + Result.GEP = GEP; + Result.LoadI = LoadI; + } + } + } + return Result; +} + +// A basic block with a comparison between two BCE atoms. +// Note: the terminology is misleading: the comparison is symmetric, so there +// is no real {l/r}hs. What we want though is to have the same base on the +// left (resp. right), so that we can detect consecutive loads. To ensure this +// we put the smallest atom on the left. +class BCECmpBlock { + public: + BCECmpBlock() {} + + BCECmpBlock(BCEAtom L, BCEAtom R, int SizeBits) + : Lhs_(L), Rhs_(R), SizeBits_(SizeBits) { + if (Rhs_ < Lhs_) std::swap(Rhs_, Lhs_); + } + + bool IsValid() const { + return Lhs_.Base() != nullptr && Rhs_.Base() != nullptr; + } + + // Assert the the block is consistent: If valid, it should also have + // non-null members besides Lhs_ and Rhs_. + void AssertConsistent() const { + if (IsValid()) { + assert(BB); + assert(CmpI); + assert(BranchI); + } + } + + const BCEAtom &Lhs() const { return Lhs_; } + const BCEAtom &Rhs() const { return Rhs_; } + int SizeBits() const { return SizeBits_; } + + // Returns true if the block does other works besides comparison. + bool doesOtherWork() const; + + // The basic block where this comparison happens. + BasicBlock *BB = nullptr; + // The ICMP for this comparison. + ICmpInst *CmpI = nullptr; + // The terminating branch. + BranchInst *BranchI = nullptr; + + private: + BCEAtom Lhs_; + BCEAtom Rhs_; + int SizeBits_ = 0; +}; + +bool BCECmpBlock::doesOtherWork() const { + AssertConsistent(); + // TODO(courbet): Can we allow some other things ? This is very conservative. + // We might be able to get away with anything does does not have any side + // effects outside of the basic block. + // Note: The GEPs and/or loads are not necessarily in the same block. + for (const Instruction &Inst : *BB) { + if (const auto *const GEP = dyn_cast<GetElementPtrInst>(&Inst)) { + if (!(Lhs_.GEP == GEP || Rhs_.GEP == GEP)) return true; + } else if (const auto *const L = dyn_cast<LoadInst>(&Inst)) { + if (!(Lhs_.LoadI == L || Rhs_.LoadI == L)) return true; + } else if (const auto *const C = dyn_cast<ICmpInst>(&Inst)) { + if (C != CmpI) return true; + } else if (const auto *const Br = dyn_cast<BranchInst>(&Inst)) { + if (Br != BranchI) return true; + } else { + return true; + } + } + return false; +} + +// Visit the given comparison. If this is a comparison between two valid +// BCE atoms, returns the comparison. +BCECmpBlock visitICmp(const ICmpInst *const CmpI, + const ICmpInst::Predicate ExpectedPredicate) { + if (CmpI->getPredicate() == ExpectedPredicate) { + DEBUG(dbgs() << "cmp " + << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") + << "\n"); + auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0)); + if (!Lhs.Base()) return {}; + auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1)); + if (!Rhs.Base()) return {}; + return BCECmpBlock(std::move(Lhs), std::move(Rhs), + CmpI->getOperand(0)->getType()->getScalarSizeInBits()); + } + return {}; +} + +// Visit the given comparison block. If this is a comparison between two valid +// BCE atoms, returns the comparison. +BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, + const BasicBlock *const PhiBlock) { + if (Block->empty()) return {}; + auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); + if (!BranchI) return {}; + DEBUG(dbgs() << "branch\n"); + if (BranchI->isUnconditional()) { + // In this case, we expect an incoming value which is the result of the + // comparison. This is the last link in the chain of comparisons (note + // that this does not mean that this is the last incoming value, blocks + // can be reordered). + auto *const CmpI = dyn_cast<ICmpInst>(Val); + if (!CmpI) return {}; + DEBUG(dbgs() << "icmp\n"); + auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ); + Result.CmpI = CmpI; + Result.BranchI = BranchI; + return Result; + } else { + // In this case, we expect a constant incoming value (the comparison is + // chained). + const auto *const Const = dyn_cast<ConstantInt>(Val); + DEBUG(dbgs() << "const\n"); + if (!Const->isZero()) return {}; + DEBUG(dbgs() << "false\n"); + auto *const CmpI = dyn_cast<ICmpInst>(BranchI->getCondition()); + if (!CmpI) return {}; + DEBUG(dbgs() << "icmp\n"); + assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); + BasicBlock *const FalseBlock = BranchI->getSuccessor(1); + auto Result = visitICmp( + CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE); + Result.CmpI = CmpI; + Result.BranchI = BranchI; + return Result; + } + return {}; +} + +// A chain of comparisons. +class BCECmpChain { + public: + BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi); + + int size() const { return Comparisons_.size(); } + +#ifdef MERGEICMPS_DOT_ON + void dump() const; +#endif // MERGEICMPS_DOT_ON + + bool simplify(const TargetLibraryInfo *const TLI); + + private: + static bool IsContiguous(const BCECmpBlock &First, + const BCECmpBlock &Second) { + return First.Lhs().Base() == Second.Lhs().Base() && + First.Rhs().Base() == Second.Rhs().Base() && + First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && + First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; + } + + // Merges the given comparison blocks into one memcmp block and update + // branches. Comparisons are assumed to be continguous. If NextBBInChain is + // null, the merged block will link to the phi block. + static void mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, + BasicBlock *const NextBBInChain, PHINode &Phi, + const TargetLibraryInfo *const TLI); + + PHINode &Phi_; + std::vector<BCECmpBlock> Comparisons_; + // The original entry block (before sorting); + BasicBlock *EntryBlock_; +}; + +BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) + : Phi_(Phi) { + // Now look inside blocks to check for BCE comparisons. + std::vector<BCECmpBlock> Comparisons; + for (BasicBlock *Block : Blocks) { + BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block), + Block, Phi.getParent()); + Comparison.BB = Block; + if (!Comparison.IsValid()) { + DEBUG(dbgs() << "skip: not a valid BCECmpBlock\n"); + return; + } + if (Comparison.doesOtherWork()) { + DEBUG(dbgs() << "block does extra work besides compare\n"); + if (Comparisons.empty()) { // First block. + // TODO(courbet): The first block can do other things, and we should + // split them apart in a separate block before the comparison chain. + // Right now we just discard it and make the chain shorter. + DEBUG(dbgs() + << "ignoring first block that does extra work besides compare\n"); + continue; + } + // TODO(courbet): Right now we abort the whole chain. We could be + // merging only the blocks that don't do other work and resume the + // chain from there. For example: + // if (a[0] == b[0]) { // bb1 + // if (a[1] == b[1]) { // bb2 + // some_value = 3; //bb3 + // if (a[2] == b[2]) { //bb3 + // do a ton of stuff //bb4 + // } + // } + // } + // + // This is: + // + // bb1 --eq--> bb2 --eq--> bb3* -eq--> bb4 --+ + // \ \ \ \ + // ne ne ne \ + // \ \ \ v + // +------------+-----------+----------> bb_phi + // + // We can only merge the first two comparisons, because bb3* does + // "other work" (setting some_value to 3). + // We could still merge bb1 and bb2 though. + return; + } + DEBUG(dbgs() << "*Found cmp of " << Comparison.SizeBits() + << " bits between " << Comparison.Lhs().Base() << " + " + << Comparison.Lhs().Offset << " and " + << Comparison.Rhs().Base() << " + " << Comparison.Rhs().Offset + << "\n"); + DEBUG(dbgs() << "\n"); + Comparisons.push_back(Comparison); + } + EntryBlock_ = Comparisons[0].BB; + Comparisons_ = std::move(Comparisons); +#ifdef MERGEICMPS_DOT_ON + errs() << "BEFORE REORDERING:\n\n"; + dump(); +#endif // MERGEICMPS_DOT_ON + // Reorder blocks by LHS. We can do that without changing the + // semantics because we are only accessing dereferencable memory. + std::sort(Comparisons_.begin(), Comparisons_.end(), + [](const BCECmpBlock &a, const BCECmpBlock &b) { + return a.Lhs() < b.Lhs(); + }); +#ifdef MERGEICMPS_DOT_ON + errs() << "AFTER REORDERING:\n\n"; + dump(); +#endif // MERGEICMPS_DOT_ON +} + +#ifdef MERGEICMPS_DOT_ON +void BCECmpChain::dump() const { + errs() << "digraph dag {\n"; + errs() << " graph [bgcolor=transparent];\n"; + errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n"; + errs() << " edge [color=black];\n"; + for (size_t I = 0; I < Comparisons_.size(); ++I) { + const auto &Comparison = Comparisons_[I]; + errs() << " \"" << I << "\" [label=\"%" + << Comparison.Lhs().Base()->getName() << " + " + << Comparison.Lhs().Offset << " == %" + << Comparison.Rhs().Base()->getName() << " + " + << Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8) + << " bytes)\"];\n"; + const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB); + if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n"; + errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n"; + } + errs() << " \"Phi\" [label=\"Phi\"];\n"; + errs() << "}\n\n"; +} +#endif // MERGEICMPS_DOT_ON + +bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI) { + // First pass to check if there is at least one merge. If not, we don't do + // anything and we keep analysis passes intact. + { + bool AtLeastOneMerged = false; + for (size_t I = 1; I < Comparisons_.size(); ++I) { + if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { + AtLeastOneMerged = true; + break; + } + } + if (!AtLeastOneMerged) return false; + } + + // Remove phi references to comparison blocks, they will be rebuilt as we + // merge the blocks. + for (const auto &Comparison : Comparisons_) { + Phi_.removeIncomingValue(Comparison.BB, false); + } + + // Point the predecessors of the chain to the first comparison block (which is + // the new entry point). + if (EntryBlock_ != Comparisons_[0].BB) + EntryBlock_->replaceAllUsesWith(Comparisons_[0].BB); + + // Effectively merge blocks. + int NumMerged = 1; + for (size_t I = 1; I < Comparisons_.size(); ++I) { + if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { + ++NumMerged; + } else { + // Merge all previous comparisons and start a new merge block. + mergeComparisons( + makeArrayRef(Comparisons_).slice(I - NumMerged, NumMerged), + Comparisons_[I].BB, Phi_, TLI); + NumMerged = 1; + } + } + mergeComparisons(makeArrayRef(Comparisons_) + .slice(Comparisons_.size() - NumMerged, NumMerged), + nullptr, Phi_, TLI); + + return true; +} + +void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, + BasicBlock *const NextBBInChain, + PHINode &Phi, + const TargetLibraryInfo *const TLI) { + assert(!Comparisons.empty()); + const auto &FirstComparison = *Comparisons.begin(); + BasicBlock *const BB = FirstComparison.BB; + LLVMContext &Context = BB->getContext(); + + if (Comparisons.size() >= 2) { + DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); + const auto TotalSize = + std::accumulate(Comparisons.begin(), Comparisons.end(), 0, + [](int Size, const BCECmpBlock &C) { + return Size + C.SizeBits(); + }) / + 8; + + // Incoming edges do not need to be updated, and both GEPs are already + // computing the right address, we just need to: + // - replace the two loads and the icmp with the memcmp + // - update the branch + // - update the incoming values in the phi. + FirstComparison.BranchI->eraseFromParent(); + FirstComparison.CmpI->eraseFromParent(); + FirstComparison.Lhs().LoadI->eraseFromParent(); + FirstComparison.Rhs().LoadI->eraseFromParent(); + + IRBuilder<> Builder(BB); + const auto &DL = Phi.getModule()->getDataLayout(); + Value *const MemCmpCall = emitMemCmp( + FirstComparison.Lhs().GEP, FirstComparison.Rhs().GEP, ConstantInt::get(DL.getIntPtrType(Context), TotalSize), + Builder, DL, TLI); + Value *const MemCmpIsZero = Builder.CreateICmpEQ( + MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); + + // Add a branch to the next basic block in the chain. + if (NextBBInChain) { + Builder.CreateCondBr(MemCmpIsZero, NextBBInChain, Phi.getParent()); + Phi.addIncoming(ConstantInt::getFalse(Context), BB); + } else { + Builder.CreateBr(Phi.getParent()); + Phi.addIncoming(MemCmpIsZero, BB); + } + + // Delete merged blocks. + for (size_t I = 1; I < Comparisons.size(); ++I) { + BasicBlock *CBB = Comparisons[I].BB; + CBB->replaceAllUsesWith(BB); + CBB->eraseFromParent(); + } + } else { + assert(Comparisons.size() == 1); + // There are no blocks to merge, but we still need to update the branches. + DEBUG(dbgs() << "Only one comparison, updating branches\n"); + if (NextBBInChain) { + if (FirstComparison.BranchI->isConditional()) { + DEBUG(dbgs() << "conditional -> conditional\n"); + // Just update the "true" target, the "false" target should already be + // the phi block. + assert(FirstComparison.BranchI->getSuccessor(1) == Phi.getParent()); + FirstComparison.BranchI->setSuccessor(0, NextBBInChain); + Phi.addIncoming(ConstantInt::getFalse(Context), BB); + } else { + DEBUG(dbgs() << "unconditional -> conditional\n"); + // Replace the unconditional branch by a conditional one. + FirstComparison.BranchI->eraseFromParent(); + IRBuilder<> Builder(BB); + Builder.CreateCondBr(FirstComparison.CmpI, NextBBInChain, + Phi.getParent()); + Phi.addIncoming(FirstComparison.CmpI, BB); + } + } else { + if (FirstComparison.BranchI->isConditional()) { + DEBUG(dbgs() << "conditional -> unconditional\n"); + // Replace the conditional branch by an unconditional one. + FirstComparison.BranchI->eraseFromParent(); + IRBuilder<> Builder(BB); + Builder.CreateBr(Phi.getParent()); + Phi.addIncoming(FirstComparison.CmpI, BB); + } else { + DEBUG(dbgs() << "unconditional -> unconditional\n"); + Phi.addIncoming(FirstComparison.CmpI, BB); + } + } + } +} + +std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, + BasicBlock *const LastBlock, + int NumBlocks) { + // Walk up from the last block to find other blocks. + std::vector<BasicBlock *> Blocks(NumBlocks); + BasicBlock *CurBlock = LastBlock; + for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) { + if (CurBlock->hasAddressTaken()) { + // Somebody is jumping to the block through an address, all bets are + // off. + DEBUG(dbgs() << "skip: block " << BlockIndex + << " has its address taken\n"); + return {}; + } + Blocks[BlockIndex] = CurBlock; + auto *SinglePredecessor = CurBlock->getSinglePredecessor(); + if (!SinglePredecessor) { + // The block has two or more predecessors. + DEBUG(dbgs() << "skip: block " << BlockIndex + << " has two or more predecessors\n"); + return {}; + } + if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) { + // The block does not link back to the phi. + DEBUG(dbgs() << "skip: block " << BlockIndex + << " does not link back to the phi\n"); + return {}; + } + CurBlock = SinglePredecessor; + } + Blocks[0] = CurBlock; + return Blocks; +} + +bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { + DEBUG(dbgs() << "processPhi()\n"); + if (Phi.getNumIncomingValues() <= 1) { + DEBUG(dbgs() << "skip: only one incoming value in phi\n"); + return false; + } + // We are looking for something that has the following structure: + // bb1 --eq--> bb2 --eq--> bb3 --eq--> bb4 --+ + // \ \ \ \ + // ne ne ne \ + // \ \ \ v + // +------------+-----------+----------> bb_phi + // + // - The last basic block (bb4 here) must branch unconditionally to bb_phi. + // It's the only block that contributes a non-constant value to the Phi. + // - All other blocks (b1, b2, b3) must have exactly two successors, one of + // them being the the phi block. + // - All intermediate blocks (bb2, bb3) must have only one predecessor. + // - Blocks cannot do other work besides the comparison, see doesOtherWork() + + // The blocks are not necessarily ordered in the phi, so we start from the + // last block and reconstruct the order. + BasicBlock *LastBlock = nullptr; + for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) { + if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue; + if (LastBlock) { + // There are several non-constant values. + DEBUG(dbgs() << "skip: several non-constant values\n"); + return false; + } + LastBlock = Phi.getIncomingBlock(I); + } + if (!LastBlock) { + // There is no non-constant block. + DEBUG(dbgs() << "skip: no non-constant block\n"); + return false; + } + if (LastBlock->getSingleSuccessor() != Phi.getParent()) { + DEBUG(dbgs() << "skip: last block non-phi successor\n"); + return false; + } + + const auto Blocks = + getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues()); + if (Blocks.empty()) return false; + BCECmpChain CmpChain(Blocks, Phi); + + if (CmpChain.size() < 2) { + DEBUG(dbgs() << "skip: only one compare block\n"); + return false; + } + + return CmpChain.simplify(TLI); +} + +class MergeICmps : public FunctionPass { + public: + static char ID; + + MergeICmps() : FunctionPass(ID) { + initializeMergeICmpsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) return false; + const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto PA = runImpl(F, &TLI, &TTI); + return !PA.areAllPreserved(); + } + + private: + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI); +}; + +PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI) { + DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); + + // We only try merging comparisons if the target wants to expand memcmp later. + // The rationale is to avoid turning small chains into memcmp calls. + if (!TTI->enableMemCmpExpansion(true)) return PreservedAnalyses::all(); + + bool MadeChange = false; + + for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { + // A Phi operation is always first in a basic block. + if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) + MadeChange |= processPhi(*Phi, TLI); + } + + if (MadeChange) return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +} // namespace + +char MergeICmps::ID = 0; +INITIALIZE_PASS_BEGIN(MergeICmps, "mergeicmps", + "Merge contiguous icmps into a memcmp", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(MergeICmps, "mergeicmps", + "Merge contiguous icmps into a memcmp", false, false) + +Pass *llvm::createMergeICmpsPass() { return new MergeICmps(); } diff --git a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp index 6727cf0179c1..f2f615cb9b0f 100644 --- a/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp +++ b/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -80,11 +80,9 @@ #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" -#include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/Metadata.h" -#include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" @@ -195,7 +193,7 @@ bool MergedLoadStoreMotion::isStoreSinkBarrierInRange(const Instruction &Start, make_range(Start.getIterator(), End.getIterator())) if (Inst.mayThrow()) return true; - return AA->canInstructionRangeModRef(Start, End, Loc, MRI_ModRef); + return AA->canInstructionRangeModRef(Start, End, Loc, ModRefInfo::ModRef); } /// diff --git a/lib/Transforms/Scalar/NaryReassociate.cpp b/lib/Transforms/Scalar/NaryReassociate.cpp index d0bfe3603897..b026c8d692c3 100644 --- a/lib/Transforms/Scalar/NaryReassociate.cpp +++ b/lib/Transforms/Scalar/NaryReassociate.cpp @@ -77,19 +77,45 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/NaryReassociate.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/Debug.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> + using namespace llvm; using namespace PatternMatch; #define DEBUG_TYPE "nary-reassociate" namespace { + class NaryReassociateLegacyPass : public FunctionPass { public: static char ID; @@ -101,6 +127,7 @@ public: bool doInitialization(Module &M) override { return false; } + bool runOnFunction(Function &F) override; void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -118,9 +145,11 @@ public: private: NaryReassociatePass Impl; }; -} // anonymous namespace + +} // end anonymous namespace char NaryReassociateLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass, "nary-reassociate", "Nary reassociation", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) diff --git a/lib/Transforms/Scalar/NewGVN.cpp b/lib/Transforms/Scalar/NewGVN.cpp index 8ac10348eb77..9ebf2d769356 100644 --- a/lib/Transforms/Scalar/NewGVN.cpp +++ b/lib/Transforms/Scalar/NewGVN.cpp @@ -1,4 +1,4 @@ -//===---- NewGVN.cpp - Global Value Numbering Pass --------------*- C++ -*-===// +//===- NewGVN.cpp - Global Value Numbering Pass ---------------------------===// // // The LLVM Compiler Infrastructure // @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// /// \file /// This file implements the new LLVM's Global Value Numbering pass. /// GVN partitions values computed by a function into congruence classes. @@ -48,59 +49,81 @@ /// published algorithms are O(Instructions). Instead, we use a technique that /// is O(number of operations with the same value number), enabling us to skip /// trying to eliminate things that have unique value numbers. +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/NewGVN.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/BitVector.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/GraphTraits.h" #include "llvm/ADT/Hashing.h" -#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SparseBitVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" -#include "llvm/Analysis/CFG.h" #include "llvm/Analysis/CFGPrinter.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/MemoryBuiltins.h" -#include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/MemorySSA.h" #include "llvm/Analysis/TargetLibraryInfo.h" -#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/Dominators.h" -#include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" -#include "llvm/IR/Metadata.h" -#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" #include "llvm/Support/Allocator.h" +#include "llvm/Support/ArrayRecycler.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/DebugCounter.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/PointerLikeTypeTraits.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Scalar/GVNExpression.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PredicateInfo.h" #include "llvm/Transforms/Utils/VNCoercion.h" -#include <numeric> -#include <unordered_map> +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <map> +#include <memory> +#include <set> +#include <string> +#include <tuple> #include <utility> #include <vector> + using namespace llvm; -using namespace PatternMatch; using namespace llvm::GVNExpression; using namespace llvm::VNCoercion; + #define DEBUG_TYPE "newgvn" STATISTIC(NumGVNInstrDeleted, "Number of instructions deleted"); @@ -118,15 +141,19 @@ STATISTIC(NumGVNPHIOfOpsCreated, "Number of PHI of ops created"); STATISTIC(NumGVNPHIOfOpsEliminations, "Number of things eliminated using PHI of ops"); DEBUG_COUNTER(VNCounter, "newgvn-vn", - "Controls which instructions are value numbered") + "Controls which instructions are value numbered"); DEBUG_COUNTER(PHIOfOpsCounter, "newgvn-phi", - "Controls which instructions we create phi of ops for") + "Controls which instructions we create phi of ops for"); // Currently store defining access refinement is too slow due to basicaa being // egregiously slow. This flag lets us keep it working while we work on this // issue. static cl::opt<bool> EnableStoreRefinement("enable-store-refinement", cl::init(false), cl::Hidden); +/// Currently, the generation "phi of ops" can result in correctness issues. +static cl::opt<bool> EnablePhiOfOps("enable-phi-of-ops", cl::init(true), + cl::Hidden); + //===----------------------------------------------------------------------===// // GVN Pass //===----------------------------------------------------------------------===// @@ -134,6 +161,7 @@ static cl::opt<bool> EnableStoreRefinement("enable-store-refinement", // Anchor methods. namespace llvm { namespace GVNExpression { + Expression::~Expression() = default; BasicExpression::~BasicExpression() = default; CallExpression::~CallExpression() = default; @@ -141,8 +169,11 @@ LoadExpression::~LoadExpression() = default; StoreExpression::~StoreExpression() = default; AggregateValueExpression::~AggregateValueExpression() = default; PHIExpression::~PHIExpression() = default; -} -} + +} // end namespace GVNExpression +} // end namespace llvm + +namespace { // Tarjan's SCC finding algorithm with Nuutila's improvements // SCCIterator is actually fairly complex for the simple thing we want. @@ -153,7 +184,6 @@ PHIExpression::~PHIExpression() = default; // instructions, // not generic values (arguments, etc). struct TarjanSCC { - TarjanSCC() : Components(1) {} void Start(const Instruction *Start) { @@ -208,15 +238,19 @@ private: Stack.push_back(I); } } + unsigned int DFSNum = 1; SmallPtrSet<const Value *, 8> InComponent; DenseMap<const Value *, unsigned int> Root; SmallVector<const Value *, 8> Stack; + // Store the components as vector of ptr sets, because we need the topo order // of SCC's, but not individual member order SmallVector<SmallPtrSet<const Value *, 8>, 8> Components; + DenseMap<const Value *, unsigned> ValueToComponent; }; + // Congruence classes represent the set of expressions/instructions // that are all the same *during some scope in the function*. // That is, because of the way we perform equality propagation, and @@ -265,7 +299,9 @@ public: explicit CongruenceClass(unsigned ID) : ID(ID) {} CongruenceClass(unsigned ID, Value *Leader, const Expression *E) : ID(ID), RepLeader(Leader), DefiningExpr(E) {} + unsigned getID() const { return ID; } + // True if this class has no members left. This is mainly used for assertion // purposes, and for skipping empty classes. bool isDead() const { @@ -273,6 +309,7 @@ public: // perspective, it's really dead. return empty() && memory_empty(); } + // Leader functions Value *getLeader() const { return RepLeader; } void setLeader(Value *Leader) { RepLeader = Leader; } @@ -280,7 +317,6 @@ public: return NextLeader; } void resetNextLeader() { NextLeader = {nullptr, ~0}; } - void addPossibleNextLeader(std::pair<Value *, unsigned int> LeaderPair) { if (LeaderPair.second < NextLeader.second) NextLeader = LeaderPair; @@ -315,6 +351,7 @@ public: iterator_range<MemoryMemberSet::const_iterator> memory() const { return make_range(memory_begin(), memory_end()); } + void memory_insert(const MemoryMemberType *M) { MemoryMembers.insert(M); } void memory_erase(const MemoryMemberType *M) { MemoryMembers.erase(M); } @@ -354,34 +391,48 @@ public: private: unsigned ID; + // Representative leader. Value *RepLeader = nullptr; + // The most dominating leader after our current leader, because the member set // is not sorted and is expensive to keep sorted all the time. std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U}; + // If this is represented by a store, the value of the store. Value *RepStoredValue = nullptr; + // If this class contains MemoryDefs or MemoryPhis, this is the leading memory // access. const MemoryAccess *RepMemoryAccess = nullptr; + // Defining Expression. const Expression *DefiningExpr = nullptr; + // Actual members of this class. MemberSet Members; + // This is the set of MemoryPhis that exist in the class. MemoryDefs and // MemoryUses have real instructions representing them, so we only need to // track MemoryPhis here. MemoryMemberSet MemoryMembers; + // Number of stores in this congruence class. // This is used so we can detect store equivalence changes properly. int StoreCount = 0; }; +} // end anonymous namespace + namespace llvm { + struct ExactEqualsExpression { const Expression &E; + explicit ExactEqualsExpression(const Expression &E) : E(E) {} + hash_code getComputedHash() const { return E.getComputedHash(); } + bool operator==(const Expression &Other) const { return E.exactlyEquals(Other); } @@ -393,17 +444,21 @@ template <> struct DenseMapInfo<const Expression *> { Val <<= PointerLikeTypeTraits<const Expression *>::NumLowBitsAvailable; return reinterpret_cast<const Expression *>(Val); } + static const Expression *getTombstoneKey() { auto Val = static_cast<uintptr_t>(~1U); Val <<= PointerLikeTypeTraits<const Expression *>::NumLowBitsAvailable; return reinterpret_cast<const Expression *>(Val); } + static unsigned getHashValue(const Expression *E) { return E->getComputedHash(); } + static unsigned getHashValue(const ExactEqualsExpression &E) { return E.getComputedHash(); } + static bool isEqual(const ExactEqualsExpression &LHS, const Expression *RHS) { if (RHS == getTombstoneKey() || RHS == getEmptyKey()) return false; @@ -425,9 +480,11 @@ template <> struct DenseMapInfo<const Expression *> { return *LHS == *RHS; } }; + } // end namespace llvm namespace { + class NewGVN { Function &F; DominatorTree *DT; @@ -464,16 +521,22 @@ class NewGVN { // Value Mappings. DenseMap<Value *, CongruenceClass *> ValueToClass; DenseMap<Value *, const Expression *> ValueToExpression; + // Value PHI handling, used to make equivalence between phi(op, op) and // op(phi, phi). // These mappings just store various data that would normally be part of the // IR. - DenseSet<const Instruction *> PHINodeUses; + SmallPtrSet<const Instruction *, 8> PHINodeUses; + + DenseMap<const Value *, bool> OpSafeForPHIOfOps; + // Map a temporary instruction we created to a parent block. DenseMap<const Value *, BasicBlock *> TempToBlock; - // Map between the temporary phis we created and the real instructions they - // are known equivalent to. + + // Map between the already in-program instructions and the temporary phis we + // created that they are known equivalent to. DenseMap<const Value *, PHINode *> RealToTemp; + // In order to know when we should re-process instructions that have // phi-of-ops, we track the set of expressions that they needed as // leaders. When we discover new leaders for those expressions, we process the @@ -485,19 +548,32 @@ class NewGVN { mutable DenseMap<const Value *, SmallPtrSet<Value *, 2>> AdditionalUsers; DenseMap<const Expression *, SmallPtrSet<Instruction *, 2>> ExpressionToPhiOfOps; - // Map from basic block to the temporary operations we created - DenseMap<const BasicBlock *, SmallVector<PHINode *, 8>> PHIOfOpsPHIs; + // Map from temporary operation to MemoryAccess. DenseMap<const Instruction *, MemoryUseOrDef *> TempToMemory; + // Set of all temporary instructions we created. + // Note: This will include instructions that were just created during value + // numbering. The way to test if something is using them is to check + // RealToTemp. DenseSet<Instruction *> AllTempInstructions; + // This is the set of instructions to revisit on a reachability change. At + // the end of the main iteration loop it will contain at least all the phi of + // ops instructions that will be changed to phis, as well as regular phis. + // During the iteration loop, it may contain other things, such as phi of ops + // instructions that used edge reachability to reach a result, and so need to + // be revisited when the edge changes, independent of whether the phi they + // depended on changes. + DenseMap<BasicBlock *, SparseBitVector<>> RevisitOnReachabilityChange; + // Mapping from predicate info we used to the instructions we used it with. // In order to correctly ensure propagation, we must keep track of what // comparisons we used, so that when the values of the comparisons change, we // propagate the information to the places we used the comparison. mutable DenseMap<const Value *, SmallPtrSet<Instruction *, 2>> PredicateToUsers; + // the same reasoning as PredicateToUsers. When we skip MemoryAccesses for // stores, we no longer can rely solely on the def-use chains of MemorySSA. mutable DenseMap<const MemoryAccess *, SmallPtrSet<MemoryAccess *, 2>> @@ -525,6 +601,7 @@ class NewGVN { enum InstCycleState { ICS_Unknown, ICS_CycleFree, ICS_Cycle }; mutable DenseMap<const Instruction *, InstCycleState> InstCycleState; + // Expression to class mapping. using ExpressionClassMap = DenseMap<const Expression *, CongruenceClass *>; ExpressionClassMap ExpressionToClass; @@ -581,6 +658,7 @@ public: : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)), SQ(DL, TLI, DT, AC) { } + bool runGVN(); private: @@ -588,7 +666,13 @@ private: const Expression *createExpression(Instruction *) const; const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *, Instruction *) const; - PHIExpression *createPHIExpression(Instruction *, bool &HasBackEdge, + + // Our canonical form for phi arguments is a pair of incoming value, incoming + // basic block. + using ValPair = std::pair<Value *, BasicBlock *>; + + PHIExpression *createPHIExpression(ArrayRef<ValPair>, const Instruction *, + BasicBlock *, bool &HasBackEdge, bool &OriginalOpsConstant) const; const DeadExpression *createDeadExpression() const; const VariableExpression *createVariableExpression(Value *) const; @@ -617,6 +701,7 @@ private: CC->setMemoryLeader(MA); return CC; } + CongruenceClass *ensureLeaderOfMemoryClass(MemoryAccess *MA) { auto *CC = getMemoryClass(MA); if (CC->getMemoryLeader() != MA) @@ -630,10 +715,21 @@ private: ValueToClass[Member] = CClass; return CClass; } + void initializeCongruenceClasses(Function &F); - const Expression *makePossiblePhiOfOps(Instruction *, + const Expression *makePossiblePHIOfOps(Instruction *, SmallPtrSetImpl<Value *> &); + Value *findLeaderForInst(Instruction *ValueOp, + SmallPtrSetImpl<Value *> &Visited, + MemoryAccess *MemAccess, Instruction *OrigInst, + BasicBlock *PredBB); + bool OpIsSafeForPHIOfOpsHelper(Value *V, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &Visited, + SmallVectorImpl<Instruction *> &Worklist); + bool OpIsSafeForPHIOfOps(Value *Op, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &); void addPhiOfOps(PHINode *Op, BasicBlock *BB, Instruction *ExistingValue); + void removePhiOfOps(Instruction *I, PHINode *PHITemp); // Value number an Instruction or MemoryPhi. void valueNumberMemoryPhi(MemoryPhi *); @@ -650,7 +746,10 @@ private: const Expression *performSymbolicLoadEvaluation(Instruction *) const; const Expression *performSymbolicStoreEvaluation(Instruction *) const; const Expression *performSymbolicCallEvaluation(Instruction *) const; - const Expression *performSymbolicPHIEvaluation(Instruction *) const; + void sortPHIOps(MutableArrayRef<ValPair> Ops) const; + const Expression *performSymbolicPHIEvaluation(ArrayRef<ValPair>, + Instruction *I, + BasicBlock *PHIBlock) const; const Expression *performSymbolicAggrValueEvaluation(Instruction *) const; const Expression *performSymbolicCmpEvaluation(Instruction *) const; const Expression *performSymbolicPredicateInfoEvaluation(Instruction *) const; @@ -658,6 +757,7 @@ private: // Congruence finding. bool someEquivalentDominates(const Instruction *, const Instruction *) const; Value *lookupOperandLeader(Value *) const; + CongruenceClass *getClassForExpression(const Expression *E) const; void performCongruenceFinding(Instruction *, const Expression *); void moveValueToNewCongruenceClass(Instruction *, const Expression *, CongruenceClass *, CongruenceClass *); @@ -692,10 +792,11 @@ private: void replaceInstruction(Instruction *, Value *); void markInstructionForDeletion(Instruction *); void deleteInstructionsInBlock(BasicBlock *); - Value *findPhiOfOpsLeader(const Expression *E, const BasicBlock *BB) const; + Value *findPHIOfOpsLeader(const Expression *, const Instruction *, + const BasicBlock *) const; // New instruction creation. - void handleNewInstruction(Instruction *){}; + void handleNewInstruction(Instruction *) {} // Various instruction touch utilities template <typename Map, typename KeyType, typename Func> @@ -731,6 +832,7 @@ private: MemoryAccess *getDefiningAccess(const MemoryAccess *) const; MemoryPhi *getMemoryAccess(const BasicBlock *) const; template <class T, class Range> T *getMinDFSOfRange(const Range &) const; + unsigned InstrToDFSNum(const Value *V) const { assert(isa<Instruction>(V) && "This should not be used for MemoryAccesses"); return InstrDFS.lookup(V); @@ -739,7 +841,9 @@ private: unsigned InstrToDFSNum(const MemoryAccess *MA) const { return MemoryToDFSNum(MA); } + Value *InstrFromDFSNum(unsigned DFSNum) { return DFSToInstr[DFSNum]; } + // Given a MemoryAccess, return the relevant instruction DFS number. Note: // This deliberately takes a value so it can be used with Use's, which will // auto-convert to Value's but not to MemoryAccess's. @@ -750,12 +854,15 @@ private: ? InstrToDFSNum(cast<MemoryUseOrDef>(MA)->getMemoryInst()) : InstrDFS.lookup(MA); } + bool isCycleFree(const Instruction *) const; bool isBackedge(BasicBlock *From, BasicBlock *To) const; + // Debug counter info. When verifying, we have to reset the value numbering // debug counter to the same state it started in to get the same results. std::pair<int, int> StartingVNCounter; }; + } // end anonymous namespace template <typename T> @@ -781,11 +888,9 @@ bool StoreExpression::equals(const Expression &Other) const { // Determine if the edge From->To is a backedge bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const { - if (From == To) - return true; - auto *FromDTN = DT->getNode(From); - auto *ToDTN = DT->getNode(To); - return RPOOrdering.lookup(FromDTN) >= RPOOrdering.lookup(ToDTN); + return From == To || + RPOOrdering.lookup(DT->getNode(From)) >= + RPOOrdering.lookup(DT->getNode(To)); } #ifndef NDEBUG @@ -830,51 +935,77 @@ void NewGVN::deleteExpression(const Expression *E) const { const_cast<BasicExpression *>(BE)->deallocateOperands(ArgRecycler); ExpressionAllocator.Deallocate(E); } -PHIExpression *NewGVN::createPHIExpression(Instruction *I, bool &HasBackedge, + +// If V is a predicateinfo copy, get the thing it is a copy of. +static Value *getCopyOf(const Value *V) { + if (auto *II = dyn_cast<IntrinsicInst>(V)) + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + return II->getOperand(0); + return nullptr; +} + +// Return true if V is really PN, even accounting for predicateinfo copies. +static bool isCopyOfPHI(const Value *V, const PHINode *PN) { + return V == PN || getCopyOf(V) == PN; +} + +static bool isCopyOfAPHI(const Value *V) { + auto *CO = getCopyOf(V); + return CO && isa<PHINode>(CO); +} + +// Sort PHI Operands into a canonical order. What we use here is an RPO +// order. The BlockInstRange numbers are generated in an RPO walk of the basic +// blocks. +void NewGVN::sortPHIOps(MutableArrayRef<ValPair> Ops) const { + std::sort(Ops.begin(), Ops.end(), [&](const ValPair &P1, const ValPair &P2) { + return BlockInstRange.lookup(P1.second).first < + BlockInstRange.lookup(P2.second).first; + }); +} + +// Return true if V is a value that will always be available (IE can +// be placed anywhere) in the function. We don't do globals here +// because they are often worse to put in place. +static bool alwaysAvailable(Value *V) { + return isa<Constant>(V) || isa<Argument>(V); +} + +// Create a PHIExpression from an array of {incoming edge, value} pairs. I is +// the original instruction we are creating a PHIExpression for (but may not be +// a phi node). We require, as an invariant, that all the PHIOperands in the +// same block are sorted the same way. sortPHIOps will sort them into a +// canonical order. +PHIExpression *NewGVN::createPHIExpression(ArrayRef<ValPair> PHIOperands, + const Instruction *I, + BasicBlock *PHIBlock, + bool &HasBackedge, bool &OriginalOpsConstant) const { - BasicBlock *PHIBlock = getBlockForValue(I); - auto *PN = cast<PHINode>(I); - auto *E = - new (ExpressionAllocator) PHIExpression(PN->getNumOperands(), PHIBlock); + unsigned NumOps = PHIOperands.size(); + auto *E = new (ExpressionAllocator) PHIExpression(NumOps, PHIBlock); E->allocateOperands(ArgRecycler, ExpressionAllocator); - E->setType(I->getType()); - E->setOpcode(I->getOpcode()); - - // NewGVN assumes the operands of a PHI node are in a consistent order across - // PHIs. LLVM doesn't seem to always guarantee this. While we need to fix - // this in LLVM at some point we don't want GVN to find wrong congruences. - // Therefore, here we sort uses in predecessor order. - // We're sorting the values by pointer. In theory this might be cause of - // non-determinism, but here we don't rely on the ordering for anything - // significant, e.g. we don't create new instructions based on it so we're - // fine. - SmallVector<const Use *, 4> PHIOperands; - for (const Use &U : PN->operands()) - PHIOperands.push_back(&U); - std::sort(PHIOperands.begin(), PHIOperands.end(), - [&](const Use *U1, const Use *U2) { - return PN->getIncomingBlock(*U1) < PN->getIncomingBlock(*U2); - }); + E->setType(PHIOperands.begin()->first->getType()); + E->setOpcode(Instruction::PHI); // Filter out unreachable phi operands. - auto Filtered = make_filter_range(PHIOperands, [&](const Use *U) { - if (*U == PN) - return false; - if (!ReachableEdges.count({PN->getIncomingBlock(*U), PHIBlock})) + auto Filtered = make_filter_range(PHIOperands, [&](const ValPair &P) { + auto *BB = P.second; + if (auto *PHIOp = dyn_cast<PHINode>(I)) + if (isCopyOfPHI(P.first, PHIOp)) + return false; + if (!ReachableEdges.count({BB, PHIBlock})) return false; // Things in TOPClass are equivalent to everything. - if (ValueToClass.lookup(*U) == TOPClass) + if (ValueToClass.lookup(P.first) == TOPClass) return false; - return lookupOperandLeader(*U) != PN; + OriginalOpsConstant = OriginalOpsConstant && isa<Constant>(P.first); + HasBackedge = HasBackedge || isBackedge(BB, PHIBlock); + return lookupOperandLeader(P.first) != I; }); std::transform(Filtered.begin(), Filtered.end(), op_inserter(E), - [&](const Use *U) -> Value * { - auto *BB = PN->getIncomingBlock(*U); - HasBackedge = HasBackedge || isBackedge(BB, PHIBlock); - OriginalOpsConstant = - OriginalOpsConstant && isa<Constant>(*U); - return lookupOperandLeader(*U); + [&](const ValPair &P) -> Value * { + return lookupOperandLeader(P.first); }); return E; } @@ -929,8 +1060,6 @@ const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, // Take a Value returned by simplification of Expression E/Instruction // I, and see if it resulted in a simpler expression. If so, return // that expression. -// TODO: Once finished, this should not take an Instruction, we only -// use it for printing. const Expression *NewGVN::checkSimplificationResults(Expression *E, Instruction *I, Value *V) const { @@ -954,25 +1083,37 @@ const Expression *NewGVN::checkSimplificationResults(Expression *E, } CongruenceClass *CC = ValueToClass.lookup(V); - if (CC && CC->getDefiningExpr()) { - // If we simplified to something else, we need to communicate - // that we're users of the value we simplified to. - if (I != V) { + if (CC) { + if (CC->getLeader() && CC->getLeader() != I) { // Don't add temporary instructions to the user lists. if (!AllTempInstructions.count(I)) addAdditionalUsers(V, I); + return createVariableOrConstant(CC->getLeader()); } + if (CC->getDefiningExpr()) { + // If we simplified to something else, we need to communicate + // that we're users of the value we simplified to. + if (I != V) { + // Don't add temporary instructions to the user lists. + if (!AllTempInstructions.count(I)) + addAdditionalUsers(V, I); + } - if (I) - DEBUG(dbgs() << "Simplified " << *I << " to " - << " expression " << *CC->getDefiningExpr() << "\n"); - NumGVNOpsSimplified++; - deleteExpression(E); - return CC->getDefiningExpr(); + if (I) + DEBUG(dbgs() << "Simplified " << *I << " to " + << " expression " << *CC->getDefiningExpr() << "\n"); + NumGVNOpsSimplified++; + deleteExpression(E); + return CC->getDefiningExpr(); + } } + return nullptr; } +// Create a value expression from the instruction I, replacing operands with +// their leaders. + const Expression *NewGVN::createExpression(Instruction *I) const { auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); @@ -987,15 +1128,7 @@ const Expression *NewGVN::createExpression(Instruction *I) const { if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) E->swapOperands(0, 1); } - - // Perform simplificaiton - // TODO: Right now we only check to see if we get a constant result. - // We may get a less than constant, but still better, result for - // some operations. - // IE - // add 0, x -> x - // and x, x -> x - // We should handle this by simply rewriting the expression. + // Perform simplification. if (auto *CI = dyn_cast<CmpInst>(I)) { // Sort the operand value numbers so x<y and y>x get the same value // number. @@ -1016,7 +1149,7 @@ const Expression *NewGVN::createExpression(Instruction *I) const { return SimplifiedE; } else if (isa<SelectInst>(I)) { if (isa<Constant>(E->getOperand(0)) || - E->getOperand(0) == E->getOperand(1)) { + E->getOperand(1) == E->getOperand(2)) { assert(E->getOperand(1)->getType() == I->getOperand(1)->getType() && E->getOperand(2)->getType() == I->getOperand(2)->getType()); Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), @@ -1121,7 +1254,7 @@ NewGVN::createCallExpression(CallInst *CI, const MemoryAccess *MA) const { bool NewGVN::someEquivalentDominates(const Instruction *Inst, const Instruction *U) const { auto *CC = ValueToClass.lookup(Inst); - // This must be an instruction because we are only called from phi nodes + // This must be an instruction because we are only called from phi nodes // in the case that the value it needs to check against is an instruction. // The most likely candiates for dominance are the leader and the next leader. @@ -1139,6 +1272,8 @@ bool NewGVN::someEquivalentDominates(const Instruction *Inst, // any of these siblings. if (!CC) return false; + if (alwaysAvailable(CC->getLeader())) + return true; if (DT->dominates(cast<Instruction>(CC->getLeader()), U)) return true; if (CC->getNextLeader().first && @@ -1229,9 +1364,9 @@ const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I) const { if (EnableStoreRefinement) StoreRHS = MSSAWalker->getClobberingMemoryAccess(StoreAccess); // If we bypassed the use-def chains, make sure we add a use. + StoreRHS = lookupMemoryLeader(StoreRHS); if (StoreRHS != StoreAccess->getDefiningAccess()) addMemoryUsers(StoreRHS, StoreAccess); - StoreRHS = lookupMemoryLeader(StoreRHS); // If we are defined by ourselves, use the live on entry def. if (StoreRHS == StoreAccess) StoreRHS = MSSA->getLiveOnEntryDef(); @@ -1278,7 +1413,7 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, if (auto *DepSI = dyn_cast<StoreInst>(DepInst)) { // Can't forward from non-atomic to atomic without violating memory model. // Also don't need to coerce if they are the same type, we will just - // propogate.. + // propagate. if (LI->isAtomic() > DepSI->isAtomic() || LoadType == DepSI->getValueOperand()->getType()) return nullptr; @@ -1292,14 +1427,13 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, getConstantStoreValueForLoad(C, Offset, LoadType, DL)); } } - - } else if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInst)) { + } else if (auto *DepLI = dyn_cast<LoadInst>(DepInst)) { // Can't forward from non-atomic to atomic without violating memory model. if (LI->isAtomic() > DepLI->isAtomic()) return nullptr; int Offset = analyzeLoadFromClobberingLoad(LoadType, LoadPtr, DepLI, DL); if (Offset >= 0) { - // We can coerce a constant load into a load + // We can coerce a constant load into a load. if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI))) if (auto *PossibleConstant = getConstantLoadValueForLoad(C, Offset, LoadType, DL)) { @@ -1308,8 +1442,7 @@ NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, return createConstantExpression(PossibleConstant); } } - - } else if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInst)) { + } else if (auto *DepMI = dyn_cast<MemIntrinsic>(DepInst)) { int Offset = analyzeLoadFromClobberingMemInst(LoadType, LoadPtr, DepMI, DL); if (Offset >= 0) { if (auto *PossibleConstant = @@ -1381,9 +1514,13 @@ const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I) const { } } - const Expression *E = createLoadExpression(LI->getType(), LoadAddressLeader, - LI, DefiningAccess); - return E; + const auto *LE = createLoadExpression(LI->getType(), LoadAddressLeader, LI, + DefiningAccess); + // If our MemoryLeader is not our defining access, add a use to the + // MemoryLeader, so that we get reprocessed when it changes. + if (LE->getMemoryLeader() != DefiningAccess) + addMemoryUsers(LE->getMemoryLeader(), OriginalAccess); + return LE; } const Expression * @@ -1402,7 +1539,7 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { auto *Cond = PWC->Condition; // If this a copy of the condition, it must be either true or false depending - // on the predicate info type and edge + // on the predicate info type and edge. if (CopyOf == Cond) { // We should not need to add predicate users because the predicate info is // already a use of this operand. @@ -1438,7 +1575,7 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); bool SwappedOps = false; - // Sort the ops + // Sort the ops. if (shouldSwapOperands(FirstOp, SecondOp)) { std::swap(FirstOp, SecondOp); SwappedOps = true; @@ -1464,7 +1601,8 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { if ((PBranch->TrueEdge && Predicate == CmpInst::ICMP_EQ) || (!PBranch->TrueEdge && Predicate == CmpInst::ICMP_NE)) { addPredicateUsers(PI, I); - addAdditionalUsers(Cmp->getOperand(0), I); + addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), + I); return createVariableOrConstant(FirstOp); } // Handle the special case of floating point. @@ -1472,7 +1610,8 @@ NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && isa<ConstantFP>(FirstOp) && !cast<ConstantFP>(FirstOp)->isZero()) { addPredicateUsers(PI, I); - addAdditionalUsers(Cmp->getOperand(0), I); + addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), + I); return createConstantExpression(cast<Constant>(FirstOp)); } } @@ -1502,7 +1641,6 @@ const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) const { // Retrieve the memory class for a given MemoryAccess. CongruenceClass *NewGVN::getMemoryClass(const MemoryAccess *MA) const { - auto *Result = MemoryAccessToClass.lookup(MA); assert(Result && "Should have found memory class"); return Result; @@ -1571,8 +1709,9 @@ bool NewGVN::isCycleFree(const Instruction *I) const { if (SCC.size() == 1) InstCycleState.insert({I, ICS_CycleFree}); else { - bool AllPhis = - llvm::all_of(SCC, [](const Value *V) { return isa<PHINode>(V); }); + bool AllPhis = llvm::all_of(SCC, [](const Value *V) { + return isa<PHINode>(V) || isCopyOfAPHI(V); + }); ICS = AllPhis ? ICS_CycleFree : ICS_Cycle; for (auto *Member : SCC) if (auto *MemberPhi = dyn_cast<PHINode>(Member)) @@ -1584,17 +1723,20 @@ bool NewGVN::isCycleFree(const Instruction *I) const { return true; } -// Evaluate PHI nodes symbolically, and create an expression result. -const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I) const { +// Evaluate PHI nodes symbolically and create an expression result. +const Expression * +NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps, + Instruction *I, + BasicBlock *PHIBlock) const { // True if one of the incoming phi edges is a backedge. bool HasBackedge = false; // All constant tracks the state of whether all the *original* phi operands // This is really shorthand for "this phi cannot cycle due to forward // change in value of the phi is guaranteed not to later change the value of // the phi. IE it can't be v = phi(undef, v+1) - bool AllConstant = true; - auto *E = - cast<PHIExpression>(createPHIExpression(I, HasBackedge, AllConstant)); + bool OriginalOpsConstant = true; + auto *E = cast<PHIExpression>(createPHIExpression( + PHIOps, I, PHIBlock, HasBackedge, OriginalOpsConstant)); // We match the semantics of SimplifyPhiNode from InstructionSimplify here. // See if all arguments are the same. // We track if any were undef because they need special handling. @@ -1620,14 +1762,10 @@ const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I) const { deleteExpression(E); return createDeadExpression(); } - unsigned NumOps = 0; Value *AllSameValue = *(Filtered.begin()); ++Filtered.begin(); // Can't use std::equal here, sadly, because filter.begin moves. - if (llvm::all_of(Filtered, [&](Value *Arg) { - ++NumOps; - return Arg == AllSameValue; - })) { + if (llvm::all_of(Filtered, [&](Value *Arg) { return Arg == AllSameValue; })) { // In LLVM's non-standard representation of phi nodes, it's possible to have // phi nodes with cycles (IE dependent on other phis that are .... dependent // on the original phi node), especially in weird CFG's where some arguments @@ -1642,9 +1780,8 @@ const Expression *NewGVN::performSymbolicPHIEvaluation(Instruction *I) const { // multivalued phi, and we need to know if it's cycle free in order to // evaluate whether we can ignore the undef. The other parts of this are // just shortcuts. If there is no backedge, or all operands are - // constants, or all operands are ignored but the undef, it also must be - // cycle free. - if (!AllConstant && HasBackedge && NumOps > 0 && + // constants, it also must be cycle free. + if (HasBackedge && !OriginalOpsConstant && !isa<UndefValue>(AllSameValue) && !isCycleFree(I)) return E; @@ -1708,8 +1845,11 @@ NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) const { return createAggregateValueExpression(I); } + const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { - auto *CI = dyn_cast<CmpInst>(I); + assert(isa<CmpInst>(I) && "Expected a cmp instruction."); + + auto *CI = cast<CmpInst>(I); // See if our operands are equal to those of a previous predicate, and if so, // if it implies true or false. auto Op0 = lookupOperandLeader(CI->getOperand(0)); @@ -1720,7 +1860,7 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { OurPredicate = CI->getSwappedPredicate(); } - // Avoid processing the same info twice + // Avoid processing the same info twice. const PredicateBase *LastPredInfo = nullptr; // See if we know something about the comparison itself, like it is the target // of an assume. @@ -1754,7 +1894,7 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { // %operands are considered users of the icmp. // *Currently* we only check one level of comparisons back, and only mark one - // level back as touched when changes appen . If you modify this code to look + // level back as touched when changes happen. If you modify this code to look // back farther through comparisons, you *must* mark the appropriate // comparisons as users in PredicateInfo.cpp, or you will cause bugs. See if // we know something just from the operands themselves @@ -1767,10 +1907,15 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { if (PI == LastPredInfo) continue; LastPredInfo = PI; - - // TODO: Along the false edge, we may know more things too, like icmp of + // In phi of ops cases, we may have predicate info that we are evaluating + // in a different context. + if (!DT->dominates(PBranch->To, getBlockForValue(I))) + continue; + // TODO: Along the false edge, we may know more things too, like + // icmp of // same operands is false. - // TODO: We only handle actual comparison conditions below, not and/or. + // TODO: We only handle actual comparison conditions below, not + // and/or. auto *BranchCond = dyn_cast<CmpInst>(PBranch->Condition); if (!BranchCond) continue; @@ -1798,7 +1943,6 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { return createConstantExpression( ConstantInt::getFalse(CI->getType())); } - } else { // Just handle the ne and eq cases, where if we have the same // operands, we may know something. @@ -1822,14 +1966,6 @@ const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { return createExpression(I); } -// Return true if V is a value that will always be available (IE can -// be placed anywhere) in the function. We don't do globals here -// because they are often worse to put in place. -// TODO: Separate cost from availability -static bool alwaysAvailable(Value *V) { - return isa<Constant>(V) || isa<Argument>(V); -} - // Substitute and symbolize the value before value numbering. const Expression * NewGVN::performSymbolicEvaluation(Value *V, @@ -1849,9 +1985,15 @@ NewGVN::performSymbolicEvaluation(Value *V, case Instruction::InsertValue: E = performSymbolicAggrValueEvaluation(I); break; - case Instruction::PHI: - E = performSymbolicPHIEvaluation(I); - break; + case Instruction::PHI: { + SmallVector<ValPair, 3> Ops; + auto *PN = cast<PHINode>(I); + for (unsigned i = 0; i < PN->getNumOperands(); ++i) + Ops.push_back({PN->getIncomingValue(i), PN->getIncomingBlock(i)}); + // Sort to ensure the invariant createPHIExpression requires is met. + sortPHIOps(Ops); + E = performSymbolicPHIEvaluation(Ops, I, getBlockForValue(I)); + } break; case Instruction::Call: E = performSymbolicCallEvaluation(I); break; @@ -1861,13 +2003,13 @@ NewGVN::performSymbolicEvaluation(Value *V, case Instruction::Load: E = performSymbolicLoadEvaluation(I); break; - case Instruction::BitCast: { + case Instruction::BitCast: E = createExpression(I); - } break; + break; case Instruction::ICmp: - case Instruction::FCmp: { + case Instruction::FCmp: E = performSymbolicCmpEvaluation(I); - } break; + break; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -2017,7 +2159,7 @@ T *NewGVN::getMinDFSOfRange(const Range &R) const { const MemoryAccess *NewGVN::getNextMemoryLeader(CongruenceClass *CC) const { // TODO: If this ends up to slow, we can maintain a next memory leader like we // do for regular leaders. - // Make sure there will be a leader to find + // Make sure there will be a leader to find. assert(!CC->definesNoMemory() && "Can't get next leader if there is none"); if (CC->getStoreCount() > 0) { if (auto *NL = dyn_cast_or_null<StoreInst>(CC->getNextLeader().first)) @@ -2194,7 +2336,7 @@ void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E, // For a given expression, mark the phi of ops instructions that could have // changed as a result. void NewGVN::markPhiOfOpsChanged(const Expression *E) { - touchAndErase(ExpressionToPhiOfOps, ExactEqualsExpression(*E)); + touchAndErase(ExpressionToPhiOfOps, E); } // Perform congruence finding on a given value numbering expression. @@ -2315,14 +2457,11 @@ void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) { if (MemoryAccess *MemPhi = getMemoryAccess(To)) TouchedInstructions.set(InstrToDFSNum(MemPhi)); - auto BI = To->begin(); - while (isa<PHINode>(BI)) { - TouchedInstructions.set(InstrToDFSNum(&*BI)); - ++BI; - } - for_each_found(PHIOfOpsPHIs, To, [&](const PHINode *I) { - TouchedInstructions.set(InstrToDFSNum(I)); - }); + // FIXME: We should just add a union op on a Bitvector and + // SparseBitVector. We can do it word by word faster than we are doing it + // here. + for (auto InstNum : RevisitOnReachabilityChange[To]) + TouchedInstructions.set(InstNum); } } } @@ -2419,24 +2558,146 @@ void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { } } +// Remove the PHI of Ops PHI for I +void NewGVN::removePhiOfOps(Instruction *I, PHINode *PHITemp) { + InstrDFS.erase(PHITemp); + // It's still a temp instruction. We keep it in the array so it gets erased. + // However, it's no longer used by I, or in the block + TempToBlock.erase(PHITemp); + RealToTemp.erase(I); + // We don't remove the users from the phi node uses. This wastes a little + // time, but such is life. We could use two sets to track which were there + // are the start of NewGVN, and which were added, but right nowt he cost of + // tracking is more than the cost of checking for more phi of ops. +} + +// Add PHI Op in BB as a PHI of operations version of ExistingValue. void NewGVN::addPhiOfOps(PHINode *Op, BasicBlock *BB, Instruction *ExistingValue) { InstrDFS[Op] = InstrToDFSNum(ExistingValue); AllTempInstructions.insert(Op); - PHIOfOpsPHIs[BB].push_back(Op); TempToBlock[Op] = BB; RealToTemp[ExistingValue] = Op; + // Add all users to phi node use, as they are now uses of the phi of ops phis + // and may themselves be phi of ops. + for (auto *U : ExistingValue->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + PHINodeUses.insert(UI); } static bool okayForPHIOfOps(const Instruction *I) { + if (!EnablePhiOfOps) + return false; return isa<BinaryOperator>(I) || isa<SelectInst>(I) || isa<CmpInst>(I) || isa<LoadInst>(I); } +bool NewGVN::OpIsSafeForPHIOfOpsHelper( + Value *V, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &Visited, + SmallVectorImpl<Instruction *> &Worklist) { + + if (!isa<Instruction>(V)) + return true; + auto OISIt = OpSafeForPHIOfOps.find(V); + if (OISIt != OpSafeForPHIOfOps.end()) + return OISIt->second; + + // Keep walking until we either dominate the phi block, or hit a phi, or run + // out of things to check. + if (DT->properlyDominates(getBlockForValue(V), PHIBlock)) { + OpSafeForPHIOfOps.insert({V, true}); + return true; + } + // PHI in the same block. + if (isa<PHINode>(V) && getBlockForValue(V) == PHIBlock) { + OpSafeForPHIOfOps.insert({V, false}); + return false; + } + + auto *OrigI = cast<Instruction>(V); + for (auto *Op : OrigI->operand_values()) { + if (!isa<Instruction>(Op)) + continue; + // Stop now if we find an unsafe operand. + auto OISIt = OpSafeForPHIOfOps.find(OrigI); + if (OISIt != OpSafeForPHIOfOps.end()) { + if (!OISIt->second) { + OpSafeForPHIOfOps.insert({V, false}); + return false; + } + continue; + } + if (!Visited.insert(Op).second) + continue; + Worklist.push_back(cast<Instruction>(Op)); + } + return true; +} + +// Return true if this operand will be safe to use for phi of ops. +// +// The reason some operands are unsafe is that we are not trying to recursively +// translate everything back through phi nodes. We actually expect some lookups +// of expressions to fail. In particular, a lookup where the expression cannot +// exist in the predecessor. This is true even if the expression, as shown, can +// be determined to be constant. +bool NewGVN::OpIsSafeForPHIOfOps(Value *V, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &Visited) { + SmallVector<Instruction *, 4> Worklist; + if (!OpIsSafeForPHIOfOpsHelper(V, PHIBlock, Visited, Worklist)) + return false; + while (!Worklist.empty()) { + auto *I = Worklist.pop_back_val(); + if (!OpIsSafeForPHIOfOpsHelper(I, PHIBlock, Visited, Worklist)) + return false; + } + OpSafeForPHIOfOps.insert({V, true}); + return true; +} + +// Try to find a leader for instruction TransInst, which is a phi translated +// version of something in our original program. Visited is used to ensure we +// don't infinite loop during translations of cycles. OrigInst is the +// instruction in the original program, and PredBB is the predecessor we +// translated it through. +Value *NewGVN::findLeaderForInst(Instruction *TransInst, + SmallPtrSetImpl<Value *> &Visited, + MemoryAccess *MemAccess, Instruction *OrigInst, + BasicBlock *PredBB) { + unsigned IDFSNum = InstrToDFSNum(OrigInst); + // Make sure it's marked as a temporary instruction. + AllTempInstructions.insert(TransInst); + // and make sure anything that tries to add it's DFS number is + // redirected to the instruction we are making a phi of ops + // for. + TempToBlock.insert({TransInst, PredBB}); + InstrDFS.insert({TransInst, IDFSNum}); + + const Expression *E = performSymbolicEvaluation(TransInst, Visited); + InstrDFS.erase(TransInst); + AllTempInstructions.erase(TransInst); + TempToBlock.erase(TransInst); + if (MemAccess) + TempToMemory.erase(TransInst); + if (!E) + return nullptr; + auto *FoundVal = findPHIOfOpsLeader(E, OrigInst, PredBB); + if (!FoundVal) { + ExpressionToPhiOfOps[E].insert(OrigInst); + DEBUG(dbgs() << "Cannot find phi of ops operand for " << *TransInst + << " in block " << getBlockName(PredBB) << "\n"); + return nullptr; + } + if (auto *SI = dyn_cast<StoreInst>(FoundVal)) + FoundVal = SI->getValueOperand(); + return FoundVal; +} + // When we see an instruction that is an op of phis, generate the equivalent phi // of ops form. const Expression * -NewGVN::makePossiblePhiOfOps(Instruction *I, +NewGVN::makePossiblePHIOfOps(Instruction *I, SmallPtrSetImpl<Value *> &Visited) { if (!okayForPHIOfOps(I)) return nullptr; @@ -2450,7 +2711,6 @@ NewGVN::makePossiblePhiOfOps(Instruction *I, if (!isCycleFree(I)) return nullptr; - unsigned IDFSNum = InstrToDFSNum(I); SmallPtrSet<const Value *, 8> ProcessedPHIs; // TODO: We don't do phi translation on memory accesses because it's // complicated. For a load, we'd need to be able to simulate a new memoryuse, @@ -2463,81 +2723,94 @@ NewGVN::makePossiblePhiOfOps(Instruction *I, MemAccess->getDefiningAccess()->getBlock() == I->getParent()) return nullptr; + SmallPtrSet<const Value *, 10> VisitedOps; // Convert op of phis to phi of ops - for (auto &Op : I->operands()) { - // TODO: We can't handle expressions that must be recursively translated - // IE - // a = phi (b, c) - // f = use a - // g = f + phi of something - // To properly make a phi of ops for g, we'd have to properly translate and - // use the instruction for f. We should add this by splitting out the - // instruction creation we do below. - if (isa<Instruction>(Op) && PHINodeUses.count(cast<Instruction>(Op))) - return nullptr; - if (!isa<PHINode>(Op)) - continue; + for (auto *Op : I->operand_values()) { + if (!isa<PHINode>(Op)) { + auto *ValuePHI = RealToTemp.lookup(Op); + if (!ValuePHI) + continue; + DEBUG(dbgs() << "Found possible dependent phi of ops\n"); + Op = ValuePHI; + } auto *OpPHI = cast<PHINode>(Op); // No point in doing this for one-operand phis. if (OpPHI->getNumOperands() == 1) continue; if (!DebugCounter::shouldExecute(PHIOfOpsCounter)) return nullptr; - SmallVector<std::pair<Value *, BasicBlock *>, 4> Ops; + SmallVector<ValPair, 4> Ops; + SmallPtrSet<Value *, 4> Deps; auto *PHIBlock = getBlockForValue(OpPHI); - for (auto PredBB : OpPHI->blocks()) { + RevisitOnReachabilityChange[PHIBlock].reset(InstrToDFSNum(I)); + for (unsigned PredNum = 0; PredNum < OpPHI->getNumOperands(); ++PredNum) { + auto *PredBB = OpPHI->getIncomingBlock(PredNum); Value *FoundVal = nullptr; // We could just skip unreachable edges entirely but it's tricky to do // with rewriting existing phi nodes. if (ReachableEdges.count({PredBB, PHIBlock})) { - // Clone the instruction, create an expression from it, and see if we - // have a leader. + // Clone the instruction, create an expression from it that is + // translated back into the predecessor, and see if we have a leader. Instruction *ValueOp = I->clone(); if (MemAccess) TempToMemory.insert({ValueOp, MemAccess}); - + bool SafeForPHIOfOps = true; + VisitedOps.clear(); for (auto &Op : ValueOp->operands()) { - Op = Op->DoPHITranslation(PHIBlock, PredBB); - // When this operand changes, it could change whether there is a - // leader for us or not. - addAdditionalUsers(Op, I); + auto *OrigOp = &*Op; + // When these operand changes, it could change whether there is a + // leader for us or not, so we have to add additional users. + if (isa<PHINode>(Op)) { + Op = Op->DoPHITranslation(PHIBlock, PredBB); + if (Op != OrigOp && Op != I) + Deps.insert(Op); + } else if (auto *ValuePHI = RealToTemp.lookup(Op)) { + if (getBlockForValue(ValuePHI) == PHIBlock) + Op = ValuePHI->getIncomingValueForBlock(PredBB); + } + // If we phi-translated the op, it must be safe. + SafeForPHIOfOps = + SafeForPHIOfOps && + (Op != OrigOp || OpIsSafeForPHIOfOps(Op, PHIBlock, VisitedOps)); } - // Make sure it's marked as a temporary instruction. - AllTempInstructions.insert(ValueOp); - // and make sure anything that tries to add it's DFS number is - // redirected to the instruction we are making a phi of ops - // for. - InstrDFS.insert({ValueOp, IDFSNum}); - const Expression *E = performSymbolicEvaluation(ValueOp, Visited); - InstrDFS.erase(ValueOp); - AllTempInstructions.erase(ValueOp); + // FIXME: For those things that are not safe we could generate + // expressions all the way down, and see if this comes out to a + // constant. For anything where that is true, and unsafe, we should + // have made a phi-of-ops (or value numbered it equivalent to something) + // for the pieces already. + FoundVal = !SafeForPHIOfOps ? nullptr + : findLeaderForInst(ValueOp, Visited, + MemAccess, I, PredBB); ValueOp->deleteValue(); - if (MemAccess) - TempToMemory.erase(ValueOp); - if (!E) - return nullptr; - FoundVal = findPhiOfOpsLeader(E, PredBB); - if (!FoundVal) { - ExpressionToPhiOfOps[E].insert(I); + if (!FoundVal) return nullptr; - } - if (auto *SI = dyn_cast<StoreInst>(FoundVal)) - FoundVal = SI->getValueOperand(); } else { DEBUG(dbgs() << "Skipping phi of ops operand for incoming block " << getBlockName(PredBB) << " because the block is unreachable\n"); FoundVal = UndefValue::get(I->getType()); + RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); } Ops.push_back({FoundVal, PredBB}); DEBUG(dbgs() << "Found phi of ops operand " << *FoundVal << " in " << getBlockName(PredBB) << "\n"); } + for (auto Dep : Deps) + addAdditionalUsers(Dep, I); + sortPHIOps(Ops); + auto *E = performSymbolicPHIEvaluation(Ops, I, PHIBlock); + if (isa<ConstantExpression>(E) || isa<VariableExpression>(E)) { + DEBUG(dbgs() + << "Not creating real PHI of ops because it simplified to existing " + "value or constant\n"); + return E; + } auto *ValuePHI = RealToTemp.lookup(I); bool NewPHI = false; if (!ValuePHI) { - ValuePHI = PHINode::Create(I->getType(), OpPHI->getNumOperands()); + ValuePHI = + PHINode::Create(I->getType(), OpPHI->getNumOperands(), "phiofops"); addPhiOfOps(ValuePHI, PHIBlock, I); NewPHI = true; NumGVNPHIOfOpsCreated++; @@ -2553,10 +2826,11 @@ NewGVN::makePossiblePhiOfOps(Instruction *I, ++i; } } - + RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); DEBUG(dbgs() << "Created phi of ops " << *ValuePHI << " for " << *I << "\n"); - return performSymbolicEvaluation(ValuePHI, Visited); + + return E; } return nullptr; } @@ -2602,8 +2876,11 @@ void NewGVN::initializeCongruenceClasses(Function &F) { if (MD && isa<StoreInst>(MD->getMemoryInst())) TOPClass->incStoreCount(); } + + // FIXME: This is trying to discover which instructions are uses of phi + // nodes. We should move this into one of the myriad of places that walk + // all the operands already. for (auto &I : *BB) { - // TODO: Move to helper if (isa<PHINode>(&I)) for (auto *U : I.users()) if (auto *UInst = dyn_cast<Instruction>(U)) @@ -2661,7 +2938,8 @@ void NewGVN::cleanupTables() { ExpressionToPhiOfOps.clear(); TempToBlock.clear(); TempToMemory.clear(); - PHIOfOpsPHIs.clear(); + PHINodeUses.clear(); + OpSafeForPHIOfOps.clear(); ReachableBlocks.clear(); ReachableEdges.clear(); #ifndef NDEBUG @@ -2675,6 +2953,7 @@ void NewGVN::cleanupTables() { MemoryAccessToClass.clear(); PredicateToUsers.clear(); MemoryToUsers.clear(); + RevisitOnReachabilityChange.clear(); } // Assign local DFS number mapping to instructions, and leave space for Value @@ -2698,6 +2977,8 @@ std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B, markInstructionForDeletion(&I); continue; } + if (isa<PHINode>(&I)) + RevisitOnReachabilityChange[B].set(End); InstrDFS[&I] = End++; DFSToInstr.emplace_back(&I); } @@ -2719,6 +3000,7 @@ void NewGVN::updateProcessedCount(const Value *V) { } #endif } + // Evaluate MemoryPhi nodes symbolically, just like PHI nodes void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { // If all the arguments are the same, the MemoryPhi has the same value as the @@ -2787,11 +3069,15 @@ void NewGVN::valueNumberInstruction(Instruction *I) { // Make a phi of ops if necessary if (Symbolized && !isa<ConstantExpression>(Symbolized) && !isa<VariableExpression>(Symbolized) && PHINodeUses.count(I)) { - auto *PHIE = makePossiblePhiOfOps(I, Visited); - if (PHIE) + auto *PHIE = makePossiblePHIOfOps(I, Visited); + // If we created a phi of ops, use it. + // If we couldn't create one, make sure we don't leave one lying around + if (PHIE) { Symbolized = PHIE; + } else if (auto *Op = RealToTemp.lookup(I)) { + removePhiOfOps(I, Op); + } } - } else { // Mark the instruction as unused so we don't value number it again. InstrDFS[I] = 0; @@ -2905,7 +3191,7 @@ void NewGVN::verifyMemoryCongruency() const { // so we don't process them. if (auto *MemPHI = dyn_cast<MemoryPhi>(Pair.first)) { for (auto &U : MemPHI->incoming_values()) { - if (Instruction *I = dyn_cast<Instruction>(U.get())) { + if (auto *I = dyn_cast<Instruction>(&*U)) { if (!isInstructionTriviallyDead(I)) return true; } @@ -3200,11 +3486,13 @@ struct NewGVN::ValueDFS { int DFSIn = 0; int DFSOut = 0; int LocalNum = 0; + // Only one of Def and U will be set. // The bool in the Def tells us whether the Def is the stored value of a // store. PointerIntPair<Value *, 1, bool> Def; Use *U = nullptr; + bool operator<(const ValueDFS &Other) const { // It's not enough that any given field be less than - we have sets // of fields that need to be evaluated together to give a proper ordering. @@ -3439,7 +3727,6 @@ void NewGVN::markInstructionForDeletion(Instruction *I) { } void NewGVN::replaceInstruction(Instruction *I, Value *V) { - DEBUG(dbgs() << "Replacing " << *I << " with " << *V << "\n"); patchAndReplaceAllUsesWith(I, V); // We save the actual erasing to avoid invalidating memory @@ -3460,7 +3747,9 @@ public: ValueStack.emplace_back(V); DFSStack.emplace_back(DFSIn, DFSOut); } + bool empty() const { return DFSStack.empty(); } + bool isInScope(int DFSIn, int DFSOut) const { if (empty()) return false; @@ -3484,19 +3773,33 @@ private: SmallVector<Value *, 8> ValueStack; SmallVector<std::pair<int, int>, 8> DFSStack; }; + +} // end anonymous namespace + +// Given an expression, get the congruence class for it. +CongruenceClass *NewGVN::getClassForExpression(const Expression *E) const { + if (auto *VE = dyn_cast<VariableExpression>(E)) + return ValueToClass.lookup(VE->getVariableValue()); + else if (isa<DeadExpression>(E)) + return TOPClass; + return ExpressionToClass.lookup(E); } // Given a value and a basic block we are trying to see if it is available in, // see if the value has a leader available in that block. -Value *NewGVN::findPhiOfOpsLeader(const Expression *E, +Value *NewGVN::findPHIOfOpsLeader(const Expression *E, + const Instruction *OrigInst, const BasicBlock *BB) const { // It would already be constant if we could make it constant if (auto *CE = dyn_cast<ConstantExpression>(E)) return CE->getConstantValue(); - if (auto *VE = dyn_cast<VariableExpression>(E)) - return VE->getVariableValue(); + if (auto *VE = dyn_cast<VariableExpression>(E)) { + auto *V = VE->getVariableValue(); + if (alwaysAvailable(V) || DT->dominates(getBlockForValue(V), BB)) + return VE->getVariableValue(); + } - auto *CC = ExpressionToClass.lookup(E); + auto *CC = getClassForExpression(E); if (!CC) return nullptr; if (alwaysAvailable(CC->getLeader())) @@ -3504,15 +3807,13 @@ Value *NewGVN::findPhiOfOpsLeader(const Expression *E, for (auto Member : *CC) { auto *MemberInst = dyn_cast<Instruction>(Member); + if (MemberInst == OrigInst) + continue; // Anything that isn't an instruction is always available. if (!MemberInst) return Member; - // If we are looking for something in the same block as the member, it must - // be a leader because this function is looking for operands for a phi node. - if (MemberInst->getParent() == BB || - DT->dominates(MemberInst->getParent(), BB)) { + if (DT->dominates(getBlockForValue(MemberInst), BB)) return Member; - } } return nullptr; } @@ -3549,36 +3850,39 @@ bool NewGVN::eliminateInstructions(Function &F) { // Go through all of our phi nodes, and kill the arguments associated with // unreachable edges. - auto ReplaceUnreachablePHIArgs = [&](PHINode &PHI, BasicBlock *BB) { - for (auto &Operand : PHI.incoming_values()) - if (!ReachableEdges.count({PHI.getIncomingBlock(Operand), BB})) { + auto ReplaceUnreachablePHIArgs = [&](PHINode *PHI, BasicBlock *BB) { + for (auto &Operand : PHI->incoming_values()) + if (!ReachableEdges.count({PHI->getIncomingBlock(Operand), BB})) { DEBUG(dbgs() << "Replacing incoming value of " << PHI << " for block " - << getBlockName(PHI.getIncomingBlock(Operand)) + << getBlockName(PHI->getIncomingBlock(Operand)) << " with undef due to it being unreachable\n"); - Operand.set(UndefValue::get(PHI.getType())); + Operand.set(UndefValue::get(PHI->getType())); } }; - SmallPtrSet<BasicBlock *, 8> BlocksWithPhis; - for (auto &B : F) - if ((!B.empty() && isa<PHINode>(*B.begin())) || - (PHIOfOpsPHIs.find(&B) != PHIOfOpsPHIs.end())) - BlocksWithPhis.insert(&B); + // Replace unreachable phi arguments. + // At this point, RevisitOnReachabilityChange only contains: + // + // 1. PHIs + // 2. Temporaries that will convert to PHIs + // 3. Operations that are affected by an unreachable edge but do not fit into + // 1 or 2 (rare). + // So it is a slight overshoot of what we want. We could make it exact by + // using two SparseBitVectors per block. DenseMap<const BasicBlock *, unsigned> ReachablePredCount; - for (auto KV : ReachableEdges) + for (auto &KV : ReachableEdges) ReachablePredCount[KV.getEnd()]++; - for (auto *BB : BlocksWithPhis) - // TODO: It would be faster to use getNumIncomingBlocks() on a phi node in - // the block and subtract the pred count, but it's more complicated. - if (ReachablePredCount.lookup(BB) != - unsigned(std::distance(pred_begin(BB), pred_end(BB)))) { - for (auto II = BB->begin(); isa<PHINode>(II); ++II) { - auto &PHI = cast<PHINode>(*II); + for (auto &BBPair : RevisitOnReachabilityChange) { + for (auto InstNum : BBPair.second) { + auto *Inst = InstrFromDFSNum(InstNum); + auto *PHI = dyn_cast<PHINode>(Inst); + PHI = PHI ? PHI : dyn_cast_or_null<PHINode>(RealToTemp.lookup(Inst)); + if (!PHI) + continue; + auto *BB = BBPair.first; + if (ReachablePredCount.lookup(BB) != PHI->getNumIncomingValues()) ReplaceUnreachablePHIArgs(PHI, BB); - } - for_each_found(PHIOfOpsPHIs, BB, [&](PHINode *PHI) { - ReplaceUnreachablePHIArgs(*PHI, BB); - }); } + } // Map to store the use counts DenseMap<const Value *, unsigned int> UseCounts; @@ -3631,7 +3935,7 @@ bool NewGVN::eliminateInstructions(Function &F) { CC->swap(MembersLeft); } else { // If this is a singleton, we can skip it. - if (CC->size() != 1 || RealToTemp.lookup(Leader)) { + if (CC->size() != 1 || RealToTemp.count(Leader)) { // This is a stack because equality replacement/etc may place // constants in the middle of the member list, and we want to use // those constant values in preference to the current leader, over @@ -3873,12 +4177,16 @@ bool NewGVN::shouldSwapOperands(const Value *A, const Value *B) const { } namespace { + class NewGVNLegacyPass : public FunctionPass { public: - static char ID; // Pass identification, replacement for typeid. + // Pass identification, replacement for typeid. + static char ID; + NewGVNLegacyPass() : FunctionPass(ID) { initializeNewGVNLegacyPassPass(*PassRegistry::getPassRegistry()); } + bool runOnFunction(Function &F) override; private: @@ -3892,7 +4200,8 @@ private: AU.addPreserved<GlobalsAAWrapperPass>(); } }; -} // namespace + +} // end anonymous namespace bool NewGVNLegacyPass::runOnFunction(Function &F) { if (skipFunction(F)) @@ -3906,6 +4215,8 @@ bool NewGVNLegacyPass::runOnFunction(Function &F) { .runGVN(); } +char NewGVNLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) @@ -3917,8 +4228,6 @@ INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) INITIALIZE_PASS_END(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false, false) -char NewGVNLegacyPass::ID = 0; - // createGVNPass - The public interface to this file. FunctionPass *llvm::createNewGVNPass() { return new NewGVNLegacyPass(); } diff --git a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp index 1bfecea2f61e..1748815c5941 100644 --- a/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp +++ b/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -26,16 +26,13 @@ using namespace llvm; static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, - BasicBlock &CurrBB, Function::iterator &BB) { + BasicBlock &CurrBB, Function::iterator &BB, + const TargetTransformInfo *TTI) { // There is no need to change the IR, since backend will emit sqrt // instruction if the call has already been marked read-only. if (Call->onlyReadsMemory()) return false; - // The call must have the expected result type. - if (!Call->getType()->isFloatingPointTy()) - return false; - // Do the following transformation: // // (before) @@ -43,7 +40,7 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, // // (after) // v0 = sqrt_noreadmem(src) # native sqrt instruction. - // if (v0 is a NaN) + // [if (v0 is a NaN) || if (src < 0)] // v1 = sqrt(src) # library call. // dst = phi(v0, v1) // @@ -52,7 +49,8 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, // Create phi and replace all uses. BasicBlock *JoinBB = llvm::SplitBlock(&CurrBB, Call->getNextNode()); IRBuilder<> Builder(JoinBB, JoinBB->begin()); - PHINode *Phi = Builder.CreatePHI(Call->getType(), 2); + Type *Ty = Call->getType(); + PHINode *Phi = Builder.CreatePHI(Ty, 2); Call->replaceAllUsesWith(Phi); // Create basic block LibCallBB and insert a call to library function sqrt. @@ -69,7 +67,10 @@ static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, Call->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); CurrBB.getTerminator()->eraseFromParent(); Builder.SetInsertPoint(&CurrBB); - Value *FCmp = Builder.CreateFCmpOEQ(Call, Call); + Value *FCmp = TTI->isFCmpOrdCheaperThanFCmpZero(Ty) + ? Builder.CreateFCmpORD(Call, Call) + : Builder.CreateFCmpOGE(Call->getOperand(0), + ConstantFP::get(Ty, 0.0)); Builder.CreateCondBr(FCmp, JoinBB, LibCallBB); // Add phi operands. @@ -96,18 +97,21 @@ static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, if (!Call || !(CalledFunc = Call->getCalledFunction())) continue; + if (Call->isNoBuiltin()) + continue; + // Skip if function either has local linkage or is not a known library // function. LibFunc LF; - if (CalledFunc->hasLocalLinkage() || !CalledFunc->hasName() || - !TLI->getLibFunc(CalledFunc->getName(), LF)) + if (CalledFunc->hasLocalLinkage() || + !TLI->getLibFunc(*CalledFunc, LF) || !TLI->has(LF)) continue; switch (LF) { case LibFunc_sqrtf: case LibFunc_sqrt: if (TTI->haveFastSqrt(Call->getType()) && - optimizeSQRT(Call, CalledFunc, *CurrBB, BB)) + optimizeSQRT(Call, CalledFunc, *CurrBB, BB, TTI)) break; continue; default: diff --git a/lib/Transforms/Scalar/PlaceSafepoints.cpp b/lib/Transforms/Scalar/PlaceSafepoints.cpp index e47b636348e3..2d0cb6fbf211 100644 --- a/lib/Transforms/Scalar/PlaceSafepoints.cpp +++ b/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -54,6 +54,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IntrinsicInst.h" @@ -113,6 +114,7 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { ScalarEvolution *SE = nullptr; DominatorTree *DT = nullptr; LoopInfo *LI = nullptr; + TargetLibraryInfo *TLI = nullptr; PlaceBackedgeSafepointsImpl(bool CallSafepoints = false) : FunctionPass(ID), CallSafepointsEnabled(CallSafepoints) { @@ -131,6 +133,7 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); for (Loop *I : *LI) { runOnLoopAndSubLoops(I); } @@ -141,6 +144,7 @@ struct PlaceBackedgeSafepointsImpl : public FunctionPass { AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<ScalarEvolutionWrapperPass>(); AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); // We no longer modify the IR at all in this pass. Thus all // analysis are preserved. AU.setPreservesAll(); @@ -165,6 +169,7 @@ struct PlaceSafepoints : public FunctionPass { // We modify the graph wholesale (inlining, block insertion, etc). We // preserve nothing at the moment. We could potentially preserve dom tree // if that was worth doing + AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; } @@ -174,10 +179,11 @@ struct PlaceSafepoints : public FunctionPass { // callers job. static void InsertSafepointPoll(Instruction *InsertBefore, - std::vector<CallSite> &ParsePointsNeeded /*rval*/); + std::vector<CallSite> &ParsePointsNeeded /*rval*/, + const TargetLibraryInfo &TLI); -static bool needsStatepoint(const CallSite &CS) { - if (callsGCLeafFunction(CS)) +static bool needsStatepoint(const CallSite &CS, const TargetLibraryInfo &TLI) { + if (callsGCLeafFunction(CS, TLI)) return false; if (CS.isCall()) { CallInst *call = cast<CallInst>(CS.getInstruction()); @@ -194,7 +200,8 @@ static bool needsStatepoint(const CallSite &CS) { /// answer; i.e. false is always valid. static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header, BasicBlock *Pred, - DominatorTree &DT) { + DominatorTree &DT, + const TargetLibraryInfo &TLI) { // In general, we're looking for any cut of the graph which ensures // there's a call safepoint along every edge between Header and Pred. // For the moment, we look only for the 'cuts' that consist of a single call @@ -217,7 +224,7 @@ static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header, // unconditional poll. In practice, this is only a theoretical concern // since we don't have any methods with conditional-only safepoint // polls. - if (needsStatepoint(CS)) + if (needsStatepoint(CS, TLI)) return true; } @@ -321,7 +328,7 @@ bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { continue; } if (CallSafepointsEnabled && - containsUnconditionalCallSafepoint(L, Header, Pred, *DT)) { + containsUnconditionalCallSafepoint(L, Header, Pred, *DT, *TLI)) { // Note: This is only semantically legal since we won't do any further // IPO or inlining before the actual call insertion.. If we hadn't, we // might latter loose this call safepoint. @@ -472,6 +479,9 @@ bool PlaceSafepoints::runOnFunction(Function &F) { if (!shouldRewriteFunction(F)) return false; + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + bool Modified = false; // In various bits below, we rely on the fact that uses are reachable from @@ -578,7 +588,7 @@ bool PlaceSafepoints::runOnFunction(Function &F) { // safepoint polls themselves. for (Instruction *PollLocation : PollsNeeded) { std::vector<CallSite> RuntimeCalls; - InsertSafepointPoll(PollLocation, RuntimeCalls); + InsertSafepointPoll(PollLocation, RuntimeCalls, TLI); ParsePointNeeded.insert(ParsePointNeeded.end(), RuntimeCalls.begin(), RuntimeCalls.end()); } @@ -610,7 +620,8 @@ INITIALIZE_PASS_END(PlaceSafepoints, "place-safepoints", "Place Safepoints", static void InsertSafepointPoll(Instruction *InsertBefore, - std::vector<CallSite> &ParsePointsNeeded /*rval*/) { + std::vector<CallSite> &ParsePointsNeeded /*rval*/, + const TargetLibraryInfo &TLI) { BasicBlock *OrigBB = InsertBefore->getParent(); Module *M = InsertBefore->getModule(); assert(M && "must be part of a module"); @@ -669,7 +680,7 @@ InsertSafepointPoll(Instruction *InsertBefore, assert(ParsePointsNeeded.empty()); for (auto *CI : Calls) { // No safepoint needed or wanted - if (!needsStatepoint(CI)) + if (!needsStatepoint(CI, TLI)) continue; // These are likely runtime calls. Should we assert that via calling diff --git a/lib/Transforms/Scalar/Reassociate.cpp b/lib/Transforms/Scalar/Reassociate.cpp index e235e5eb1a06..88dcaf0f8a36 100644 --- a/lib/Transforms/Scalar/Reassociate.cpp +++ b/lib/Transforms/Scalar/Reassociate.cpp @@ -21,28 +21,45 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/Reassociate.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include <algorithm> +#include <cassert> +#include <utility> + using namespace llvm; using namespace reassociate; @@ -54,7 +71,6 @@ STATISTIC(NumFactor , "Number of multiplies factored"); #ifndef NDEBUG /// Print out the expression identified in the Ops list. -/// static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) { Module *M = I->getModule(); dbgs() << Instruction::getOpcodeName(I->getOpcode()) << " " @@ -128,38 +144,37 @@ XorOpnd::XorOpnd(Value *V) { /// Return true if V is an instruction of the specified opcode and if it /// only has one use. static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { - if (V->hasOneUse() && isa<Instruction>(V) && - cast<Instruction>(V)->getOpcode() == Opcode && - (!isa<FPMathOperator>(V) || - cast<Instruction>(V)->hasUnsafeAlgebra())) - return cast<BinaryOperator>(V); + auto *I = dyn_cast<Instruction>(V); + if (I && I->hasOneUse() && I->getOpcode() == Opcode) + if (!isa<FPMathOperator>(I) || I->isFast()) + return cast<BinaryOperator>(I); return nullptr; } static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode1, unsigned Opcode2) { - if (V->hasOneUse() && isa<Instruction>(V) && - (cast<Instruction>(V)->getOpcode() == Opcode1 || - cast<Instruction>(V)->getOpcode() == Opcode2) && - (!isa<FPMathOperator>(V) || - cast<Instruction>(V)->hasUnsafeAlgebra())) - return cast<BinaryOperator>(V); + auto *I = dyn_cast<Instruction>(V); + if (I && I->hasOneUse() && + (I->getOpcode() == Opcode1 || I->getOpcode() == Opcode2)) + if (!isa<FPMathOperator>(I) || I->isFast()) + return cast<BinaryOperator>(I); return nullptr; } void ReassociatePass::BuildRankMap(Function &F, ReversePostOrderTraversal<Function*> &RPOT) { - unsigned i = 2; + unsigned Rank = 2; // Assign distinct ranks to function arguments. - for (Function::arg_iterator I = F.arg_begin(), E = F.arg_end(); I != E; ++I) { - ValueRankMap[&*I] = ++i; - DEBUG(dbgs() << "Calculated Rank[" << I->getName() << "] = " << i << "\n"); + for (auto &Arg : F.args()) { + ValueRankMap[&Arg] = ++Rank; + DEBUG(dbgs() << "Calculated Rank[" << Arg.getName() << "] = " << Rank + << "\n"); } // Traverse basic blocks in ReversePostOrder for (BasicBlock *BB : RPOT) { - unsigned BBRank = RankMap[BB] = ++i << 16; + unsigned BBRank = RankMap[BB] = ++Rank << 16; // Walk the basic block, adding precomputed ranks for any instructions that // we cannot move. This ensures that the ranks for these instructions are @@ -207,13 +222,9 @@ void ReassociatePass::canonicalizeOperands(Instruction *I) { Value *LHS = I->getOperand(0); Value *RHS = I->getOperand(1); - unsigned LHSRank = getRank(LHS); - unsigned RHSRank = getRank(RHS); - - if (isa<Constant>(RHS)) + if (LHS == RHS || isa<Constant>(RHS)) return; - - if (isa<Constant>(LHS) || RHSRank < LHSRank) + if (isa<Constant>(LHS) || getRank(RHS) < getRank(LHS)) cast<BinaryOperator>(I)->swapOperands(); } @@ -357,7 +368,7 @@ static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) { } } -typedef std::pair<Value*, APInt> RepeatedValue; +using RepeatedValue = std::pair<Value*, APInt>; /// Given an associative binary expression, return the leaf /// nodes in Ops along with their weights (how many times the leaf occurs). The @@ -432,7 +443,6 @@ typedef std::pair<Value*, APInt> RepeatedValue; /// that have all uses inside the expression (i.e. only used by non-leaf nodes /// of the expression) if it can turn them into binary operators of the right /// type and thus make the expression bigger. - static bool LinearizeExprTree(BinaryOperator *I, SmallVectorImpl<RepeatedValue> &Ops) { DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); @@ -470,12 +480,12 @@ static bool LinearizeExprTree(BinaryOperator *I, // Leaves - Keeps track of the set of putative leaves as well as the number of // paths to each leaf seen so far. - typedef DenseMap<Value*, APInt> LeafMap; + using LeafMap = DenseMap<Value *, APInt>; LeafMap Leaves; // Leaf -> Total weight so far. - SmallVector<Value*, 8> LeafOrder; // Ensure deterministic leaf output order. + SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order. #ifndef NDEBUG - SmallPtrSet<Value*, 8> Visited; // For sanity checking the iteration scheme. + SmallPtrSet<Value *, 8> Visited; // For sanity checking the iteration scheme. #endif while (!Worklist.empty()) { std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val(); @@ -554,7 +564,7 @@ static bool LinearizeExprTree(BinaryOperator *I, assert((!isa<Instruction>(Op) || cast<Instruction>(Op)->getOpcode() != Opcode || (isa<FPMathOperator>(Op) && - !cast<Instruction>(Op)->hasUnsafeAlgebra())) && + !cast<Instruction>(Op)->isFast())) && "Should have been handled above!"); assert(Op->hasOneUse() && "Has uses outside the expression tree!"); @@ -773,7 +783,7 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, break; ExpressionChanged->moveBefore(I); ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin()); - } while (1); + } while (true); // Throw away any left over nodes from the original expression. for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i) @@ -789,13 +799,9 @@ void ReassociatePass::RewriteExprTree(BinaryOperator *I, /// additional opportunities have been exposed. static Value *NegateValue(Value *V, Instruction *BI, SetVector<AssertingVH<Instruction>> &ToRedo) { - if (Constant *C = dyn_cast<Constant>(V)) { - if (C->getType()->isFPOrFPVectorTy()) { - return ConstantExpr::getFNeg(C); - } - return ConstantExpr::getNeg(C); - } - + if (auto *C = dyn_cast<Constant>(V)) + return C->getType()->isFPOrFPVectorTy() ? ConstantExpr::getFNeg(C) : + ConstantExpr::getNeg(C); // We are trying to expose opportunity for reassociation. One of the things // that we want to do to achieve this is to push a negation as deep into an @@ -913,7 +919,6 @@ BreakUpSubtract(Instruction *Sub, SetVector<AssertingVH<Instruction>> &ToRedo) { // // Calculate the negative value of Operand 1 of the sub instruction, // and set it as the RHS of the add instruction we just made. - // Value *NegVal = NegateValue(Sub->getOperand(1), Sub, ToRedo); BinaryOperator *New = CreateAdd(Sub->getOperand(0), NegVal, "", Sub, Sub); Sub->setOperand(0, Constant::getNullValue(Sub->getType())); // Drop use of op. @@ -990,7 +995,7 @@ static Value *EmitAddTreeOfValues(Instruction *I, Value *V1 = Ops.back(); Ops.pop_back(); Value *V2 = EmitAddTreeOfValues(I, Ops); - return CreateAdd(V2, V1, "tmp", I, I); + return CreateAdd(V2, V1, "reass.add", I, I); } /// If V is an expression tree that is a multiplication sequence, @@ -1157,7 +1162,6 @@ static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, // If it was successful, true is returned, and the "R" and "C" is returned // via "Res" and "ConstOpnd", respectively; otherwise, false is returned, // and both "Res" and "ConstOpnd" remain unchanged. -// bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, APInt &ConstOpnd, Value *&Res) { // Xor-Rule 1: (x | c1) ^ c2 = (x | c1) ^ (c1 ^ c1) ^ c2 @@ -1183,7 +1187,6 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, RedoInsts.insert(T); return true; } - // Helper function of OptimizeXor(). It tries to simplify // "Opnd1 ^ Opnd2 ^ ConstOpnd" into "R ^ C", where C would be 0, and R is a @@ -1230,7 +1233,6 @@ bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, Res = createAndInstr(I, X, C3); ConstOpnd ^= C1; - } else if (Opnd1->isOrExpr()) { // Xor-Rule 3: (x | c1) ^ (x | c2) = (x & c3) ^ c3 where c3 = c1 ^ c2 // @@ -1349,7 +1351,6 @@ Value *ReassociatePass::OptimizeXor(Instruction *I, // step 3.2: When previous and current operands share the same symbolic // value, try to simplify "PrevOpnd ^ CurrOpnd ^ ConstOpnd" - // if (CombineXorOpnd(I, CurrOpnd, PrevOpnd, ConstOpnd, CV)) { // Remove previous operand PrevOpnd->Invalidate(); @@ -1601,7 +1602,7 @@ Value *ReassociatePass::OptimizeAdd(Instruction *I, RedoInsts.insert(VI); // Create the multiply. - Instruction *V2 = CreateMul(V, MaxOccVal, "tmp", I, I); + Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I, I); // Rerun associate on the multiply in case the inner expression turned into // a multiply. We want to make sure that we keep things in canonical form. @@ -2012,8 +2013,8 @@ void ReassociatePass::OptimizeInst(Instruction *I) { if (I->isCommutative()) canonicalizeOperands(I); - // Don't optimize floating point instructions that don't have unsafe algebra. - if (I->getType()->isFPOrFPVectorTy() && !I->hasUnsafeAlgebra()) + // Don't optimize floating-point instructions unless they are 'fast'. + if (I->getType()->isFPOrFPVectorTy() && !I->isFast()) return; // Do not reassociate boolean (i1) expressions. We want to preserve the @@ -2140,7 +2141,8 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { DEBUG(dbgs() << "Reassoc to scalar: " << *V << '\n'); I->replaceAllUsesWith(V); if (Instruction *VI = dyn_cast<Instruction>(V)) - VI->setDebugLoc(I->getDebugLoc()); + if (I->getDebugLoc()) + VI->setDebugLoc(I->getDebugLoc()); RedoInsts.insert(I); ++NumAnnihil; return; @@ -2183,11 +2185,104 @@ void ReassociatePass::ReassociateExpression(BinaryOperator *I) { return; } + if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) { + // Find the pair with the highest count in the pairmap and move it to the + // back of the list so that it can later be CSE'd. + // example: + // a*b*c*d*e + // if c*e is the most "popular" pair, we can express this as + // (((c*e)*d)*b)*a + unsigned Max = 1; + unsigned BestRank = 0; + std::pair<unsigned, unsigned> BestPair; + unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin; + for (unsigned i = 0; i < Ops.size() - 1; ++i) + for (unsigned j = i + 1; j < Ops.size(); ++j) { + unsigned Score = 0; + Value *Op0 = Ops[i].Op; + Value *Op1 = Ops[j].Op; + if (std::less<Value *>()(Op1, Op0)) + std::swap(Op0, Op1); + auto it = PairMap[Idx].find({Op0, Op1}); + if (it != PairMap[Idx].end()) + Score += it->second; + + unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); + if (Score > Max || (Score == Max && MaxRank < BestRank)) { + BestPair = {i, j}; + Max = Score; + BestRank = MaxRank; + } + } + if (Max > 1) { + auto Op0 = Ops[BestPair.first]; + auto Op1 = Ops[BestPair.second]; + Ops.erase(&Ops[BestPair.second]); + Ops.erase(&Ops[BestPair.first]); + Ops.push_back(Op0); + Ops.push_back(Op1); + } + } // Now that we ordered and optimized the expressions, splat them back into // the expression tree, removing any unneeded nodes. RewriteExprTree(I, Ops); } +void +ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) { + // Make a "pairmap" of how often each operand pair occurs. + for (BasicBlock *BI : RPOT) { + for (Instruction &I : *BI) { + if (!I.isAssociative()) + continue; + + // Ignore nodes that aren't at the root of trees. + if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode()) + continue; + + // Collect all operands in a single reassociable expression. + // Since Reassociate has already been run once, we can assume things + // are already canonical according to Reassociation's regime. + SmallVector<Value *, 8> Worklist = { I.getOperand(0), I.getOperand(1) }; + SmallVector<Value *, 8> Ops; + while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) { + Value *Op = Worklist.pop_back_val(); + Instruction *OpI = dyn_cast<Instruction>(Op); + if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) { + Ops.push_back(Op); + continue; + } + // Be paranoid about self-referencing expressions in unreachable code. + if (OpI->getOperand(0) != OpI) + Worklist.push_back(OpI->getOperand(0)); + if (OpI->getOperand(1) != OpI) + Worklist.push_back(OpI->getOperand(1)); + } + // Skip extremely long expressions. + if (Ops.size() > GlobalReassociateLimit) + continue; + + // Add all pairwise combinations of operands to the pair map. + unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin; + SmallSet<std::pair<Value *, Value*>, 32> Visited; + for (unsigned i = 0; i < Ops.size() - 1; ++i) { + for (unsigned j = i + 1; j < Ops.size(); ++j) { + // Canonicalize operand orderings. + Value *Op0 = Ops[i]; + Value *Op1 = Ops[j]; + if (std::less<Value *>()(Op1, Op0)) + std::swap(Op0, Op1); + if (!Visited.insert({Op0, Op1}).second) + continue; + auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1}); + if (!res.second) + ++res.first->second; + } + } + } + } +} + PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { // Get the functions basic blocks in Reverse Post Order. This order is used by // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic @@ -2198,8 +2293,20 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { // Calculate the rank map for F. BuildRankMap(F, RPOT); + // Build the pair map before running reassociate. + // Technically this would be more accurate if we did it after one round + // of reassociation, but in practice it doesn't seem to help much on + // real-world code, so don't waste the compile time running reassociate + // twice. + // If a user wants, they could expicitly run reassociate twice in their + // pass pipeline for further potential gains. + // It might also be possible to update the pair map during runtime, but the + // overhead of that may be large if there's many reassociable chains. + BuildPairMap(RPOT); + MadeChange = false; - // Traverse the same blocks that was analysed by BuildRankMap. + + // Traverse the same blocks that were analysed by BuildRankMap. for (BasicBlock *BI : RPOT) { assert(RankMap.count(&*BI) && "BB should be ranked."); // Optimize every instruction in the basic block. @@ -2238,9 +2345,11 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { } } - // We are done with the rank map. + // We are done with the rank map and pair map. RankMap.clear(); ValueRankMap.clear(); + for (auto &Entry : PairMap) + Entry.clear(); if (MadeChange) { PreservedAnalyses PA; @@ -2253,10 +2362,13 @@ PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { } namespace { + class ReassociateLegacyPass : public FunctionPass { ReassociatePass Impl; + public: static char ID; // Pass identification, replacement for typeid + ReassociateLegacyPass() : FunctionPass(ID) { initializeReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -2275,9 +2387,11 @@ namespace { AU.addPreserved<GlobalsAAWrapperPass>(); } }; -} + +} // end anonymous namespace char ReassociateLegacyPass::ID = 0; + INITIALIZE_PASS(ReassociateLegacyPass, "reassociate", "Reassociate expressions", false, false) diff --git a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp index f19d45329d23..3b45cfa482e6 100644 --- a/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp +++ b/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -12,36 +12,69 @@ // //===----------------------------------------------------------------------===// +#include "llvm/Transforms/Scalar/RewriteStatepointsForGC.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SetOperations.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/Analysis/CFG.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" #include "llvm/IR/Statepoint.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/IR/Verifier.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <set> +#include <string> +#include <utility> +#include <vector> #define DEBUG_TYPE "rewrite-statepoints-for-gc" @@ -52,6 +85,7 @@ static cl::opt<bool> PrintLiveSet("spp-print-liveset", cl::Hidden, cl::init(false)); static cl::opt<bool> PrintLiveSetSize("spp-print-liveset-size", cl::Hidden, cl::init(false)); + // Print out the base pointers for debugging static cl::opt<bool> PrintBasePointers("spp-print-base-pointers", cl::Hidden, cl::init(false)); @@ -67,6 +101,7 @@ static bool ClobberNonLive = true; #else static bool ClobberNonLive = false; #endif + static cl::opt<bool, true> ClobberNonLiveOverride("rs4gc-clobber-non-live", cl::location(ClobberNonLive), cl::Hidden); @@ -75,27 +110,96 @@ static cl::opt<bool> AllowStatepointWithNoDeoptInfo("rs4gc-allow-statepoint-with-no-deopt-info", cl::Hidden, cl::init(true)); +/// The IR fed into RewriteStatepointsForGC may have had attributes and +/// metadata implying dereferenceability that are no longer valid/correct after +/// RewriteStatepointsForGC has run. This is because semantically, after +/// RewriteStatepointsForGC runs, all calls to gc.statepoint "free" the entire +/// heap. stripNonValidData (conservatively) restores +/// correctness by erasing all attributes in the module that externally imply +/// dereferenceability. Similar reasoning also applies to the noalias +/// attributes and metadata. gc.statepoint can touch the entire heap including +/// noalias objects. +/// Apart from attributes and metadata, we also remove instructions that imply +/// constant physical memory: llvm.invariant.start. +static void stripNonValidData(Module &M); + +static bool shouldRewriteStatepointsIn(Function &F); + +PreservedAnalyses RewriteStatepointsForGC::run(Module &M, + ModuleAnalysisManager &AM) { + bool Changed = false; + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + for (Function &F : M) { + // Nothing to do for declarations. + if (F.isDeclaration() || F.empty()) + continue; + + // Policy choice says not to rewrite - the most common reason is that we're + // compiling code without a GCStrategy. + if (!shouldRewriteStatepointsIn(F)) + continue; + + auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); + auto &TTI = FAM.getResult<TargetIRAnalysis>(F); + auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + Changed |= runOnFunction(F, DT, TTI, TLI); + } + if (!Changed) + return PreservedAnalyses::all(); + + // stripNonValidData asserts that shouldRewriteStatepointsIn + // returns true for at least one function in the module. Since at least + // one function changed, we know that the precondition is satisfied. + stripNonValidData(M); + + PreservedAnalyses PA; + PA.preserve<TargetIRAnalysis>(); + PA.preserve<TargetLibraryAnalysis>(); + return PA; +} + namespace { -struct RewriteStatepointsForGC : public ModulePass { + +class RewriteStatepointsForGCLegacyPass : public ModulePass { + RewriteStatepointsForGC Impl; + +public: static char ID; // Pass identification, replacement for typeid - RewriteStatepointsForGC() : ModulePass(ID) { - initializeRewriteStatepointsForGCPass(*PassRegistry::getPassRegistry()); + RewriteStatepointsForGCLegacyPass() : ModulePass(ID), Impl() { + initializeRewriteStatepointsForGCLegacyPassPass( + *PassRegistry::getPassRegistry()); } - bool runOnFunction(Function &F); + bool runOnModule(Module &M) override { bool Changed = false; - for (Function &F : M) - Changed |= runOnFunction(F); - - if (Changed) { - // stripNonValidAttributesAndMetadata asserts that shouldRewriteStatepointsIn - // returns true for at least one function in the module. Since at least - // one function changed, we know that the precondition is satisfied. - stripNonValidAttributesAndMetadata(M); + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + for (Function &F : M) { + // Nothing to do for declarations. + if (F.isDeclaration() || F.empty()) + continue; + + // Policy choice says not to rewrite - the most common reason is that + // we're compiling code without a GCStrategy. + if (!shouldRewriteStatepointsIn(F)) + continue; + + TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + + Changed |= Impl.runOnFunction(F, DT, TTI, TLI); } - return Changed; + if (!Changed) + return false; + + // stripNonValidData asserts that shouldRewriteStatepointsIn + // returns true for at least one function in the module. Since at least + // one function changed, we know that the precondition is satisfied. + stripNonValidData(M); + return true; } void getAnalysisUsage(AnalysisUsage &AU) const override { @@ -103,46 +207,33 @@ struct RewriteStatepointsForGC : public ModulePass { // else. We could in theory preserve a lot more analyses here. AU.addRequired<DominatorTreeWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); } - - /// The IR fed into RewriteStatepointsForGC may have had attributes and - /// metadata implying dereferenceability that are no longer valid/correct after - /// RewriteStatepointsForGC has run. This is because semantically, after - /// RewriteStatepointsForGC runs, all calls to gc.statepoint "free" the entire - /// heap. stripNonValidAttributesAndMetadata (conservatively) restores - /// correctness by erasing all attributes in the module that externally imply - /// dereferenceability. Similar reasoning also applies to the noalias - /// attributes and metadata. gc.statepoint can touch the entire heap including - /// noalias objects. - void stripNonValidAttributesAndMetadata(Module &M); - - // Helpers for stripNonValidAttributesAndMetadata - void stripNonValidAttributesAndMetadataFromBody(Function &F); - void stripNonValidAttributesFromPrototype(Function &F); - // Certain metadata on instructions are invalid after running RS4GC. - // Optimizations that run after RS4GC can incorrectly use this metadata to - // optimize functions. We drop such metadata on the instruction. - void stripInvalidMetadataFromInstruction(Instruction &I); }; -} // namespace -char RewriteStatepointsForGC::ID = 0; +} // end anonymous namespace + +char RewriteStatepointsForGCLegacyPass::ID = 0; -ModulePass *llvm::createRewriteStatepointsForGCPass() { - return new RewriteStatepointsForGC(); +ModulePass *llvm::createRewriteStatepointsForGCLegacyPass() { + return new RewriteStatepointsForGCLegacyPass(); } -INITIALIZE_PASS_BEGIN(RewriteStatepointsForGC, "rewrite-statepoints-for-gc", +INITIALIZE_PASS_BEGIN(RewriteStatepointsForGCLegacyPass, + "rewrite-statepoints-for-gc", "Make relocations explicit at statepoints", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_END(RewriteStatepointsForGC, "rewrite-statepoints-for-gc", +INITIALIZE_PASS_END(RewriteStatepointsForGCLegacyPass, + "rewrite-statepoints-for-gc", "Make relocations explicit at statepoints", false, false) namespace { + struct GCPtrLivenessData { /// Values defined in this block. MapVector<BasicBlock *, SetVector<Value *>> KillSet; + /// Values used in this block (and thus live); does not included values /// killed within this block. MapVector<BasicBlock *, SetVector<Value *>> LiveSet; @@ -166,10 +257,10 @@ struct GCPtrLivenessData { // Generally, after the execution of a full findBasePointer call, only the // base relation will remain. Internally, we add a mixture of the two // types, then update all the second type to the first type -typedef MapVector<Value *, Value *> DefiningValueMapTy; -typedef SetVector<Value *> StatepointLiveSetTy; -typedef MapVector<AssertingVH<Instruction>, AssertingVH<Value>> - RematerializedValueMapTy; +using DefiningValueMapTy = MapVector<Value *, Value *>; +using StatepointLiveSetTy = SetVector<Value *>; +using RematerializedValueMapTy = + MapVector<AssertingVH<Instruction>, AssertingVH<Value>>; struct PartiallyConstructedSafepointRecord { /// The set of values known to be live across this safepoint @@ -191,7 +282,8 @@ struct PartiallyConstructedSafepointRecord { /// Maps rematerialized copy to it's original value. RematerializedValueMapTy RematerializedValues; }; -} + +} // end anonymous namespace static ArrayRef<Use> GetDeoptBundleOperands(ImmutableCallSite CS) { Optional<OperandBundleUse> DeoptBundle = @@ -254,7 +346,7 @@ static bool containsGCPtrType(Type *Ty) { if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) return containsGCPtrType(AT->getElementType()); if (StructType *ST = dyn_cast<StructType>(Ty)) - return any_of(ST->subtypes(), containsGCPtrType); + return llvm::any_of(ST->subtypes(), containsGCPtrType); return false; } @@ -299,7 +391,9 @@ analyzeParsePointLiveness(DominatorTree &DT, } static bool isKnownBaseResult(Value *V); + namespace { + /// A single base defining value - An immediate base defining value for an /// instruction 'Def' is an input to 'Def' whose base is also a base of 'Def'. /// For instructions which have multiple pointer [vector] inputs or that @@ -311,9 +405,11 @@ namespace { struct BaseDefiningValueResult { /// Contains the value which is the base defining value. Value * const BDV; + /// True if the base defining value is also known to be an actual base /// pointer. const bool IsKnownBase; + BaseDefiningValueResult(Value *BDV, bool IsKnownBase) : BDV(BDV), IsKnownBase(IsKnownBase) { #ifndef NDEBUG @@ -324,7 +420,8 @@ struct BaseDefiningValueResult { #endif } }; -} + +} // end anonymous namespace static BaseDefiningValueResult findBaseDefiningValue(Value *I); @@ -374,6 +471,11 @@ findBaseDefiningValueOfVector(Value *I) { if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) return findBaseDefiningValue(GEP->getPointerOperand()); + // If the pointer comes through a bitcast of a vector of pointers to + // a vector of another type of pointer, then look through the bitcast + if (auto *BC = dyn_cast<BitCastInst>(I)) + return findBaseDefiningValue(BC->getOperand(0)); + // A PHI or Select is a base defining value. The outer findBasePointer // algorithm is responsible for constructing a base value for this BDV. assert((isa<SelectInst>(I) || isa<PHINode>(I)) && @@ -429,7 +531,6 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { if (isa<LoadInst>(I)) // The value loaded is an gc base itself return BaseDefiningValueResult(I, true); - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) // The base of this GEP is the base @@ -442,12 +543,11 @@ static BaseDefiningValueResult findBaseDefiningValue(Value *I) { break; case Intrinsic::experimental_gc_statepoint: llvm_unreachable("statepoints don't produce pointers"); - case Intrinsic::experimental_gc_relocate: { + case Intrinsic::experimental_gc_relocate: // Rerunning safepoint insertion after safepoints are already // inserted is not supported. It could probably be made to work, // but why are you doing this? There's no good reason. llvm_unreachable("repeat safepoint insertion is not supported"); - } case Intrinsic::gcroot: // Currently, this mechanism hasn't been extended to work with gcroot. // There's no reason it couldn't be, but I haven't thought about the @@ -551,6 +651,7 @@ static bool isKnownBaseResult(Value *V) { } namespace { + /// Models the state of a single base defining value in the findBasePointer /// algorithm for determining where a new instruction is needed to propagate /// the base of this BDV. @@ -558,7 +659,7 @@ class BDVState { public: enum Status { Unknown, Base, Conflict }; - BDVState() : Status(Unknown), BaseValue(nullptr) {} + BDVState() : BaseValue(nullptr) {} explicit BDVState(Status Status, Value *BaseValue = nullptr) : Status(Status), BaseValue(BaseValue) { @@ -597,16 +698,17 @@ public: case Conflict: OS << "C"; break; - }; + } OS << " (" << getBaseValue() << " - " << (getBaseValue() ? getBaseValue()->getName() : "nullptr") << "): "; } private: - Status Status; + Status Status = Unknown; AssertingVH<Value> BaseValue; // Non-null only if Status == Base. }; -} + +} // end anonymous namespace #ifndef NDEBUG static raw_ostream &operator<<(raw_ostream &OS, const BDVState &State) { @@ -1169,7 +1271,7 @@ static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, return; auto FindIndex = [](ArrayRef<Value *> LiveVec, Value *Val) { - auto ValIt = find(LiveVec, Val); + auto ValIt = llvm::find(LiveVec, Val); assert(ValIt != LiveVec.end() && "Val not found in LiveVec!"); size_t Index = std::distance(LiveVec.begin(), ValIt); assert(Index < LiveVec.size() && "Bug in std::find?"); @@ -1229,7 +1331,7 @@ class DeferredReplacement { AssertingVH<Instruction> New; bool IsDeoptimize = false; - DeferredReplacement() {} + DeferredReplacement() = default; public: static DeferredReplacement createRAUW(Instruction *Old, Instruction *New) { @@ -1286,7 +1388,8 @@ public: OldI->eraseFromParent(); } }; -} + +} // end anonymous namespace static StringRef getDeoptLowering(CallSite CS) { const char *DeoptLowering = "deopt-lowering"; @@ -1304,7 +1407,6 @@ static StringRef getDeoptLowering(CallSite CS) { return "live-through"; } - static void makeStatepointExplicitImpl(const CallSite CS, /* to replace */ const SmallVectorImpl<Value *> &BasePtrs, @@ -1528,7 +1630,6 @@ static void insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, DenseMap<Value *, Value *> &AllocaMap, DenseSet<Value *> &VisitedLiveValues) { - for (User *U : GCRelocs) { GCRelocateInst *Relocate = dyn_cast<GCRelocateInst>(U); if (!Relocate) @@ -1564,7 +1665,6 @@ static void insertRematerializationStores( const RematerializedValueMapTy &RematerializedValues, DenseMap<Value *, Value *> &AllocaMap, DenseSet<Value *> &VisitedLiveValues) { - for (auto RematerializedValuePair: RematerializedValues) { Instruction *RematerializedValue = RematerializedValuePair.first; Value *OriginalValue = RematerializedValuePair.second; @@ -1830,7 +1930,6 @@ static void findLiveReferences( static Value* findRematerializableChainToBasePointer( SmallVectorImpl<Instruction*> &ChainToBase, Value *CurrentValue) { - if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurrentValue)) { ChainToBase.push_back(GEP); return findRematerializableChainToBasePointer(ChainToBase, @@ -1886,7 +1985,6 @@ chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain, } static bool AreEquivalentPhiNodes(PHINode &OrigRootPhi, PHINode &AlternateRootPhi) { - unsigned PhiNum = OrigRootPhi.getNumIncomingValues(); if (PhiNum != AlternateRootPhi.getNumIncomingValues() || OrigRootPhi.getParent() != AlternateRootPhi.getParent()) @@ -1910,7 +2008,6 @@ static bool AreEquivalentPhiNodes(PHINode &OrigRootPhi, PHINode &AlternateRootPh return false; } return true; - } // From the statepoint live set pick values that are cheaper to recompute then @@ -2297,8 +2394,7 @@ static void RemoveNonValidAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH, AH.setAttributes(AH.getAttributes().removeAttributes(Ctx, Index, R)); } -void -RewriteStatepointsForGC::stripNonValidAttributesFromPrototype(Function &F) { +static void stripNonValidAttributesFromPrototype(Function &F) { LLVMContext &Ctx = F.getContext(); for (Argument &A : F.args()) @@ -2310,8 +2406,10 @@ RewriteStatepointsForGC::stripNonValidAttributesFromPrototype(Function &F) { RemoveNonValidAttrAtIndex(Ctx, F, AttributeList::ReturnIndex); } -void RewriteStatepointsForGC::stripInvalidMetadataFromInstruction(Instruction &I) { - +/// Certain metadata on instructions are invalid after running RS4GC. +/// Optimizations that run after RS4GC can incorrectly use this metadata to +/// optimize functions. We drop such metadata on the instruction. +static void stripInvalidMetadataFromInstruction(Instruction &I) { if (!isa<LoadInst>(I) && !isa<StoreInst>(I)) return; // These are the attributes that are still valid on loads and stores after @@ -2337,18 +2435,32 @@ void RewriteStatepointsForGC::stripInvalidMetadataFromInstruction(Instruction &I // Drops all metadata on the instruction other than ValidMetadataAfterRS4GC. I.dropUnknownNonDebugMetadata(ValidMetadataAfterRS4GC); - } -void RewriteStatepointsForGC::stripNonValidAttributesAndMetadataFromBody(Function &F) { +static void stripNonValidDataFromBody(Function &F) { if (F.empty()) return; LLVMContext &Ctx = F.getContext(); MDBuilder Builder(Ctx); + // Set of invariantstart instructions that we need to remove. + // Use this to avoid invalidating the instruction iterator. + SmallVector<IntrinsicInst*, 12> InvariantStartInstructions; for (Instruction &I : instructions(F)) { + // invariant.start on memory location implies that the referenced memory + // location is constant and unchanging. This is no longer true after + // RewriteStatepointsForGC runs because there can be calls to gc.statepoint + // which frees the entire heap and the presence of invariant.start allows + // the optimizer to sink the load of a memory location past a statepoint, + // which is incorrect. + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::invariant_start) { + InvariantStartInstructions.push_back(II); + continue; + } + if (const MDNode *MD = I.getMetadata(LLVMContext::MD_tbaa)) { assert(MD->getNumOperands() < 5 && "unrecognized metadata shape!"); bool IsImmutableTBAA = @@ -2378,6 +2490,12 @@ void RewriteStatepointsForGC::stripNonValidAttributesAndMetadataFromBody(Functio RemoveNonValidAttrAtIndex(Ctx, CS, AttributeList::ReturnIndex); } } + + // Delete the invariant.start instructions and RAUW undef. + for (auto *II : InvariantStartInstructions) { + II->replaceAllUsesWith(UndefValue::get(II->getType())); + II->eraseFromParent(); + } } /// Returns true if this function should be rewritten by this pass. The main @@ -2394,35 +2512,28 @@ static bool shouldRewriteStatepointsIn(Function &F) { return false; } -void RewriteStatepointsForGC::stripNonValidAttributesAndMetadata(Module &M) { +static void stripNonValidData(Module &M) { #ifndef NDEBUG - assert(any_of(M, shouldRewriteStatepointsIn) && "precondition!"); + assert(llvm::any_of(M, shouldRewriteStatepointsIn) && "precondition!"); #endif for (Function &F : M) stripNonValidAttributesFromPrototype(F); for (Function &F : M) - stripNonValidAttributesAndMetadataFromBody(F); + stripNonValidDataFromBody(F); } -bool RewriteStatepointsForGC::runOnFunction(Function &F) { - // Nothing to do for declarations. - if (F.isDeclaration() || F.empty()) - return false; - - // Policy choice says not to rewrite - the most common reason is that we're - // compiling code without a GCStrategy. - if (!shouldRewriteStatepointsIn(F)) - return false; - - DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); - TargetTransformInfo &TTI = - getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); +bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, + TargetTransformInfo &TTI, + const TargetLibraryInfo &TLI) { + assert(!F.isDeclaration() && !F.empty() && + "need function body to rewrite statepoints in"); + assert(shouldRewriteStatepointsIn(F) && "mismatch in rewrite decision"); - auto NeedsRewrite = [](Instruction &I) { + auto NeedsRewrite = [&TLI](Instruction &I) { if (ImmutableCallSite CS = ImmutableCallSite(&I)) - return !callsGCLeafFunction(CS) && !isStatepoint(CS); + return !callsGCLeafFunction(CS, TLI) && !isStatepoint(CS); return false; }; @@ -2662,7 +2773,6 @@ static void computeLiveInValues(DominatorTree &DT, Function &F, static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, StatepointLiveSetTy &Out) { - BasicBlock *BB = Inst->getParent(); // Note: The copy is intentional and required diff --git a/lib/Transforms/Scalar/SCCP.cpp b/lib/Transforms/Scalar/SCCP.cpp index 4822cf7cce0f..e5866b4718da 100644 --- a/lib/Transforms/Scalar/SCCP.cpp +++ b/lib/Transforms/Scalar/SCCP.cpp @@ -18,30 +18,49 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/IPO/SCCP.h" +#include "llvm/Transforms/Scalar/SCCP.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueLattice.h" +#include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/IPO.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Scalar/SCCP.h" #include "llvm/Transforms/Utils/Local.h" -#include <algorithm> +#include <cassert> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "sccp" @@ -52,8 +71,11 @@ STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP"); STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); +STATISTIC(IPNumRangeInfoUsed, "Number of times constant range info was used by" + "IPSCCP"); namespace { + /// LatticeVal class - This class represents the different lattice values that /// an LLVM value may occupy. It is a simple class with value semantics. /// @@ -88,9 +110,11 @@ public: LatticeVal() : Val(nullptr, unknown) {} bool isUnknown() const { return getLatticeValue() == unknown; } + bool isConstant() const { return getLatticeValue() == constant || getLatticeValue() == forcedconstant; } + bool isOverdefined() const { return getLatticeValue() == overdefined; } Constant *getConstant() const { @@ -153,11 +177,15 @@ public: Val.setInt(forcedconstant); Val.setPointer(V); } -}; -} // end anonymous namespace. - -namespace { + ValueLatticeElement toValueLattice() const { + if (isOverdefined()) + return ValueLatticeElement::getOverdefined(); + if (isConstant()) + return ValueLatticeElement::get(getConstant()); + return ValueLatticeElement(); + } +}; //===----------------------------------------------------------------------===// // @@ -167,37 +195,38 @@ namespace { class SCCPSolver : public InstVisitor<SCCPSolver> { const DataLayout &DL; const TargetLibraryInfo *TLI; - SmallPtrSet<BasicBlock*, 8> BBExecutable; // The BBs that are executable. - DenseMap<Value*, LatticeVal> ValueState; // The state each value is in. + SmallPtrSet<BasicBlock *, 8> BBExecutable; // The BBs that are executable. + DenseMap<Value *, LatticeVal> ValueState; // The state each value is in. + // The state each parameter is in. + DenseMap<Value *, ValueLatticeElement> ParamState; /// StructValueState - This maintains ValueState for values that have /// StructType, for example for formal arguments, calls, insertelement, etc. - /// - DenseMap<std::pair<Value*, unsigned>, LatticeVal> StructValueState; + DenseMap<std::pair<Value *, unsigned>, LatticeVal> StructValueState; /// GlobalValue - If we are tracking any values for the contents of a global /// variable, we keep a mapping from the constant accessor to the element of /// the global, to the currently known value. If the value becomes /// overdefined, it's entry is simply removed from this map. - DenseMap<GlobalVariable*, LatticeVal> TrackedGlobals; + DenseMap<GlobalVariable *, LatticeVal> TrackedGlobals; /// TrackedRetVals - If we are tracking arguments into and the return /// value out of a function, it will have an entry in this map, indicating /// what the known return value for the function is. - DenseMap<Function*, LatticeVal> TrackedRetVals; + DenseMap<Function *, LatticeVal> TrackedRetVals; /// TrackedMultipleRetVals - Same as TrackedRetVals, but used for functions /// that return multiple values. - DenseMap<std::pair<Function*, unsigned>, LatticeVal> TrackedMultipleRetVals; + DenseMap<std::pair<Function *, unsigned>, LatticeVal> TrackedMultipleRetVals; /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is /// represented here for efficient lookup. - SmallPtrSet<Function*, 16> MRVFunctionsTracked; + SmallPtrSet<Function *, 16> MRVFunctionsTracked; /// TrackingIncomingArguments - This is the set of functions for whose /// arguments we make optimistic assumptions about and try to prove as /// constants. - SmallPtrSet<Function*, 16> TrackingIncomingArguments; + SmallPtrSet<Function *, 16> TrackingIncomingArguments; /// The reason for two worklists is that overdefined is the lowest state /// on the lattice, and moving things to overdefined as fast as possible @@ -206,16 +235,17 @@ class SCCPSolver : public InstVisitor<SCCPSolver> { /// By having a separate worklist, we accomplish this because everything /// possibly overdefined will become overdefined at the soonest possible /// point. - SmallVector<Value*, 64> OverdefinedInstWorkList; - SmallVector<Value*, 64> InstWorkList; - + SmallVector<Value *, 64> OverdefinedInstWorkList; + SmallVector<Value *, 64> InstWorkList; - SmallVector<BasicBlock*, 64> BBWorkList; // The BasicBlock work list + // The BasicBlock work list + SmallVector<BasicBlock *, 64> BBWorkList; /// KnownFeasibleEdges - Entries in this set are edges which have already had /// PHI nodes retriggered. - typedef std::pair<BasicBlock*, BasicBlock*> Edge; + using Edge = std::pair<BasicBlock *, BasicBlock *>; DenseSet<Edge> KnownFeasibleEdges; + public: SCCPSolver(const DataLayout &DL, const TargetLibraryInfo *tli) : DL(DL), TLI(tli) {} @@ -263,8 +293,13 @@ public: TrackingIncomingArguments.insert(F); } + /// Returns true if the given function is in the solver's set of + /// argument-tracked functions. + bool isArgumentTrackedFunction(Function *F) { + return TrackingIncomingArguments.count(F); + } + /// Solve - Solve for constants and executable blocks. - /// void Solve(); /// ResolvedUndefsIn - While solving the dataflow for a function, we assume @@ -290,14 +325,23 @@ public: return StructValues; } - LatticeVal getLatticeValueFor(Value *V) const { - DenseMap<Value*, LatticeVal>::const_iterator I = ValueState.find(V); - assert(I != ValueState.end() && "V is not in valuemap!"); - return I->second; + ValueLatticeElement getLatticeValueFor(Value *V) { + assert(!V->getType()->isStructTy() && + "Should use getStructLatticeValueFor"); + std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool> + PI = ParamState.insert(std::make_pair(V, ValueLatticeElement())); + ValueLatticeElement &LV = PI.first->second; + if (PI.second) { + DenseMap<Value*, LatticeVal>::const_iterator I = ValueState.find(V); + assert(I != ValueState.end() && + "V not found in ValueState nor Paramstate map!"); + LV = I->second.toValueLattice(); + } + + return LV; } /// getTrackedRetVals - Get the inferred return value map. - /// const DenseMap<Function*, LatticeVal> &getTrackedRetVals() { return TrackedRetVals; } @@ -349,7 +393,6 @@ private: // markConstant - Make a value be marked as "constant". If the value // is not already a constant, add it to the instruction work list so that // the users of the instruction are updated later. - // void markConstant(LatticeVal &IV, Value *V, Constant *C) { if (!IV.markConstant(C)) return; DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n'); @@ -369,7 +412,6 @@ private: pushToWorkList(IV, V); } - // markOverdefined - Make a value be marked as "overdefined". If the // value is not already overdefined, add it to the overdefined instruction // work list so that the users of the instruction are updated later. @@ -402,7 +444,6 @@ private: mergeInValue(ValueState[V], V, MergeWithV); } - /// getValueState - Return the LatticeVal object that corresponds to the /// value. This function handles the case when the value hasn't been seen yet /// by properly seeding constants etc. @@ -426,6 +467,18 @@ private: return LV; } + ValueLatticeElement &getParamState(Value *V) { + assert(!V->getType()->isStructTy() && "Should use getStructValueState"); + + std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool> + PI = ParamState.insert(std::make_pair(V, ValueLatticeElement())); + ValueLatticeElement &LV = PI.first->second; + if (PI.second) + LV = getValueState(V).toValueLattice(); + + return LV; + } + /// getStructValueState - Return the LatticeVal object that corresponds to the /// value/field pair. This function handles the case when the value hasn't /// been seen yet by properly seeding constants etc. @@ -457,7 +510,6 @@ private: return LV; } - /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB /// work list if it is not already executable. void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest) { @@ -480,18 +532,15 @@ private: // getFeasibleSuccessors - Return a vector of booleans to indicate which // successors are reachable from a given terminator instruction. - // void getFeasibleSuccessors(TerminatorInst &TI, SmallVectorImpl<bool> &Succs); // isEdgeFeasible - Return true if the control flow edge from the 'From' basic // block to the 'To' basic block is currently feasible. - // bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); // OperandChangedState - This method is invoked on all of the users of an // instruction that was just changed state somehow. Based on this // information, we need to update the specified user of this instruction. - // void OperandChangedState(Instruction *I) { if (BBExecutable.count(I->getParent())) // Inst is executable? visit(*I); @@ -506,6 +555,7 @@ private: void visitPHINode(PHINode &I); // Terminators + void visitReturnInst(ReturnInst &I); void visitTerminatorInst(TerminatorInst &TI); @@ -515,26 +565,32 @@ private: void visitCmpInst(CmpInst &I); void visitExtractValueInst(ExtractValueInst &EVI); void visitInsertValueInst(InsertValueInst &IVI); + void visitCatchSwitchInst(CatchSwitchInst &CPI) { markOverdefined(&CPI); visitTerminatorInst(CPI); } // Instructions that cannot be folded away. + void visitStoreInst (StoreInst &I); void visitLoadInst (LoadInst &I); void visitGetElementPtrInst(GetElementPtrInst &I); + void visitCallInst (CallInst &I) { visitCallSite(&I); } + void visitInvokeInst (InvokeInst &II) { visitCallSite(&II); visitTerminatorInst(II); } + void visitCallSite (CallSite CS); void visitResumeInst (TerminatorInst &I) { /*returns void*/ } void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ } void visitFenceInst (FenceInst &I) { /*returns void*/ } + void visitInstruction(Instruction &I) { // All the instructions we don't do any special handling for just // go to overdefined. @@ -545,10 +601,8 @@ private: } // end anonymous namespace - // getFeasibleSuccessors - Return a vector of booleans to indicate which // successors are reachable from a given terminator instruction. -// void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, SmallVectorImpl<bool> &Succs) { Succs.resize(TI.getNumSuccessors()); @@ -631,10 +685,8 @@ void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, llvm_unreachable("SCCP: Don't know how to handle this terminator!"); } - // isEdgeFeasible - Return true if the control flow edge from the 'From' basic // block to the 'To' basic block is currently feasible. -// bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { assert(BBExecutable.count(To) && "Dest should always be alive!"); @@ -710,7 +762,6 @@ bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { // destination executable // 7. If a conditional branch has a value that is overdefined, make all // successors executable. -// void SCCPSolver::visitPHINode(PHINode &PN) { // If this PN returns a struct, just mark the result overdefined. // TODO: We could do a lot better than this if code actually uses this. @@ -730,7 +781,6 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // constant, and they agree with each other, the PHI becomes the identical // constant. If they are constant and don't agree, the PHI is overdefined. // If there are no executable operands, the PHI remains unknown. - // Constant *OperandVal = nullptr; for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { LatticeVal IV = getValueState(PN.getIncomingValue(i)); @@ -761,7 +811,6 @@ void SCCPSolver::visitPHINode(PHINode &PN) { // arguments that agree with each other(and OperandVal is the constant) or // OperandVal is null because there are no defined incoming arguments. If // this is the case, the PHI remains unknown. - // if (OperandVal) markConstant(&PN, OperandVal); // Acquire operand value } @@ -789,7 +838,6 @@ void SCCPSolver::visitReturnInst(ReturnInst &I) { for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) mergeInValue(TrackedMultipleRetVals[std::make_pair(F, i)], F, getStructValueState(ResultOp, i)); - } } @@ -820,7 +868,6 @@ void SCCPSolver::visitCastInst(CastInst &I) { } } - void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { // If this returns a struct, mark all elements over defined, we don't track // structs in structs. @@ -969,7 +1016,6 @@ void SCCPSolver::visitBinaryOperator(Instruction &I) { } } - markOverdefined(&I); } @@ -998,7 +1044,6 @@ void SCCPSolver::visitCmpInst(CmpInst &I) { // Handle getelementptr instructions. If all operands are constants then we // can turn this into a getelementptr ConstantExpr. -// void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { if (ValueState[&I].isOverdefined()) return; @@ -1044,7 +1089,6 @@ void SCCPSolver::visitStoreInst(StoreInst &SI) { TrackedGlobals.erase(I); // No need to keep tracking this! } - // Handle load instructions. If the operand is a constant pointer to a constant // global, we can replace the load with the loaded constant value! void SCCPSolver::visitLoadInst(LoadInst &I) { @@ -1108,7 +1152,6 @@ CallOverdefined: // a declaration, maybe we can constant fold it. if (F && F->isDeclaration() && !I->getType()->isStructTy() && canConstantFoldCallTo(CS, F)) { - SmallVector<Constant*, 8> Operands; for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); AI != E; ++AI) { @@ -1162,6 +1205,9 @@ CallOverdefined: mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg); } } else { + // Most other parts of the Solver still only use the simpler value + // lattice, so we propagate changes for parameters to both lattices. + getParamState(&*AI).mergeIn(getValueState(*CAI).toValueLattice(), DL); mergeInValue(&*AI, getValueState(*CAI)); } } @@ -1360,7 +1406,6 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // undef & X -> 0. X could be zero. markForcedConstant(&I, Constant::getNullValue(ITy)); return true; - case Instruction::Or: // Both operands undef -> undef if (Op0LV.isUnknown() && Op1LV.isUnknown()) @@ -1368,7 +1413,6 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // undef | X -> -1. X could be -1. markForcedConstant(&I, Constant::getAllOnesValue(ITy)); return true; - case Instruction::Xor: // undef ^ undef -> 0; strictly speaking, this is not strictly // necessary, but we try to be nice to people who expect this @@ -1379,7 +1423,6 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { } // undef ^ X -> undef break; - case Instruction::SDiv: case Instruction::UDiv: case Instruction::SRem: @@ -1397,7 +1440,6 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // undef % X -> 0. X could be 1. markForcedConstant(&I, Constant::getNullValue(ITy)); return true; - case Instruction::AShr: // X >>a undef -> undef. if (Op1LV.isUnknown()) break; @@ -1464,7 +1506,7 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { markOverdefined(&I); return true; case Instruction::Call: - case Instruction::Invoke: { + case Instruction::Invoke: // There are two reasons a call can have an undef result // 1. It could be tracked. // 2. It could be constant-foldable. @@ -1478,7 +1520,6 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { // we do not know what return values are valid. markOverdefined(&I); return true; - } default: // If we don't know what should happen here, conservatively mark it // overdefined. @@ -1557,11 +1598,56 @@ bool SCCPSolver::ResolvedUndefsIn(Function &F) { return false; } +static bool tryToReplaceWithConstantRange(SCCPSolver &Solver, Value *V) { + bool Changed = false; + + // Currently we only use range information for integer values. + if (!V->getType()->isIntegerTy()) + return false; + + const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); + if (!IV.isConstantRange()) + return false; + + for (auto UI = V->uses().begin(), E = V->uses().end(); UI != E;) { + const Use &U = *UI++; + auto *Icmp = dyn_cast<ICmpInst>(U.getUser()); + if (!Icmp || !Solver.isBlockExecutable(Icmp->getParent())) + continue; + + auto getIcmpLatticeValue = [&](Value *Op) { + if (auto *C = dyn_cast<Constant>(Op)) + return ValueLatticeElement::get(C); + return Solver.getLatticeValueFor(Op); + }; + + ValueLatticeElement A = getIcmpLatticeValue(Icmp->getOperand(0)); + ValueLatticeElement B = getIcmpLatticeValue(Icmp->getOperand(1)); + + Constant *C = nullptr; + if (A.satisfiesPredicate(Icmp->getPredicate(), B)) + C = ConstantInt::getTrue(Icmp->getType()); + else if (A.satisfiesPredicate(Icmp->getInversePredicate(), B)) + C = ConstantInt::getFalse(Icmp->getType()); + + if (C) { + Icmp->replaceAllUsesWith(C); + DEBUG(dbgs() << "Replacing " << *Icmp << " with " << *C + << ", because of range information " << A << " " << B + << "\n"); + Icmp->eraseFromParent(); + Changed = true; + } + } + return Changed; +} + static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { Constant *Const = nullptr; if (V->getType()->isStructTy()) { std::vector<LatticeVal> IVs = Solver.getStructLatticeValueFor(V); - if (any_of(IVs, [](const LatticeVal &LV) { return LV.isOverdefined(); })) + if (llvm::any_of(IVs, + [](const LatticeVal &LV) { return LV.isOverdefined(); })) return false; std::vector<Constant *> ConstVals; auto *ST = dyn_cast<StructType>(V->getType()); @@ -1573,10 +1659,19 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { } Const = ConstantStruct::get(ST, ConstVals); } else { - LatticeVal IV = Solver.getLatticeValueFor(V); + const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); if (IV.isOverdefined()) return false; - Const = IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); + + if (IV.isConstantRange()) { + if (IV.getConstantRange().isSingleElement()) + Const = + ConstantInt::get(V->getType(), IV.asConstantInteger().getValue()); + else + return false; + } else + Const = + IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); } assert(Const && "Constant is nullptr here!"); DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); @@ -1588,7 +1683,6 @@ static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { // runSCCP() - Run the Sparse Conditional Constant Propagation algorithm, // and return true if the function was modified. -// static bool runSCCP(Function &F, const DataLayout &DL, const TargetLibraryInfo *TLI) { DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); @@ -1628,7 +1722,6 @@ static bool runSCCP(Function &F, const DataLayout &DL, // Iterate over all of the instructions in a function, replacing them with // constants if we have found them to be of constant values. - // for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { Instruction *Inst = &*BI++; if (Inst->getType()->isVoidTy() || isa<TerminatorInst>(Inst)) @@ -1659,6 +1752,7 @@ PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) { } namespace { + //===--------------------------------------------------------------------===// // /// SCCP Class - This class uses the SCCPSolver to implement a per-function @@ -1666,18 +1760,20 @@ namespace { /// class SCCPLegacyPass : public FunctionPass { public: + // Pass identification, replacement for typeid + static char ID; + + SCCPLegacyPass() : FunctionPass(ID) { + initializeSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } - static char ID; // Pass identification, replacement for typeid - SCCPLegacyPass() : FunctionPass(ID) { - initializeSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); - } // runOnFunction - Run the Sparse Conditional Constant Propagation // algorithm, and return true if the function was modified. - // bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; @@ -1687,9 +1783,11 @@ public: return runSCCP(F, DL, TLI); } }; + } // end anonymous namespace char SCCPLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(SCCPLegacyPass, "sccp", "Sparse Conditional Constant Propagation", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) @@ -1699,38 +1797,11 @@ INITIALIZE_PASS_END(SCCPLegacyPass, "sccp", // createSCCPPass - This is the public interface to this file. FunctionPass *llvm::createSCCPPass() { return new SCCPLegacyPass(); } -static bool AddressIsTaken(const GlobalValue *GV) { - // Delete any dead constantexpr klingons. - GV->removeDeadConstantUsers(); - - for (const Use &U : GV->uses()) { - const User *UR = U.getUser(); - if (const auto *SI = dyn_cast<StoreInst>(UR)) { - if (SI->getOperand(0) == GV || SI->isVolatile()) - return true; // Storing addr of GV. - } else if (isa<InvokeInst>(UR) || isa<CallInst>(UR)) { - // Make sure we are calling the function, not passing the address. - ImmutableCallSite CS(cast<Instruction>(UR)); - if (!CS.isCallee(&U)) - return true; - } else if (const auto *LI = dyn_cast<LoadInst>(UR)) { - if (LI->isVolatile()) - return true; - } else if (isa<BlockAddress>(UR)) { - // blockaddress doesn't take the address of the function, it takes addr - // of label. - } else { - return true; - } - } - return false; -} - static void findReturnsToZap(Function &F, - SmallPtrSet<Function *, 32> &AddressTakenFunctions, - SmallVector<ReturnInst *, 8> &ReturnsToZap) { + SmallVector<ReturnInst *, 8> &ReturnsToZap, + SCCPSolver &Solver) { // We can only do this if we know that nothing else can call the function. - if (!F.hasLocalLinkage() || AddressTakenFunctions.count(&F)) + if (!Solver.isArgumentTrackedFunction(&F)) return; for (BasicBlock &BB : F) @@ -1743,39 +1814,22 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, const TargetLibraryInfo *TLI) { SCCPSolver Solver(DL, TLI); - // AddressTakenFunctions - This set keeps track of the address-taken functions - // that are in the input. As IPSCCP runs through and simplifies code, - // functions that were address taken can end up losing their - // address-taken-ness. Because of this, we keep track of their addresses from - // the first pass so we can use them for the later simplification pass. - SmallPtrSet<Function*, 32> AddressTakenFunctions; - // Loop over all functions, marking arguments to those with their addresses // taken or that are external as overdefined. - // for (Function &F : M) { if (F.isDeclaration()) continue; - // If this is an exact definition of this function, then we can propagate - // information about its result into callsites of it. - // Don't touch naked functions. They may contain asm returning a - // value we don't see, so we may end up interprocedurally propagating - // the return value incorrectly. - if (F.hasExactDefinition() && !F.hasFnAttribute(Attribute::Naked)) + // Determine if we can track the function's return values. If so, add the + // function to the solver's set of return-tracked functions. + if (canTrackReturnsInterprocedurally(&F)) Solver.AddTrackedFunction(&F); - // If this function only has direct calls that we can see, we can track its - // arguments and return value aggressively, and can assume it is not called - // unless we see evidence to the contrary. - if (F.hasLocalLinkage()) { - if (F.hasAddressTaken()) { - AddressTakenFunctions.insert(&F); - } - else { - Solver.AddArgumentTrackedFunction(&F); - continue; - } + // Determine if we can track the function's arguments. If so, add the + // function to the solver's set of argument-tracked functions. + if (canTrackArgumentsInterprocedurally(&F)) { + Solver.AddArgumentTrackedFunction(&F); + continue; } // Assume the function is called. @@ -1786,13 +1840,14 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, Solver.markOverdefined(&AI); } - // Loop over global variables. We inform the solver about any internal global - // variables that do not have their 'addresses taken'. If they don't have - // their addresses taken, we can propagate constants through them. - for (GlobalVariable &G : M.globals()) - if (!G.isConstant() && G.hasLocalLinkage() && - G.hasDefinitiveInitializer() && !AddressIsTaken(&G)) + // Determine if we can track any of the module's global variables. If so, add + // the global variables we can track to the solver's set of tracked global + // variables. + for (GlobalVariable &G : M.globals()) { + G.removeDeadConstantUsers(); + if (canTrackGlobalVariableInterprocedurally(&G)) Solver.TrackValueOfGlobalVariable(&G); + } // Solve for constants. bool ResolvedUndefs = true; @@ -1809,7 +1864,6 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, // Iterate over all of the instructions in the module, replacing them with // constants if we have found them to be of constant values. - // SmallVector<BasicBlock*, 512> BlocksToErase; for (Function &F : M) { @@ -1818,9 +1872,15 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, if (Solver.isBlockExecutable(&F.front())) for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E; - ++AI) - if (!AI->use_empty() && tryToReplaceWithConstant(Solver, &*AI)) + ++AI) { + if (!AI->use_empty() && tryToReplaceWithConstant(Solver, &*AI)) { ++IPNumArgsElimed; + continue; + } + + if (!AI->use_empty() && tryToReplaceWithConstantRange(Solver, &*AI)) + ++IPNumRangeInfoUsed; + } for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { if (!Solver.isBlockExecutable(&*BB)) { @@ -1897,7 +1957,7 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, Function *F = I.first; if (I.second.isOverdefined() || F->getReturnType()->isVoidTy()) continue; - findReturnsToZap(*F, AddressTakenFunctions, ReturnsToZap); + findReturnsToZap(*F, ReturnsToZap, Solver); } for (const auto &F : Solver.getMRVFunctionsTracked()) { @@ -1905,7 +1965,7 @@ static bool runIPSCCP(Module &M, const DataLayout &DL, "The return type should be a struct"); StructType *STy = cast<StructType>(F->getReturnType()); if (Solver.isStructLatticeConstant(F, STy)) - findReturnsToZap(*F, AddressTakenFunctions, ReturnsToZap); + findReturnsToZap(*F, ReturnsToZap, Solver); } // Zap all returns which we've identified as zap to change. @@ -1943,6 +2003,7 @@ PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { } namespace { + //===--------------------------------------------------------------------===// // /// IPSCCP Class - This class implements interprocedural Sparse Conditional @@ -1969,9 +2030,11 @@ public: AU.addRequired<TargetLibraryInfoWrapperPass>(); } }; + } // end anonymous namespace char IPSCCPLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", "Interprocedural Sparse Conditional Constant Propagation", false, false) diff --git a/lib/Transforms/Scalar/SROA.cpp b/lib/Transforms/Scalar/SROA.cpp index b9cee5b2ba95..bfe3754f0769 100644 --- a/lib/Transforms/Scalar/SROA.cpp +++ b/lib/Transforms/Scalar/SROA.cpp @@ -24,28 +24,54 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Scalar/SROA.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PointerIntPair.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/PtrUseVisitor.h" -#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantFolder.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" -#include "llvm/Support/Chrono.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" @@ -55,6 +81,17 @@ #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include <algorithm> +#include <cassert> +#include <chrono> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <iterator> +#include <string> +#include <tuple> +#include <utility> +#include <vector> #ifndef NDEBUG // We only use this for a debug check. @@ -87,11 +124,18 @@ static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices", static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), cl::Hidden); +/// Hidden option to allow more aggressive splitting. +static cl::opt<bool> +SROASplitNonWholeAllocaSlices("sroa-split-nonwhole-alloca-slices", + cl::init(false), cl::Hidden); + namespace { + /// \brief A custom IRBuilder inserter which prefixes all names, but only in /// Assert builds. class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter { std::string Prefix; + const Twine getNameWithPrefix(const Twine &Name) const { return Name.isTriviallyEmpty() ? Name : Prefix + Name; } @@ -107,11 +151,9 @@ protected: } }; -/// \brief Provide a typedef for IRBuilder that drops names in release builds. -using IRBuilderTy = llvm::IRBuilder<ConstantFolder, IRBuilderPrefixedInserter>; -} +/// \brief Provide a type for IRBuilder that drops names in release builds. +using IRBuilderTy = IRBuilder<ConstantFolder, IRBuilderPrefixedInserter>; -namespace { /// \brief A used slice of an alloca. /// /// This structure represents a slice of an alloca used by some instruction. It @@ -120,17 +162,18 @@ namespace { /// or not when forming partitions of the alloca. class Slice { /// \brief The beginning offset of the range. - uint64_t BeginOffset; + uint64_t BeginOffset = 0; /// \brief The ending offset, not included in the range. - uint64_t EndOffset; + uint64_t EndOffset = 0; /// \brief Storage for both the use of this slice and whether it can be /// split. PointerIntPair<Use *, 1, bool> UseAndIsSplittable; public: - Slice() : BeginOffset(), EndOffset() {} + Slice() = default; + Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable) : BeginOffset(BeginOffset), EndOffset(EndOffset), UseAndIsSplittable(U, IsSplittable) {} @@ -180,12 +223,15 @@ public: } bool operator!=(const Slice &RHS) const { return !operator==(RHS); } }; + } // end anonymous namespace namespace llvm { + template <typename T> struct isPodLike; template <> struct isPodLike<Slice> { static const bool value = true; }; -} + +} // end namespace llvm /// \brief Representation of the alloca slices. /// @@ -207,13 +253,15 @@ public: /// \brief Support for iterating over the slices. /// @{ - typedef SmallVectorImpl<Slice>::iterator iterator; - typedef iterator_range<iterator> range; + using iterator = SmallVectorImpl<Slice>::iterator; + using range = iterator_range<iterator>; + iterator begin() { return Slices.begin(); } iterator end() { return Slices.end(); } - typedef SmallVectorImpl<Slice>::const_iterator const_iterator; - typedef iterator_range<const_iterator> const_range; + using const_iterator = SmallVectorImpl<Slice>::const_iterator; + using const_range = iterator_range<const_iterator>; + const_iterator begin() const { return Slices.begin(); } const_iterator end() const { return Slices.end(); } /// @} @@ -264,6 +312,7 @@ public: private: template <typename DerivedT, typename RetT = void> class BuilderBase; class SliceBuilder; + friend class AllocaSlices::SliceBuilder; #if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) @@ -320,7 +369,7 @@ private: friend class AllocaSlices; friend class AllocaSlices::partition_iterator; - typedef AllocaSlices::iterator iterator; + using iterator = AllocaSlices::iterator; /// \brief The beginning and ending offsets of the alloca for this /// partition. @@ -403,12 +452,12 @@ class AllocaSlices::partition_iterator /// \brief We also need to keep track of the maximum split end offset seen. /// FIXME: Do we really? - uint64_t MaxSplitSliceEndOffset; + uint64_t MaxSplitSliceEndOffset = 0; /// \brief Sets the partition to be empty at given iterator, and sets the /// end iterator. partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE) - : P(SI), SE(SE), MaxSplitSliceEndOffset(0) { + : P(SI), SE(SE) { // If not already at the end, advance our state to form the initial // partition. if (SI != SE) @@ -432,19 +481,21 @@ class AllocaSlices::partition_iterator // Remove the uses which have ended in the prior partition. This // cannot change the max split slice end because we just checked that // the prior partition ended prior to that max. - P.SplitTails.erase( - remove_if(P.SplitTails, - [&](Slice *S) { return S->endOffset() <= P.EndOffset; }), - P.SplitTails.end()); - assert(any_of(P.SplitTails, - [&](Slice *S) { - return S->endOffset() == MaxSplitSliceEndOffset; - }) && + P.SplitTails.erase(llvm::remove_if(P.SplitTails, + [&](Slice *S) { + return S->endOffset() <= + P.EndOffset; + }), + P.SplitTails.end()); + assert(llvm::any_of(P.SplitTails, + [&](Slice *S) { + return S->endOffset() == MaxSplitSliceEndOffset; + }) && "Could not find the current max split slice offset!"); - assert(all_of(P.SplitTails, - [&](Slice *S) { - return S->endOffset() <= MaxSplitSliceEndOffset; - }) && + assert(llvm::all_of(P.SplitTails, + [&](Slice *S) { + return S->endOffset() <= MaxSplitSliceEndOffset; + }) && "Max split slice end offset is not actually the max!"); } } @@ -608,7 +659,8 @@ static Value *foldPHINodeOrSelectInst(Instruction &I) { class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> { friend class PtrUseVisitor<SliceBuilder>; friend class InstVisitor<SliceBuilder>; - typedef PtrUseVisitor<SliceBuilder> Base; + + using Base = PtrUseVisitor<SliceBuilder>; const uint64_t AllocSize; AllocaSlices &AS; @@ -996,8 +1048,9 @@ AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) return; } - Slices.erase(remove_if(Slices, [](const Slice &S) { return S.isDead(); }), - Slices.end()); + Slices.erase( + llvm::remove_if(Slices, [](const Slice &S) { return S.isDead(); }), + Slices.end()); #ifndef NDEBUG if (SROARandomShuffleSlices) { @@ -1820,11 +1873,12 @@ static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { // do that until all the backends are known to produce good code for all // integer vector types. if (!HaveCommonEltTy) { - CandidateTys.erase(remove_if(CandidateTys, - [](VectorType *VTy) { - return !VTy->getElementType()->isIntegerTy(); - }), - CandidateTys.end()); + CandidateTys.erase( + llvm::remove_if(CandidateTys, + [](VectorType *VTy) { + return !VTy->getElementType()->isIntegerTy(); + }), + CandidateTys.end()); // If there were no integer vector types, give up. if (CandidateTys.empty()) @@ -2151,8 +2205,9 @@ static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, class llvm::sroa::AllocaSliceRewriter : public InstVisitor<AllocaSliceRewriter, bool> { // Befriend the base class so it can delegate to private visit methods. - friend class llvm::InstVisitor<AllocaSliceRewriter, bool>; - typedef llvm::InstVisitor<AllocaSliceRewriter, bool> Base; + friend class InstVisitor<AllocaSliceRewriter, bool>; + + using Base = InstVisitor<AllocaSliceRewriter, bool>; const DataLayout &DL; AllocaSlices &AS; @@ -2182,16 +2237,18 @@ class llvm::sroa::AllocaSliceRewriter // The original offset of the slice currently being rewritten relative to // the original alloca. - uint64_t BeginOffset, EndOffset; + uint64_t BeginOffset = 0; + uint64_t EndOffset = 0; + // The new offsets of the slice currently being rewritten relative to the // original alloca. uint64_t NewBeginOffset, NewEndOffset; uint64_t SliceSize; - bool IsSplittable; - bool IsSplit; - Use *OldUse; - Instruction *OldPtr; + bool IsSplittable = false; + bool IsSplit = false; + Use *OldUse = nullptr; + Instruction *OldPtr = nullptr; // Track post-rewrite users which are PHI nodes and Selects. SmallSetVector<PHINode *, 8> &PHIUsers; @@ -2221,8 +2278,7 @@ public: VecTy(PromotableVecTy), ElementTy(VecTy ? VecTy->getElementType() : nullptr), ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0), - BeginOffset(), EndOffset(), IsSplittable(), IsSplit(), OldUse(), - OldPtr(), PHIUsers(PHIUsers), SelectUsers(SelectUsers), + PHIUsers(PHIUsers), SelectUsers(SelectUsers), IRB(NewAI.getContext(), ConstantFolder()) { if (VecTy) { assert((DL.getTypeSizeInBits(ElementTy) % 8) == 0 && @@ -2987,6 +3043,7 @@ private: }; namespace { + /// \brief Visitor to rewrite aggregate loads and stores as scalar. /// /// This pass aggressively rewrites all aggregate loads and stores on @@ -2994,7 +3051,7 @@ namespace { /// with scalar loads and stores. class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> { // Befriend the base class so it can delegate to private visit methods. - friend class llvm::InstVisitor<AggLoadStoreRewriter, bool>; + friend class InstVisitor<AggLoadStoreRewriter, bool>; /// Queue of pointer uses to analyze and potentially rewrite. SmallVector<Use *, 8> Queue; @@ -3037,12 +3094,15 @@ private: protected: /// The builder used to form new instructions. IRBuilderTy IRB; + /// The indices which to be used with insert- or extractvalue to select the /// appropriate value within the aggregate. SmallVector<unsigned, 4> Indices; + /// The indices to a GEP instruction which will move Ptr to the correct slot /// within the aggregate. SmallVector<Value *, 4> GEPIndices; + /// The base pointer of the original op, used as a base for GEPing the /// split operations. Value *Ptr; @@ -3193,7 +3253,8 @@ private: return false; } }; -} + +} // end anonymous namespace /// \brief Strip aggregate type wrapping. /// @@ -3485,58 +3546,60 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // match relative to their starting offset. We have to verify this prior to // any rewriting. Stores.erase( - remove_if(Stores, - [&UnsplittableLoads, &SplitOffsetsMap](StoreInst *SI) { - // Lookup the load we are storing in our map of split - // offsets. - auto *LI = cast<LoadInst>(SI->getValueOperand()); - // If it was completely unsplittable, then we're done, - // and this store can't be pre-split. - if (UnsplittableLoads.count(LI)) - return true; - - auto LoadOffsetsI = SplitOffsetsMap.find(LI); - if (LoadOffsetsI == SplitOffsetsMap.end()) - return false; // Unrelated loads are definitely safe. - auto &LoadOffsets = LoadOffsetsI->second; - - // Now lookup the store's offsets. - auto &StoreOffsets = SplitOffsetsMap[SI]; - - // If the relative offsets of each split in the load and - // store match exactly, then we can split them and we - // don't need to remove them here. - if (LoadOffsets.Splits == StoreOffsets.Splits) - return false; - - DEBUG(dbgs() << " Mismatched splits for load and store:\n" - << " " << *LI << "\n" - << " " << *SI << "\n"); - - // We've found a store and load that we need to split - // with mismatched relative splits. Just give up on them - // and remove both instructions from our list of - // candidates. - UnsplittableLoads.insert(LI); - return true; - }), + llvm::remove_if(Stores, + [&UnsplittableLoads, &SplitOffsetsMap](StoreInst *SI) { + // Lookup the load we are storing in our map of split + // offsets. + auto *LI = cast<LoadInst>(SI->getValueOperand()); + // If it was completely unsplittable, then we're done, + // and this store can't be pre-split. + if (UnsplittableLoads.count(LI)) + return true; + + auto LoadOffsetsI = SplitOffsetsMap.find(LI); + if (LoadOffsetsI == SplitOffsetsMap.end()) + return false; // Unrelated loads are definitely safe. + auto &LoadOffsets = LoadOffsetsI->second; + + // Now lookup the store's offsets. + auto &StoreOffsets = SplitOffsetsMap[SI]; + + // If the relative offsets of each split in the load and + // store match exactly, then we can split them and we + // don't need to remove them here. + if (LoadOffsets.Splits == StoreOffsets.Splits) + return false; + + DEBUG(dbgs() + << " Mismatched splits for load and store:\n" + << " " << *LI << "\n" + << " " << *SI << "\n"); + + // We've found a store and load that we need to split + // with mismatched relative splits. Just give up on them + // and remove both instructions from our list of + // candidates. + UnsplittableLoads.insert(LI); + return true; + }), Stores.end()); // Now we have to go *back* through all the stores, because a later store may // have caused an earlier store's load to become unsplittable and if it is // unsplittable for the later store, then we can't rely on it being split in // the earlier store either. - Stores.erase(remove_if(Stores, - [&UnsplittableLoads](StoreInst *SI) { - auto *LI = cast<LoadInst>(SI->getValueOperand()); - return UnsplittableLoads.count(LI); - }), + Stores.erase(llvm::remove_if(Stores, + [&UnsplittableLoads](StoreInst *SI) { + auto *LI = + cast<LoadInst>(SI->getValueOperand()); + return UnsplittableLoads.count(LI); + }), Stores.end()); // Once we've established all the loads that can't be split for some reason, // filter any that made it into our list out. - Loads.erase(remove_if(Loads, - [&UnsplittableLoads](LoadInst *LI) { - return UnsplittableLoads.count(LI); - }), + Loads.erase(llvm::remove_if(Loads, + [&UnsplittableLoads](LoadInst *LI) { + return UnsplittableLoads.count(LI); + }), Loads.end()); // If no loads or stores are left, there is no pre-splitting to be done for @@ -3804,7 +3867,8 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { } // Remove the killed slices that have ben pre-split. - AS.erase(remove_if(AS, [](const Slice &S) { return S.isDead(); }), AS.end()); + AS.erase(llvm::remove_if(AS, [](const Slice &S) { return S.isDead(); }), + AS.end()); // Insert our new slices. This will sort and merge them into the sorted // sequence. @@ -3819,7 +3883,7 @@ bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { // Finally, don't try to promote any allocas that new require re-splitting. // They have already been added to the worklist above. PromotableAllocas.erase( - remove_if( + llvm::remove_if( PromotableAllocas, [&](AllocaInst *AI) { return ResplitPromotableAllocas.count(AI); }), PromotableAllocas.end()); @@ -3989,27 +4053,58 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // First try to pre-split loads and stores. Changed |= presplitLoadsAndStores(AI, AS); - // Now that we have identified any pre-splitting opportunities, mark any - // splittable (non-whole-alloca) loads and stores as unsplittable. If we fail - // to split these during pre-splitting, we want to force them to be - // rewritten into a partition. + // Now that we have identified any pre-splitting opportunities, + // mark loads and stores unsplittable except for the following case. + // We leave a slice splittable if all other slices are disjoint or fully + // included in the slice, such as whole-alloca loads and stores. + // If we fail to split these during pre-splitting, we want to force them + // to be rewritten into a partition. bool IsSorted = true; - for (Slice &S : AS) { - if (!S.isSplittable()) - continue; - // FIXME: We currently leave whole-alloca splittable loads and stores. This - // used to be the only splittable loads and stores and we need to be - // confident that the above handling of splittable loads and stores is - // completely sufficient before we forcibly disable the remaining handling. - if (S.beginOffset() == 0 && - S.endOffset() >= DL.getTypeAllocSize(AI.getAllocatedType())) - continue; - if (isa<LoadInst>(S.getUse()->getUser()) || - isa<StoreInst>(S.getUse()->getUser())) { - S.makeUnsplittable(); - IsSorted = false; + + uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType()); + const uint64_t MaxBitVectorSize = 1024; + if (SROASplitNonWholeAllocaSlices && AllocaSize <= MaxBitVectorSize) { + // If a byte boundary is included in any load or store, a slice starting or + // ending at the boundary is not splittable. + SmallBitVector SplittableOffset(AllocaSize + 1, true); + for (Slice &S : AS) + for (unsigned O = S.beginOffset() + 1; + O < S.endOffset() && O < AllocaSize; O++) + SplittableOffset.reset(O); + + for (Slice &S : AS) { + if (!S.isSplittable()) + continue; + + if ((S.beginOffset() > AllocaSize || SplittableOffset[S.beginOffset()]) && + (S.endOffset() > AllocaSize || SplittableOffset[S.endOffset()])) + continue; + + if (isa<LoadInst>(S.getUse()->getUser()) || + isa<StoreInst>(S.getUse()->getUser())) { + S.makeUnsplittable(); + IsSorted = false; + } } } + else { + // We only allow whole-alloca splittable loads and stores + // for a large alloca to avoid creating too large BitVector. + for (Slice &S : AS) { + if (!S.isSplittable()) + continue; + + if (S.beginOffset() == 0 && S.endOffset() >= AllocaSize) + continue; + + if (isa<LoadInst>(S.getUse()->getUser()) || + isa<StoreInst>(S.getUse()->getUser())) { + S.makeUnsplittable(); + IsSorted = false; + } + } + } + if (!IsSorted) std::sort(AS.begin(), AS.end()); @@ -4044,9 +4139,11 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { // Migrate debug information from the old alloca to the new alloca(s) // and the individual partitions. - if (DbgDeclareInst *DbgDecl = FindAllocaDbgDeclare(&AI)) { - auto *Var = DbgDecl->getVariable(); - auto *Expr = DbgDecl->getExpression(); + TinyPtrVector<DbgInfoIntrinsic *> DbgDeclares = FindDbgAddrUses(&AI); + if (!DbgDeclares.empty()) { + auto *Var = DbgDeclares.front()->getVariable(); + auto *Expr = DbgDeclares.front()->getExpression(); + auto VarSize = Var->getSizeInBits(); DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()); for (auto Fragment : Fragments) { @@ -4062,21 +4159,43 @@ bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { uint64_t Size = Fragment.Size; if (ExprFragment) { uint64_t AbsEnd = - ExprFragment->OffsetInBits + ExprFragment->SizeInBits; + ExprFragment->OffsetInBits + ExprFragment->SizeInBits; if (Start >= AbsEnd) // No need to describe a SROAed padding. continue; Size = std::min(Size, AbsEnd - Start); } - FragmentExpr = DIB.createFragmentExpression(Start, Size); + // The new, smaller fragment is stenciled out from the old fragment. + if (auto OrigFragment = FragmentExpr->getFragmentInfo()) { + assert(Start >= OrigFragment->OffsetInBits && + "new fragment is outside of original fragment"); + Start -= OrigFragment->OffsetInBits; + } + + // The alloca may be larger than the variable. + if (VarSize) { + if (Size > *VarSize) + Size = *VarSize; + if (Size == 0 || Start + Size > *VarSize) + continue; + } + + // Avoid creating a fragment expression that covers the entire variable. + if (!VarSize || *VarSize != Size) { + if (auto E = + DIExpression::createFragmentExpression(Expr, Start, Size)) + FragmentExpr = *E; + else + continue; + } } - // Remove any existing dbg.declare intrinsic describing the same alloca. - if (DbgDeclareInst *OldDDI = FindAllocaDbgDeclare(Fragment.Alloca)) - OldDDI->eraseFromParent(); + // Remove any existing intrinsics describing the same alloca. + for (DbgInfoIntrinsic *OldDII : FindDbgAddrUses(Fragment.Alloca)) + OldDII->eraseFromParent(); DIB.insertDeclare(Fragment.Alloca, Var, FragmentExpr, - DbgDecl->getDebugLoc(), &AI); + DbgDeclares.front()->getDebugLoc(), &AI); } } return Changed; @@ -4175,12 +4294,22 @@ bool SROA::runOnAlloca(AllocaInst &AI) { /// /// We also record the alloca instructions deleted here so that they aren't /// subsequently handed to mem2reg to promote. -void SROA::deleteDeadInstructions( +bool SROA::deleteDeadInstructions( SmallPtrSetImpl<AllocaInst *> &DeletedAllocas) { + bool Changed = false; while (!DeadInsts.empty()) { Instruction *I = DeadInsts.pop_back_val(); DEBUG(dbgs() << "Deleting dead instruction: " << *I << "\n"); + // If the instruction is an alloca, find the possible dbg.declare connected + // to it, and remove it too. We must do this before calling RAUW or we will + // not be able to find it. + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { + DeletedAllocas.insert(AI); + for (DbgInfoIntrinsic *OldDII : FindDbgAddrUses(AI)) + OldDII->eraseFromParent(); + } + I->replaceAllUsesWith(UndefValue::get(I->getType())); for (Use &Operand : I->operands()) @@ -4191,15 +4320,11 @@ void SROA::deleteDeadInstructions( DeadInsts.insert(U); } - if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { - DeletedAllocas.insert(AI); - if (DbgDeclareInst *DbgDecl = FindAllocaDbgDeclare(AI)) - DbgDecl->eraseFromParent(); - } - ++NumDeleted; I->eraseFromParent(); + Changed = true; } + return Changed; } /// \brief Promote the allocas, using the best available technique. @@ -4241,7 +4366,7 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, do { while (!Worklist.empty()) { Changed |= runOnAlloca(*Worklist.pop_back_val()); - deleteDeadInstructions(DeletedAllocas); + Changed |= deleteDeadInstructions(DeletedAllocas); // Remove the deleted allocas from various lists so that we don't try to // continue processing them. @@ -4249,7 +4374,7 @@ PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, auto IsInSet = [&](AllocaInst *AI) { return DeletedAllocas.count(AI); }; Worklist.remove_if(IsInSet); PostPromotionWorklist.remove_if(IsInSet); - PromotableAllocas.erase(remove_if(PromotableAllocas, IsInSet), + PromotableAllocas.erase(llvm::remove_if(PromotableAllocas, IsInSet), PromotableAllocas.end()); DeletedAllocas.clear(); } @@ -4284,9 +4409,12 @@ class llvm::sroa::SROALegacyPass : public FunctionPass { SROA Impl; public: + static char ID; + SROALegacyPass() : FunctionPass(ID) { initializeSROALegacyPassPass(*PassRegistry::getPassRegistry()); } + bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; @@ -4296,6 +4424,7 @@ public: getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F)); return !PA.areAllPreserved(); } + void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<DominatorTreeWrapperPass>(); @@ -4304,7 +4433,6 @@ public: } StringRef getPassName() const override { return "SROA"; } - static char ID; }; char SROALegacyPass::ID = 0; diff --git a/lib/Transforms/Scalar/Scalar.cpp b/lib/Transforms/Scalar/Scalar.cpp index ce6f93eb0c15..3b99ddff2e06 100644 --- a/lib/Transforms/Scalar/Scalar.cpp +++ b/lib/Transforms/Scalar/Scalar.cpp @@ -35,11 +35,13 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeADCELegacyPassPass(Registry); initializeBDCELegacyPassPass(Registry); initializeAlignmentFromAssumptionsPass(Registry); + initializeCallSiteSplittingLegacyPassPass(Registry); initializeConstantHoistingLegacyPassPass(Registry); initializeConstantPropagationPass(Registry); initializeCorrelatedValuePropagationPass(Registry); initializeDCELegacyPassPass(Registry); initializeDeadInstEliminationPass(Registry); + initializeDivRemPairsLegacyPassPass(Registry); initializeScalarizerPass(Registry); initializeDSELegacyPassPass(Registry); initializeGuardWideningLegacyPassPass(Registry); @@ -73,17 +75,17 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLowerExpectIntrinsicPass(Registry); initializeLowerGuardIntrinsicLegacyPassPass(Registry); initializeMemCpyOptLegacyPassPass(Registry); + initializeMergeICmpsPass(Registry); initializeMergedLoadStoreMotionLegacyPassPass(Registry); initializeNaryReassociateLegacyPassPass(Registry); initializePartiallyInlineLibCallsLegacyPassPass(Registry); initializeReassociateLegacyPassPass(Registry); initializeRegToMemPass(Registry); - initializeRewriteStatepointsForGCPass(Registry); + initializeRewriteStatepointsForGCLegacyPassPass(Registry); initializeSCCPLegacyPassPass(Registry); initializeIPSCCPLegacyPassPass(Registry); initializeSROALegacyPassPass(Registry); initializeCFGSimplifyPassPass(Registry); - initializeLateCFGSimplifyPassPass(Registry); initializeStructurizeCFGPass(Registry); initializeSimpleLoopUnswitchLegacyPassPass(Registry); initializeSinkingLegacyPassPass(Registry); @@ -98,6 +100,8 @@ void llvm::initializeScalarOpts(PassRegistry &Registry) { initializeLoopLoadEliminationPass(Registry); initializeLoopSimplifyCFGLegacyPassPass(Registry); initializeLoopVersioningPassPass(Registry); + initializeEntryExitInstrumenterPass(Registry); + initializePostInlineEntryExitInstrumenterPass(Registry); } void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { @@ -117,11 +121,7 @@ void LLVMAddAlignmentFromAssumptionsPass(LLVMPassManagerRef PM) { } void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createCFGSimplificationPass()); -} - -void LLVMAddLateCFGSimplificationPass(LLVMPassManagerRef PM) { - unwrap(PM)->add(createLateCFGSimplificationPass()); + unwrap(PM)->add(createCFGSimplificationPass(1, false, false, true)); } void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) { diff --git a/lib/Transforms/Scalar/Scalarizer.cpp b/lib/Transforms/Scalar/Scalarizer.cpp index d11855f2f3a9..34ed126155be 100644 --- a/lib/Transforms/Scalar/Scalarizer.cpp +++ b/lib/Transforms/Scalar/Scalarizer.cpp @@ -1,4 +1,4 @@ -//===--- Scalarizer.cpp - Scalarize vector operations ---------------------===// +//===- Scalarizer.cpp - Scalarize vector operations -----------------------===// // // The LLVM Compiler Infrastructure // @@ -14,36 +14,59 @@ // //===----------------------------------------------------------------------===// -#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Options.h" #include "llvm/Transforms/Scalar.h" -#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <map> +#include <utility> using namespace llvm; #define DEBUG_TYPE "scalarizer" namespace { + // Used to store the scattered form of a vector. -typedef SmallVector<Value *, 8> ValueVector; +using ValueVector = SmallVector<Value *, 8>; // Used to map a vector Value to its scattered form. We use std::map // because we want iterators to persist across insertion and because the // values are relatively large. -typedef std::map<Value *, ValueVector> ScatterMap; +using ScatterMap = std::map<Value *, ValueVector>; // Lists Instructions that have been replaced with scalar implementations, // along with a pointer to their scattered forms. -typedef SmallVector<std::pair<Instruction *, ValueVector *>, 16> GatherList; +using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>; // Provides a very limited vector-like interface for lazily accessing one // component of a scattered vector or vector pointer. class Scatterer { public: - Scatterer() {} + Scatterer() = default; // Scatter V into Size components. If new instructions are needed, // insert them before BBI in BB. If Cache is nonnull, use it to cache @@ -71,10 +94,12 @@ private: // called Name that compares X and Y in the same way as FCI. struct FCmpSplitter { FCmpSplitter(FCmpInst &fci) : FCI(fci) {} + Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, const Twine &Name) const { return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name); } + FCmpInst &FCI; }; @@ -82,10 +107,12 @@ struct FCmpSplitter { // called Name that compares X and Y in the same way as ICI. struct ICmpSplitter { ICmpSplitter(ICmpInst &ici) : ICI(ici) {} + Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, const Twine &Name) const { return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name); } + ICmpInst &ICI; }; @@ -93,16 +120,18 @@ struct ICmpSplitter { // a binary operator like BO called Name with operands X and Y. struct BinarySplitter { BinarySplitter(BinaryOperator &bo) : BO(bo) {} + Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, const Twine &Name) const { return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name); } + BinaryOperator &BO; }; // Information about a load or store that we're scalarizing. struct VectorLayout { - VectorLayout() : VecTy(nullptr), ElemTy(nullptr), VecAlign(0), ElemSize(0) {} + VectorLayout() = default; // Return the alignment of element I. uint64_t getElemAlign(unsigned I) { @@ -110,16 +139,16 @@ struct VectorLayout { } // The type of the vector. - VectorType *VecTy; + VectorType *VecTy = nullptr; // The type of each element. - Type *ElemTy; + Type *ElemTy = nullptr; // The alignment of the vector. - uint64_t VecAlign; + uint64_t VecAlign = 0; // The size of each element. - uint64_t ElemSize; + uint64_t ElemSize = 0; }; class Scalarizer : public FunctionPass, @@ -127,8 +156,7 @@ class Scalarizer : public FunctionPass, public: static char ID; - Scalarizer() : - FunctionPass(ID) { + Scalarizer() : FunctionPass(ID) { initializeScalarizerPass(*PassRegistry::getPassRegistry()); } @@ -137,19 +165,19 @@ public: // InstVisitor methods. They return true if the instruction was scalarized, // false if nothing changed. - bool visitInstruction(Instruction &) { return false; } + bool visitInstruction(Instruction &I) { return false; } bool visitSelectInst(SelectInst &SI); - bool visitICmpInst(ICmpInst &); - bool visitFCmpInst(FCmpInst &); - bool visitBinaryOperator(BinaryOperator &); - bool visitGetElementPtrInst(GetElementPtrInst &); - bool visitCastInst(CastInst &); - bool visitBitCastInst(BitCastInst &); - bool visitShuffleVectorInst(ShuffleVectorInst &); - bool visitPHINode(PHINode &); - bool visitLoadInst(LoadInst &); - bool visitStoreInst(StoreInst &); - bool visitCallInst(CallInst &I); + bool visitICmpInst(ICmpInst &ICI); + bool visitFCmpInst(FCmpInst &FCI); + bool visitBinaryOperator(BinaryOperator &BO); + bool visitGetElementPtrInst(GetElementPtrInst &GEPI); + bool visitCastInst(CastInst &CI); + bool visitBitCastInst(BitCastInst &BCI); + bool visitShuffleVectorInst(ShuffleVectorInst &SVI); + bool visitPHINode(PHINode &PHI); + bool visitLoadInst(LoadInst &LI); + bool visitStoreInst(StoreInst &SI); + bool visitCallInst(CallInst &ICI); static void registerOptions() { // This is disabled by default because having separate loads and stores @@ -162,11 +190,12 @@ public: } private: - Scatterer scatter(Instruction *, Value *); - void gather(Instruction *, const ValueVector &); + Scatterer scatter(Instruction *Point, Value *V); + void gather(Instruction *Op, const ValueVector &CV); bool canTransferMetadata(unsigned Kind); - void transferMetadata(Instruction *, const ValueVector &); - bool getVectorLayout(Type *, unsigned, VectorLayout &, const DataLayout &); + void transferMetadata(Instruction *Op, const ValueVector &CV); + bool getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout, + const DataLayout &DL); bool finish(); template<typename T> bool splitBinary(Instruction &, const T &); @@ -179,9 +208,10 @@ private: bool ScalarizeLoadStore; }; -char Scalarizer::ID = 0; } // end anonymous namespace +char Scalarizer::ID = 0; + INITIALIZE_PASS_WITH_OPTIONS(Scalarizer, "scalarizer", "Scalarize vector operations", false, false) @@ -222,7 +252,7 @@ Value *Scatterer::operator[](unsigned I) { // Search through a chain of InsertElementInsts looking for element I. // Record other elements in the cache. The new V is still suitable // for all uncached indices. - for (;;) { + while (true) { InsertElementInst *Insert = dyn_cast<InsertElementInst>(V); if (!Insert) break; diff --git a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp index 84675f41cdd5..209821ff21d7 100644 --- a/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp +++ b/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -1,4 +1,4 @@ -//===-- SeparateConstOffsetFromGEP.cpp - ------------------------*- C++ -*-===// +//===- SeparateConstOffsetFromGEP.cpp -------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -156,27 +156,44 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" -#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" -#include "llvm/Target/TargetSubtargetInfo.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> +#include <string> using namespace llvm; using namespace llvm::PatternMatch; @@ -185,6 +202,7 @@ static cl::opt<bool> DisableSeparateConstOffsetFromGEP( "disable-separate-const-offset-from-gep", cl::init(false), cl::desc("Do not separate the constant offset from a GEP instruction"), cl::Hidden); + // Setting this flag may emit false positives when the input module already // contains dead instructions. Therefore, we set it only in unit tests that are // free of dead code. @@ -219,6 +237,7 @@ public: /// garbage-collect unused instructions in UserChain. static Value *Extract(Value *Idx, GetElementPtrInst *GEP, User *&UserChainTail, const DominatorTree *DT); + /// Looks for a constant offset from the given GEP index without extracting /// it. It returns the numeric value of the extracted constant offset (0 if /// failed). The meaning of the arguments are the same as Extract. @@ -229,6 +248,7 @@ private: ConstantOffsetExtractor(Instruction *InsertionPt, const DominatorTree *DT) : IP(InsertionPt), DL(InsertionPt->getModule()->getDataLayout()), DT(DT) { } + /// Searches the expression that computes V for a non-zero constant C s.t. /// V can be reassociated into the form V' + C. If the searching is /// successful, returns C and update UserChain as a def-use chain from C to V; @@ -244,9 +264,11 @@ private: /// non-negative. Levaraging this, we can better split /// inbounds GEPs. APInt find(Value *V, bool SignExtended, bool ZeroExtended, bool NonNegative); + /// A helper function to look into both operands of a binary operator. APInt findInEitherOperand(BinaryOperator *BO, bool SignExtended, bool ZeroExtended); + /// After finding the constant offset C from the GEP index I, we build a new /// index I' s.t. I' + C = I. This function builds and returns the new /// index I' according to UserChain produced by function "find". @@ -263,6 +285,7 @@ private: /// (sext(a) + sext(b)) + 5. /// Given this form, we know I' is sext(a) + sext(b). Value *rebuildWithoutConstOffset(); + /// After the first step of rebuilding the GEP index without the constant /// offset, distribute s/zext to the operands of all operators in UserChain. /// e.g., zext(sext(a + (b + 5)) (assuming no overflow) => @@ -279,8 +302,10 @@ private: /// UserChain.size() - 1, and is decremented during /// the recursion. Value *distributeExtsAndCloneChain(unsigned ChainIndex); + /// Reassociates the GEP index to the form I' + C and returns I'. Value *removeConstOffset(unsigned ChainIndex); + /// A helper function to apply ExtInsts, a list of s/zext, to value V. /// e.g., if ExtInsts = [sext i32 to i64, zext i16 to i32], this function /// returns "sext i32 (zext i16 V to i32) to i64". @@ -303,10 +328,14 @@ private: /// /// This path helps to rebuild the new GEP index. SmallVector<User *, 8> UserChain; + /// A data structure used in rebuildWithoutConstOffset. Contains all /// sext/zext instructions along UserChain. SmallVector<CastInst *, 16> ExtInsts; - Instruction *IP; /// Insertion position of cloned instructions. + + /// Insertion position of cloned instructions. + Instruction *IP; + const DataLayout &DL; const DominatorTree *DT; }; @@ -317,9 +346,10 @@ private: class SeparateConstOffsetFromGEP : public FunctionPass { public: static char ID; + SeparateConstOffsetFromGEP(const TargetMachine *TM = nullptr, bool LowerGEP = false) - : FunctionPass(ID), DL(nullptr), DT(nullptr), TM(TM), LowerGEP(LowerGEP) { + : FunctionPass(ID), TM(TM), LowerGEP(LowerGEP) { initializeSeparateConstOffsetFromGEPPass(*PassRegistry::getPassRegistry()); } @@ -336,12 +366,14 @@ public: DL = &M.getDataLayout(); return false; } + bool runOnFunction(Function &F) override; private: /// Tries to split the given GEP into a variadic base and a constant offset, /// and returns true if the splitting succeeds. bool splitGEP(GetElementPtrInst *GEP); + /// Lower a GEP with multiple indices into multiple GEPs with a single index. /// Function splitGEP already split the original GEP into a variadic part and /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the @@ -351,6 +383,7 @@ private: /// \p AccumulativeByteOffset The constant offset. void lowerToSingleIndexGEPs(GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset); + /// Lower a GEP with multiple indices into ptrtoint+arithmetics+inttoptr form. /// Function splitGEP already split the original GEP into a variadic part and /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the @@ -360,12 +393,14 @@ private: /// \p AccumulativeByteOffset The constant offset. void lowerToArithmetics(GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset); + /// Finds the constant offset within each index and accumulates them. If /// LowerGEP is true, it finds in indices of both sequential and structure /// types, otherwise it only finds in sequential indices. The output /// NeedsExtraction indicates whether we successfully find a non-zero constant /// offset. int64_t accumulateByteOffset(GetElementPtrInst *GEP, bool &NeedsExtraction); + /// Canonicalize array indices to pointer-size integers. This helps to /// simplify the logic of splitting a GEP. For example, if a + b is a /// pointer-size integer, we have @@ -382,6 +417,7 @@ private: /// /// Verified in @i32_add in split-gep.ll bool canonicalizeArrayIndicesToPointerSize(GetElementPtrInst *GEP); + /// Optimize sext(a)+sext(b) to sext(a+b) when a+b can't sign overflow. /// SeparateConstOffsetFromGEP distributes a sext to leaves before extracting /// the constant offset. After extraction, it becomes desirable to reunion the @@ -392,8 +428,10 @@ private: /// => constant extraction &a[sext(i) + sext(j)] + 5 /// => reunion &a[sext(i +nsw j)] + 5 bool reuniteExts(Function &F); + /// A helper that reunites sexts in an instruction. bool reuniteExts(Instruction *I); + /// Find the closest dominator of <Dominatee> that is equivalent to <Key>. Instruction *findClosestMatchingDominator(const SCEV *Key, Instruction *Dominatee); @@ -401,27 +439,33 @@ private: void verifyNoDeadCode(Function &F); bool hasMoreThanOneUseInLoop(Value *v, Loop *L); + // Swap the index operand of two GEP. void swapGEPOperand(GetElementPtrInst *First, GetElementPtrInst *Second); + // Check if it is safe to swap operand of two GEP. bool isLegalToSwapOperand(GetElementPtrInst *First, GetElementPtrInst *Second, Loop *CurLoop); - const DataLayout *DL; - DominatorTree *DT; + const DataLayout *DL = nullptr; + DominatorTree *DT = nullptr; ScalarEvolution *SE; const TargetMachine *TM; LoopInfo *LI; TargetLibraryInfo *TLI; + /// Whether to lower a GEP with multiple indices into arithmetic operations or /// multiple GEPs with a single index. bool LowerGEP; + DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingExprs; }; -} // anonymous namespace + +} // end anonymous namespace char SeparateConstOffsetFromGEP::ID = 0; + INITIALIZE_PASS_BEGIN( SeparateConstOffsetFromGEP, "separate-const-offset-from-gep", "Split GEPs to a variadic base and a constant offset for better CSE", false, diff --git a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp index aaab5857e0f1..3d0fca0bc3a5 100644 --- a/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp +++ b/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -17,6 +17,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" #include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopPass.h" @@ -28,6 +29,7 @@ #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Use.h" #include "llvm/IR/Value.h" #include "llvm/Pass.h" @@ -36,11 +38,15 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GenericDomTree.h" #include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> #include <cassert> #include <iterator> +#include <numeric> #include <utility> #define DEBUG_TYPE "simple-loop-unswitch" @@ -51,6 +57,15 @@ STATISTIC(NumBranches, "Number of branches unswitched"); STATISTIC(NumSwitches, "Number of switches unswitched"); STATISTIC(NumTrivial, "Number of unswitches that are trivial"); +static cl::opt<bool> EnableNonTrivialUnswitch( + "enable-nontrivial-unswitch", cl::init(false), cl::Hidden, + cl::desc("Forcibly enables non-trivial loop unswitching rather than " + "following the configuration passed into the pass.")); + +static cl::opt<int> + UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, + cl::desc("The cost threshold for unswitching a loop.")); + static void replaceLoopUsesWithConstant(Loop &L, Value &LIC, Constant &Replacement) { assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); @@ -68,24 +83,95 @@ static void replaceLoopUsesWithConstant(Loop &L, Value &LIC, } } -/// Update the dominator tree after removing one exiting predecessor of a loop -/// exit block. -static void updateLoopExitIDom(BasicBlock *LoopExitBB, Loop &L, - DominatorTree &DT) { - assert(pred_begin(LoopExitBB) != pred_end(LoopExitBB) && - "Cannot have empty predecessors of the loop exit block if we split " - "off a block to unswitch!"); +/// Update the IDom for a basic block whose predecessor set has changed. +/// +/// This routine is designed to work when the domtree update is relatively +/// localized by leveraging a known common dominator, often a loop header. +/// +/// FIXME: Should consider hand-rolling a slightly more efficient non-DFS +/// approach here as we can do that easily by persisting the candidate IDom's +/// dominating set between each predecessor. +/// +/// FIXME: Longer term, many uses of this can be replaced by an incremental +/// domtree update strategy that starts from a known dominating block and +/// rebuilds that subtree. +static bool updateIDomWithKnownCommonDominator(BasicBlock *BB, + BasicBlock *KnownDominatingBB, + DominatorTree &DT) { + assert(pred_begin(BB) != pred_end(BB) && + "This routine does not handle unreachable blocks!"); + + BasicBlock *OrigIDom = DT[BB]->getIDom()->getBlock(); + + BasicBlock *IDom = *pred_begin(BB); + assert(DT.dominates(KnownDominatingBB, IDom) && + "Bad known dominating block!"); - BasicBlock *IDom = *pred_begin(LoopExitBB); // Walk all of the other predecessors finding the nearest common dominator // until all predecessors are covered or we reach the loop header. The loop // header necessarily dominates all loop exit blocks in loop simplified form // so we can early-exit the moment we hit that block. - for (auto PI = std::next(pred_begin(LoopExitBB)), PE = pred_end(LoopExitBB); - PI != PE && IDom != L.getHeader(); ++PI) + for (auto PI = std::next(pred_begin(BB)), PE = pred_end(BB); + PI != PE && IDom != KnownDominatingBB; ++PI) { + assert(DT.dominates(KnownDominatingBB, *PI) && + "Bad known dominating block!"); IDom = DT.findNearestCommonDominator(IDom, *PI); + } + + if (IDom == OrigIDom) + return false; + + DT.changeImmediateDominator(BB, IDom); + return true; +} - DT.changeImmediateDominator(LoopExitBB, IDom); +// Note that we don't currently use the IDFCalculator here for two reasons: +// 1) It computes dominator tree levels for the entire function on each run +// of 'compute'. While this isn't terrible, given that we expect to update +// relatively small subtrees of the domtree, it isn't necessarily the right +// tradeoff. +// 2) The interface doesn't fit this usage well. It doesn't operate in +// append-only, and builds several sets that we don't need. +// +// FIXME: Neither of these issues are a big deal and could be addressed with +// some amount of refactoring of IDFCalculator. That would allow us to share +// the core logic here (which is solving the same core problem). +static void appendDomFrontier(DomTreeNode *Node, + SmallSetVector<BasicBlock *, 4> &Worklist, + SmallVectorImpl<DomTreeNode *> &DomNodes, + SmallPtrSetImpl<BasicBlock *> &DomSet) { + assert(DomNodes.empty() && "Must start with no dominator nodes."); + assert(DomSet.empty() && "Must start with an empty dominator set."); + + // First flatten this subtree into sequence of nodes by doing a pre-order + // walk. + DomNodes.push_back(Node); + // We intentionally re-evaluate the size as each node can add new children. + // Because this is a tree walk, this cannot add any duplicates. + for (int i = 0; i < (int)DomNodes.size(); ++i) + DomNodes.insert(DomNodes.end(), DomNodes[i]->begin(), DomNodes[i]->end()); + + // Now create a set of the basic blocks so we can quickly test for + // dominated successors. We could in theory use the DFS numbers of the + // dominator tree for this, but we want this to remain predictably fast + // even while we mutate the dominator tree in ways that would invalidate + // the DFS numbering. + for (DomTreeNode *InnerN : DomNodes) + DomSet.insert(InnerN->getBlock()); + + // Now re-walk the nodes, appending every successor of every node that isn't + // in the set. Note that we don't append the node itself, even though if it + // is a successor it does not strictly dominate itself and thus it would be + // part of the dominance frontier. The reason we don't append it is that + // the node passed in came *from* the worklist and so it has already been + // processed. + for (DomTreeNode *InnerN : DomNodes) + for (BasicBlock *SuccBB : successors(InnerN->getBlock())) + if (!DomSet.count(SuccBB)) + Worklist.insert(SuccBB); + + DomNodes.clear(); + DomSet.clear(); } /// Update the dominator tree after unswitching a particular former exit block. @@ -127,58 +213,14 @@ static void updateDTAfterUnswitch(BasicBlock *UnswitchedBB, BasicBlock *OldPH, // dominator frontier to see if it additionally should move up the dominator // tree. This lambda appends the dominator frontier for a node on the // worklist. - // - // Note that we don't currently use the IDFCalculator here for two reasons: - // 1) It computes dominator tree levels for the entire function on each run - // of 'compute'. While this isn't terrible, given that we expect to update - // relatively small subtrees of the domtree, it isn't necessarily the right - // tradeoff. - // 2) The interface doesn't fit this usage well. It doesn't operate in - // append-only, and builds several sets that we don't need. - // - // FIXME: Neither of these issues are a big deal and could be addressed with - // some amount of refactoring of IDFCalculator. That would allow us to share - // the core logic here (which is solving the same core problem). SmallSetVector<BasicBlock *, 4> Worklist; + + // Scratch data structures reused by domfrontier finding. SmallVector<DomTreeNode *, 4> DomNodes; SmallPtrSet<BasicBlock *, 4> DomSet; - auto AppendDomFrontier = [&](DomTreeNode *Node) { - assert(DomNodes.empty() && "Must start with no dominator nodes."); - assert(DomSet.empty() && "Must start with an empty dominator set."); - - // First flatten this subtree into sequence of nodes by doing a pre-order - // walk. - DomNodes.push_back(Node); - // We intentionally re-evaluate the size as each node can add new children. - // Because this is a tree walk, this cannot add any duplicates. - for (int i = 0; i < (int)DomNodes.size(); ++i) - DomNodes.insert(DomNodes.end(), DomNodes[i]->begin(), DomNodes[i]->end()); - - // Now create a set of the basic blocks so we can quickly test for - // dominated successors. We could in theory use the DFS numbers of the - // dominator tree for this, but we want this to remain predictably fast - // even while we mutate the dominator tree in ways that would invalidate - // the DFS numbering. - for (DomTreeNode *InnerN : DomNodes) - DomSet.insert(InnerN->getBlock()); - - // Now re-walk the nodes, appending every successor of every node that isn't - // in the set. Note that we don't append the node itself, even though if it - // is a successor it does not strictly dominate itself and thus it would be - // part of the dominance frontier. The reason we don't append it is that - // the node passed in came *from* the worklist and so it has already been - // processed. - for (DomTreeNode *InnerN : DomNodes) - for (BasicBlock *SuccBB : successors(InnerN->getBlock())) - if (!DomSet.count(SuccBB)) - Worklist.insert(SuccBB); - - DomNodes.clear(); - DomSet.clear(); - }; // Append the initial dom frontier nodes. - AppendDomFrontier(UnswitchedNode); + appendDomFrontier(UnswitchedNode, Worklist, DomNodes, DomSet); // Walk the worklist. We grow the list in the loop and so must recompute size. for (int i = 0; i < (int)Worklist.size(); ++i) { @@ -197,7 +239,7 @@ static void updateDTAfterUnswitch(BasicBlock *UnswitchedBB, BasicBlock *OldPH, DT.changeImmediateDominator(Node, OldPHNode); // Now add this node's dominator frontier to the worklist as well. - AppendDomFrontier(Node); + appendDomFrontier(Node, Worklist, DomNodes, DomSet); } } @@ -395,7 +437,7 @@ static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, // one of the predecessors for the loop exit block and may need to update its // idom. if (UnswitchedBB != LoopExitBB) - updateLoopExitIDom(LoopExitBB, L, DT); + updateIDomWithKnownCommonDominator(LoopExitBB, L.getHeader(), DT); // Since this is an i1 condition we can also trivially replace uses of it // within the loop with a constant. @@ -540,7 +582,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, *ParentBB, *OldPH); - updateLoopExitIDom(DefaultExitBB, L, DT); + updateIDomWithKnownCommonDominator(DefaultExitBB, L.getHeader(), DT); DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; } } @@ -567,7 +609,7 @@ static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, *ParentBB, *OldPH); - updateLoopExitIDom(ExitBB, L, DT); + updateIDomWithKnownCommonDominator(ExitBB, L.getHeader(), DT); } // Update the case pair to point to the split block. CasePair.second = SplitExitBB; @@ -708,15 +750,1172 @@ static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, return Changed; } +/// Build the cloned blocks for an unswitched copy of the given loop. +/// +/// The cloned blocks are inserted before the loop preheader (`LoopPH`) and +/// after the split block (`SplitBB`) that will be used to select between the +/// cloned and original loop. +/// +/// This routine handles cloning all of the necessary loop blocks and exit +/// blocks including rewriting their instructions and the relevant PHI nodes. +/// It skips loop and exit blocks that are not necessary based on the provided +/// set. It also correctly creates the unconditional branch in the cloned +/// unswitched parent block to only point at the unswitched successor. +/// +/// This does not handle most of the necessary updates to `LoopInfo`. Only exit +/// block splitting is correctly reflected in `LoopInfo`, essentially all of +/// the cloned blocks (and their loops) are left without full `LoopInfo` +/// updates. This also doesn't fully update `DominatorTree`. It adds the cloned +/// blocks to them but doesn't create the cloned `DominatorTree` structure and +/// instead the caller must recompute an accurate DT. It *does* correctly +/// update the `AssumptionCache` provided in `AC`. +static BasicBlock *buildClonedLoopBlocks( + Loop &L, BasicBlock *LoopPH, BasicBlock *SplitBB, + ArrayRef<BasicBlock *> ExitBlocks, BasicBlock *ParentBB, + BasicBlock *UnswitchedSuccBB, BasicBlock *ContinueSuccBB, + const SmallPtrSetImpl<BasicBlock *> &SkippedLoopAndExitBlocks, + ValueToValueMapTy &VMap, AssumptionCache &AC, DominatorTree &DT, + LoopInfo &LI) { + SmallVector<BasicBlock *, 4> NewBlocks; + NewBlocks.reserve(L.getNumBlocks() + ExitBlocks.size()); + + // We will need to clone a bunch of blocks, wrap up the clone operation in + // a helper. + auto CloneBlock = [&](BasicBlock *OldBB) { + // Clone the basic block and insert it before the new preheader. + BasicBlock *NewBB = CloneBasicBlock(OldBB, VMap, ".us", OldBB->getParent()); + NewBB->moveBefore(LoopPH); + + // Record this block and the mapping. + NewBlocks.push_back(NewBB); + VMap[OldBB] = NewBB; + + // Add the block to the domtree. We'll move it to the correct position + // below. + DT.addNewBlock(NewBB, SplitBB); + + return NewBB; + }; + + // First, clone the preheader. + auto *ClonedPH = CloneBlock(LoopPH); + + // Then clone all the loop blocks, skipping the ones that aren't necessary. + for (auto *LoopBB : L.blocks()) + if (!SkippedLoopAndExitBlocks.count(LoopBB)) + CloneBlock(LoopBB); + + // Split all the loop exit edges so that when we clone the exit blocks, if + // any of the exit blocks are *also* a preheader for some other loop, we + // don't create multiple predecessors entering the loop header. + for (auto *ExitBB : ExitBlocks) { + if (SkippedLoopAndExitBlocks.count(ExitBB)) + continue; + + // When we are going to clone an exit, we don't need to clone all the + // instructions in the exit block and we want to ensure we have an easy + // place to merge the CFG, so split the exit first. This is always safe to + // do because there cannot be any non-loop predecessors of a loop exit in + // loop simplified form. + auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + + // Rearrange the names to make it easier to write test cases by having the + // exit block carry the suffix rather than the merge block carrying the + // suffix. + MergeBB->takeName(ExitBB); + ExitBB->setName(Twine(MergeBB->getName()) + ".split"); + + // Now clone the original exit block. + auto *ClonedExitBB = CloneBlock(ExitBB); + assert(ClonedExitBB->getTerminator()->getNumSuccessors() == 1 && + "Exit block should have been split to have one successor!"); + assert(ClonedExitBB->getTerminator()->getSuccessor(0) == MergeBB && + "Cloned exit block has the wrong successor!"); + + // Move the merge block's idom to be the split point as one exit is + // dominated by one header, and the other by another, so we know the split + // point dominates both. While the dominator tree isn't fully accurate, we + // want sub-trees within the original loop to be correctly reflect + // dominance within that original loop (at least) and that requires moving + // the merge block out of that subtree. + // FIXME: This is very brittle as we essentially have a partial contract on + // the dominator tree. We really need to instead update it and keep it + // valid or stop relying on it. + DT.changeImmediateDominator(MergeBB, SplitBB); + + // Remap any cloned instructions and create a merge phi node for them. + for (auto ZippedInsts : llvm::zip_first( + llvm::make_range(ExitBB->begin(), std::prev(ExitBB->end())), + llvm::make_range(ClonedExitBB->begin(), + std::prev(ClonedExitBB->end())))) { + Instruction &I = std::get<0>(ZippedInsts); + Instruction &ClonedI = std::get<1>(ZippedInsts); + + // The only instructions in the exit block should be PHI nodes and + // potentially a landing pad. + assert( + (isa<PHINode>(I) || isa<LandingPadInst>(I) || isa<CatchPadInst>(I)) && + "Bad instruction in exit block!"); + // We should have a value map between the instruction and its clone. + assert(VMap.lookup(&I) == &ClonedI && "Mismatch in the value map!"); + + auto *MergePN = + PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi", + &*MergeBB->getFirstInsertionPt()); + I.replaceAllUsesWith(MergePN); + MergePN->addIncoming(&I, ExitBB); + MergePN->addIncoming(&ClonedI, ClonedExitBB); + } + } + + // Rewrite the instructions in the cloned blocks to refer to the instructions + // in the cloned blocks. We have to do this as a second pass so that we have + // everything available. Also, we have inserted new instructions which may + // include assume intrinsics, so we update the assumption cache while + // processing this. + for (auto *ClonedBB : NewBlocks) + for (Instruction &I : *ClonedBB) { + RemapInstruction(&I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC.registerAssumption(II); + } + + // Remove the cloned parent as a predecessor of the cloned continue successor + // if we did in fact clone it. + auto *ClonedParentBB = cast<BasicBlock>(VMap.lookup(ParentBB)); + if (auto *ClonedContinueSuccBB = + cast_or_null<BasicBlock>(VMap.lookup(ContinueSuccBB))) + ClonedContinueSuccBB->removePredecessor(ClonedParentBB, + /*DontDeleteUselessPHIs*/ true); + // Replace the cloned branch with an unconditional branch to the cloneed + // unswitched successor. + auto *ClonedSuccBB = cast<BasicBlock>(VMap.lookup(UnswitchedSuccBB)); + ClonedParentBB->getTerminator()->eraseFromParent(); + BranchInst::Create(ClonedSuccBB, ClonedParentBB); + + // Update any PHI nodes in the cloned successors of the skipped blocks to not + // have spurious incoming values. + for (auto *LoopBB : L.blocks()) + if (SkippedLoopAndExitBlocks.count(LoopBB)) + for (auto *SuccBB : successors(LoopBB)) + if (auto *ClonedSuccBB = cast_or_null<BasicBlock>(VMap.lookup(SuccBB))) + for (PHINode &PN : ClonedSuccBB->phis()) + PN.removeIncomingValue(LoopBB, /*DeletePHIIfEmpty*/ false); + + return ClonedPH; +} + +/// Recursively clone the specified loop and all of its children. +/// +/// The target parent loop for the clone should be provided, or can be null if +/// the clone is a top-level loop. While cloning, all the blocks are mapped +/// with the provided value map. The entire original loop must be present in +/// the value map. The cloned loop is returned. +static Loop *cloneLoopNest(Loop &OrigRootL, Loop *RootParentL, + const ValueToValueMapTy &VMap, LoopInfo &LI) { + auto AddClonedBlocksToLoop = [&](Loop &OrigL, Loop &ClonedL) { + assert(ClonedL.getBlocks().empty() && "Must start with an empty loop!"); + ClonedL.reserveBlocks(OrigL.getNumBlocks()); + for (auto *BB : OrigL.blocks()) { + auto *ClonedBB = cast<BasicBlock>(VMap.lookup(BB)); + ClonedL.addBlockEntry(ClonedBB); + if (LI.getLoopFor(BB) == &OrigL) { + assert(!LI.getLoopFor(ClonedBB) && + "Should not have an existing loop for this block!"); + LI.changeLoopFor(ClonedBB, &ClonedL); + } + } + }; + + // We specially handle the first loop because it may get cloned into + // a different parent and because we most commonly are cloning leaf loops. + Loop *ClonedRootL = LI.AllocateLoop(); + if (RootParentL) + RootParentL->addChildLoop(ClonedRootL); + else + LI.addTopLevelLoop(ClonedRootL); + AddClonedBlocksToLoop(OrigRootL, *ClonedRootL); + + if (OrigRootL.empty()) + return ClonedRootL; + + // If we have a nest, we can quickly clone the entire loop nest using an + // iterative approach because it is a tree. We keep the cloned parent in the + // data structure to avoid repeatedly querying through a map to find it. + SmallVector<std::pair<Loop *, Loop *>, 16> LoopsToClone; + // Build up the loops to clone in reverse order as we'll clone them from the + // back. + for (Loop *ChildL : llvm::reverse(OrigRootL)) + LoopsToClone.push_back({ClonedRootL, ChildL}); + do { + Loop *ClonedParentL, *L; + std::tie(ClonedParentL, L) = LoopsToClone.pop_back_val(); + Loop *ClonedL = LI.AllocateLoop(); + ClonedParentL->addChildLoop(ClonedL); + AddClonedBlocksToLoop(*L, *ClonedL); + for (Loop *ChildL : llvm::reverse(*L)) + LoopsToClone.push_back({ClonedL, ChildL}); + } while (!LoopsToClone.empty()); + + return ClonedRootL; +} + +/// Build the cloned loops of an original loop from unswitching. +/// +/// Because unswitching simplifies the CFG of the loop, this isn't a trivial +/// operation. We need to re-verify that there even is a loop (as the backedge +/// may not have been cloned), and even if there are remaining backedges the +/// backedge set may be different. However, we know that each child loop is +/// undisturbed, we only need to find where to place each child loop within +/// either any parent loop or within a cloned version of the original loop. +/// +/// Because child loops may end up cloned outside of any cloned version of the +/// original loop, multiple cloned sibling loops may be created. All of them +/// are returned so that the newly introduced loop nest roots can be +/// identified. +static Loop *buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, + const ValueToValueMapTy &VMap, LoopInfo &LI, + SmallVectorImpl<Loop *> &NonChildClonedLoops) { + Loop *ClonedL = nullptr; + + auto *OrigPH = OrigL.getLoopPreheader(); + auto *OrigHeader = OrigL.getHeader(); + + auto *ClonedPH = cast<BasicBlock>(VMap.lookup(OrigPH)); + auto *ClonedHeader = cast<BasicBlock>(VMap.lookup(OrigHeader)); + + // We need to know the loops of the cloned exit blocks to even compute the + // accurate parent loop. If we only clone exits to some parent of the + // original parent, we want to clone into that outer loop. We also keep track + // of the loops that our cloned exit blocks participate in. + Loop *ParentL = nullptr; + SmallVector<BasicBlock *, 4> ClonedExitsInLoops; + SmallDenseMap<BasicBlock *, Loop *, 16> ExitLoopMap; + ClonedExitsInLoops.reserve(ExitBlocks.size()); + for (auto *ExitBB : ExitBlocks) + if (auto *ClonedExitBB = cast_or_null<BasicBlock>(VMap.lookup(ExitBB))) + if (Loop *ExitL = LI.getLoopFor(ExitBB)) { + ExitLoopMap[ClonedExitBB] = ExitL; + ClonedExitsInLoops.push_back(ClonedExitBB); + if (!ParentL || (ParentL != ExitL && ParentL->contains(ExitL))) + ParentL = ExitL; + } + assert((!ParentL || ParentL == OrigL.getParentLoop() || + ParentL->contains(OrigL.getParentLoop())) && + "The computed parent loop should always contain (or be) the parent of " + "the original loop."); + + // We build the set of blocks dominated by the cloned header from the set of + // cloned blocks out of the original loop. While not all of these will + // necessarily be in the cloned loop, it is enough to establish that they + // aren't in unreachable cycles, etc. + SmallSetVector<BasicBlock *, 16> ClonedLoopBlocks; + for (auto *BB : OrigL.blocks()) + if (auto *ClonedBB = cast_or_null<BasicBlock>(VMap.lookup(BB))) + ClonedLoopBlocks.insert(ClonedBB); + + // Rebuild the set of blocks that will end up in the cloned loop. We may have + // skipped cloning some region of this loop which can in turn skip some of + // the backedges so we have to rebuild the blocks in the loop based on the + // backedges that remain after cloning. + SmallVector<BasicBlock *, 16> Worklist; + SmallPtrSet<BasicBlock *, 16> BlocksInClonedLoop; + for (auto *Pred : predecessors(ClonedHeader)) { + // The only possible non-loop header predecessor is the preheader because + // we know we cloned the loop in simplified form. + if (Pred == ClonedPH) + continue; + + // Because the loop was in simplified form, the only non-loop predecessor + // should be the preheader. + assert(ClonedLoopBlocks.count(Pred) && "Found a predecessor of the loop " + "header other than the preheader " + "that is not part of the loop!"); + + // Insert this block into the loop set and on the first visit (and if it + // isn't the header we're currently walking) put it into the worklist to + // recurse through. + if (BlocksInClonedLoop.insert(Pred).second && Pred != ClonedHeader) + Worklist.push_back(Pred); + } + + // If we had any backedges then there *is* a cloned loop. Put the header into + // the loop set and then walk the worklist backwards to find all the blocks + // that remain within the loop after cloning. + if (!BlocksInClonedLoop.empty()) { + BlocksInClonedLoop.insert(ClonedHeader); + + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + assert(BlocksInClonedLoop.count(BB) && + "Didn't put block into the loop set!"); + + // Insert any predecessors that are in the possible set into the cloned + // set, and if the insert is successful, add them to the worklist. Note + // that we filter on the blocks that are definitely reachable via the + // backedge to the loop header so we may prune out dead code within the + // cloned loop. + for (auto *Pred : predecessors(BB)) + if (ClonedLoopBlocks.count(Pred) && + BlocksInClonedLoop.insert(Pred).second) + Worklist.push_back(Pred); + } + + ClonedL = LI.AllocateLoop(); + if (ParentL) { + ParentL->addBasicBlockToLoop(ClonedPH, LI); + ParentL->addChildLoop(ClonedL); + } else { + LI.addTopLevelLoop(ClonedL); + } + + ClonedL->reserveBlocks(BlocksInClonedLoop.size()); + // We don't want to just add the cloned loop blocks based on how we + // discovered them. The original order of blocks was carefully built in + // a way that doesn't rely on predecessor ordering. Rather than re-invent + // that logic, we just re-walk the original blocks (and those of the child + // loops) and filter them as we add them into the cloned loop. + for (auto *BB : OrigL.blocks()) { + auto *ClonedBB = cast_or_null<BasicBlock>(VMap.lookup(BB)); + if (!ClonedBB || !BlocksInClonedLoop.count(ClonedBB)) + continue; + + // Directly add the blocks that are only in this loop. + if (LI.getLoopFor(BB) == &OrigL) { + ClonedL->addBasicBlockToLoop(ClonedBB, LI); + continue; + } + + // We want to manually add it to this loop and parents. + // Registering it with LoopInfo will happen when we clone the top + // loop for this block. + for (Loop *PL = ClonedL; PL; PL = PL->getParentLoop()) + PL->addBlockEntry(ClonedBB); + } + + // Now add each child loop whose header remains within the cloned loop. All + // of the blocks within the loop must satisfy the same constraints as the + // header so once we pass the header checks we can just clone the entire + // child loop nest. + for (Loop *ChildL : OrigL) { + auto *ClonedChildHeader = + cast_or_null<BasicBlock>(VMap.lookup(ChildL->getHeader())); + if (!ClonedChildHeader || !BlocksInClonedLoop.count(ClonedChildHeader)) + continue; + +#ifndef NDEBUG + // We should never have a cloned child loop header but fail to have + // all of the blocks for that child loop. + for (auto *ChildLoopBB : ChildL->blocks()) + assert(BlocksInClonedLoop.count( + cast<BasicBlock>(VMap.lookup(ChildLoopBB))) && + "Child cloned loop has a header within the cloned outer " + "loop but not all of its blocks!"); +#endif + + cloneLoopNest(*ChildL, ClonedL, VMap, LI); + } + } + + // Now that we've handled all the components of the original loop that were + // cloned into a new loop, we still need to handle anything from the original + // loop that wasn't in a cloned loop. + + // Figure out what blocks are left to place within any loop nest containing + // the unswitched loop. If we never formed a loop, the cloned PH is one of + // them. + SmallPtrSet<BasicBlock *, 16> UnloopedBlockSet; + if (BlocksInClonedLoop.empty()) + UnloopedBlockSet.insert(ClonedPH); + for (auto *ClonedBB : ClonedLoopBlocks) + if (!BlocksInClonedLoop.count(ClonedBB)) + UnloopedBlockSet.insert(ClonedBB); + + // Copy the cloned exits and sort them in ascending loop depth, we'll work + // backwards across these to process them inside out. The order shouldn't + // matter as we're just trying to build up the map from inside-out; we use + // the map in a more stably ordered way below. + auto OrderedClonedExitsInLoops = ClonedExitsInLoops; + std::sort(OrderedClonedExitsInLoops.begin(), OrderedClonedExitsInLoops.end(), + [&](BasicBlock *LHS, BasicBlock *RHS) { + return ExitLoopMap.lookup(LHS)->getLoopDepth() < + ExitLoopMap.lookup(RHS)->getLoopDepth(); + }); + + // Populate the existing ExitLoopMap with everything reachable from each + // exit, starting from the inner most exit. + while (!UnloopedBlockSet.empty() && !OrderedClonedExitsInLoops.empty()) { + assert(Worklist.empty() && "Didn't clear worklist!"); + + BasicBlock *ExitBB = OrderedClonedExitsInLoops.pop_back_val(); + Loop *ExitL = ExitLoopMap.lookup(ExitBB); + + // Walk the CFG back until we hit the cloned PH adding everything reachable + // and in the unlooped set to this exit block's loop. + Worklist.push_back(ExitBB); + do { + BasicBlock *BB = Worklist.pop_back_val(); + // We can stop recursing at the cloned preheader (if we get there). + if (BB == ClonedPH) + continue; + + for (BasicBlock *PredBB : predecessors(BB)) { + // If this pred has already been moved to our set or is part of some + // (inner) loop, no update needed. + if (!UnloopedBlockSet.erase(PredBB)) { + assert( + (BlocksInClonedLoop.count(PredBB) || ExitLoopMap.count(PredBB)) && + "Predecessor not mapped to a loop!"); + continue; + } + + // We just insert into the loop set here. We'll add these blocks to the + // exit loop after we build up the set in an order that doesn't rely on + // predecessor order (which in turn relies on use list order). + bool Inserted = ExitLoopMap.insert({PredBB, ExitL}).second; + (void)Inserted; + assert(Inserted && "Should only visit an unlooped block once!"); + + // And recurse through to its predecessors. + Worklist.push_back(PredBB); + } + } while (!Worklist.empty()); + } + + // Now that the ExitLoopMap gives as mapping for all the non-looping cloned + // blocks to their outer loops, walk the cloned blocks and the cloned exits + // in their original order adding them to the correct loop. + + // We need a stable insertion order. We use the order of the original loop + // order and map into the correct parent loop. + for (auto *BB : llvm::concat<BasicBlock *const>( + makeArrayRef(ClonedPH), ClonedLoopBlocks, ClonedExitsInLoops)) + if (Loop *OuterL = ExitLoopMap.lookup(BB)) + OuterL->addBasicBlockToLoop(BB, LI); + +#ifndef NDEBUG + for (auto &BBAndL : ExitLoopMap) { + auto *BB = BBAndL.first; + auto *OuterL = BBAndL.second; + assert(LI.getLoopFor(BB) == OuterL && + "Failed to put all blocks into outer loops!"); + } +#endif + + // Now that all the blocks are placed into the correct containing loop in the + // absence of child loops, find all the potentially cloned child loops and + // clone them into whatever outer loop we placed their header into. + for (Loop *ChildL : OrigL) { + auto *ClonedChildHeader = + cast_or_null<BasicBlock>(VMap.lookup(ChildL->getHeader())); + if (!ClonedChildHeader || BlocksInClonedLoop.count(ClonedChildHeader)) + continue; + +#ifndef NDEBUG + for (auto *ChildLoopBB : ChildL->blocks()) + assert(VMap.count(ChildLoopBB) && + "Cloned a child loop header but not all of that loops blocks!"); +#endif + + NonChildClonedLoops.push_back(cloneLoopNest( + *ChildL, ExitLoopMap.lookup(ClonedChildHeader), VMap, LI)); + } + + // Return the main cloned loop if any. + return ClonedL; +} + +static void deleteDeadBlocksFromLoop(Loop &L, BasicBlock *DeadSubtreeRoot, + SmallVectorImpl<BasicBlock *> &ExitBlocks, + DominatorTree &DT, LoopInfo &LI) { + // Walk the dominator tree to build up the set of blocks we will delete here. + // The order is designed to allow us to always delete bottom-up and avoid any + // dangling uses. + SmallSetVector<BasicBlock *, 16> DeadBlocks; + DeadBlocks.insert(DeadSubtreeRoot); + for (int i = 0; i < (int)DeadBlocks.size(); ++i) + for (DomTreeNode *ChildN : *DT[DeadBlocks[i]]) { + // FIXME: This assert should pass and that means we don't change nearly + // as much below! Consider rewriting all of this to avoid deleting + // blocks. They are always cloned before being deleted, and so instead + // could just be moved. + // FIXME: This in turn means that we might actually be more able to + // update the domtree. + assert((L.contains(ChildN->getBlock()) || + llvm::find(ExitBlocks, ChildN->getBlock()) != ExitBlocks.end()) && + "Should never reach beyond the loop and exits when deleting!"); + DeadBlocks.insert(ChildN->getBlock()); + } + + // Filter out the dead blocks from the exit blocks list so that it can be + // used in the caller. + llvm::erase_if(ExitBlocks, + [&](BasicBlock *BB) { return DeadBlocks.count(BB); }); + + // Remove these blocks from their successors. + for (auto *BB : DeadBlocks) + for (BasicBlock *SuccBB : successors(BB)) + SuccBB->removePredecessor(BB, /*DontDeleteUselessPHIs*/ true); + + // Walk from this loop up through its parents removing all of the dead blocks. + for (Loop *ParentL = &L; ParentL; ParentL = ParentL->getParentLoop()) { + for (auto *BB : DeadBlocks) + ParentL->getBlocksSet().erase(BB); + llvm::erase_if(ParentL->getBlocksVector(), + [&](BasicBlock *BB) { return DeadBlocks.count(BB); }); + } + + // Now delete the dead child loops. This raw delete will clear them + // recursively. + llvm::erase_if(L.getSubLoopsVector(), [&](Loop *ChildL) { + if (!DeadBlocks.count(ChildL->getHeader())) + return false; + + assert(llvm::all_of(ChildL->blocks(), + [&](BasicBlock *ChildBB) { + return DeadBlocks.count(ChildBB); + }) && + "If the child loop header is dead all blocks in the child loop must " + "be dead as well!"); + LI.destroy(ChildL); + return true; + }); + + // Remove the mappings for the dead blocks. + for (auto *BB : DeadBlocks) + LI.changeLoopFor(BB, nullptr); + + // Drop all the references from these blocks to others to handle cyclic + // references as we start deleting the blocks themselves. + for (auto *BB : DeadBlocks) + BB->dropAllReferences(); + + for (auto *BB : llvm::reverse(DeadBlocks)) { + DT.eraseNode(BB); + BB->eraseFromParent(); + } +} + +/// Recompute the set of blocks in a loop after unswitching. +/// +/// This walks from the original headers predecessors to rebuild the loop. We +/// take advantage of the fact that new blocks can't have been added, and so we +/// filter by the original loop's blocks. This also handles potentially +/// unreachable code that we don't want to explore but might be found examining +/// the predecessors of the header. +/// +/// If the original loop is no longer a loop, this will return an empty set. If +/// it remains a loop, all the blocks within it will be added to the set +/// (including those blocks in inner loops). +static SmallPtrSet<const BasicBlock *, 16> recomputeLoopBlockSet(Loop &L, + LoopInfo &LI) { + SmallPtrSet<const BasicBlock *, 16> LoopBlockSet; + + auto *PH = L.getLoopPreheader(); + auto *Header = L.getHeader(); + + // A worklist to use while walking backwards from the header. + SmallVector<BasicBlock *, 16> Worklist; + + // First walk the predecessors of the header to find the backedges. This will + // form the basis of our walk. + for (auto *Pred : predecessors(Header)) { + // Skip the preheader. + if (Pred == PH) + continue; + + // Because the loop was in simplified form, the only non-loop predecessor + // is the preheader. + assert(L.contains(Pred) && "Found a predecessor of the loop header other " + "than the preheader that is not part of the " + "loop!"); + + // Insert this block into the loop set and on the first visit and, if it + // isn't the header we're currently walking, put it into the worklist to + // recurse through. + if (LoopBlockSet.insert(Pred).second && Pred != Header) + Worklist.push_back(Pred); + } + + // If no backedges were found, we're done. + if (LoopBlockSet.empty()) + return LoopBlockSet; + + // Add the loop header to the set. + LoopBlockSet.insert(Header); + + // We found backedges, recurse through them to identify the loop blocks. + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + assert(LoopBlockSet.count(BB) && "Didn't put block into the loop set!"); + + // Because we know the inner loop structure remains valid we can use the + // loop structure to jump immediately across the entire nested loop. + // Further, because it is in loop simplified form, we can directly jump + // to its preheader afterward. + if (Loop *InnerL = LI.getLoopFor(BB)) + if (InnerL != &L) { + assert(L.contains(InnerL) && + "Should not reach a loop *outside* this loop!"); + // The preheader is the only possible predecessor of the loop so + // insert it into the set and check whether it was already handled. + auto *InnerPH = InnerL->getLoopPreheader(); + assert(L.contains(InnerPH) && "Cannot contain an inner loop block " + "but not contain the inner loop " + "preheader!"); + if (!LoopBlockSet.insert(InnerPH).second) + // The only way to reach the preheader is through the loop body + // itself so if it has been visited the loop is already handled. + continue; + + // Insert all of the blocks (other than those already present) into + // the loop set. The only block we expect to already be in the set is + // the one we used to find this loop as we immediately handle the + // others the first time we encounter the loop. + for (auto *InnerBB : InnerL->blocks()) { + if (InnerBB == BB) { + assert(LoopBlockSet.count(InnerBB) && + "Block should already be in the set!"); + continue; + } + + bool Inserted = LoopBlockSet.insert(InnerBB).second; + (void)Inserted; + assert(Inserted && "Should only insert an inner loop once!"); + } + + // Add the preheader to the worklist so we will continue past the + // loop body. + Worklist.push_back(InnerPH); + continue; + } + + // Insert any predecessors that were in the original loop into the new + // set, and if the insert is successful, add them to the worklist. + for (auto *Pred : predecessors(BB)) + if (L.contains(Pred) && LoopBlockSet.insert(Pred).second) + Worklist.push_back(Pred); + } + + // We've found all the blocks participating in the loop, return our completed + // set. + return LoopBlockSet; +} + +/// Rebuild a loop after unswitching removes some subset of blocks and edges. +/// +/// The removal may have removed some child loops entirely but cannot have +/// disturbed any remaining child loops. However, they may need to be hoisted +/// to the parent loop (or to be top-level loops). The original loop may be +/// completely removed. +/// +/// The sibling loops resulting from this update are returned. If the original +/// loop remains a valid loop, it will be the first entry in this list with all +/// of the newly sibling loops following it. +/// +/// Returns true if the loop remains a loop after unswitching, and false if it +/// is no longer a loop after unswitching (and should not continue to be +/// referenced). +static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, + LoopInfo &LI, + SmallVectorImpl<Loop *> &HoistedLoops) { + auto *PH = L.getLoopPreheader(); + + // Compute the actual parent loop from the exit blocks. Because we may have + // pruned some exits the loop may be different from the original parent. + Loop *ParentL = nullptr; + SmallVector<Loop *, 4> ExitLoops; + SmallVector<BasicBlock *, 4> ExitsInLoops; + ExitsInLoops.reserve(ExitBlocks.size()); + for (auto *ExitBB : ExitBlocks) + if (Loop *ExitL = LI.getLoopFor(ExitBB)) { + ExitLoops.push_back(ExitL); + ExitsInLoops.push_back(ExitBB); + if (!ParentL || (ParentL != ExitL && ParentL->contains(ExitL))) + ParentL = ExitL; + } + + // Recompute the blocks participating in this loop. This may be empty if it + // is no longer a loop. + auto LoopBlockSet = recomputeLoopBlockSet(L, LI); + + // If we still have a loop, we need to re-set the loop's parent as the exit + // block set changing may have moved it within the loop nest. Note that this + // can only happen when this loop has a parent as it can only hoist the loop + // *up* the nest. + if (!LoopBlockSet.empty() && L.getParentLoop() != ParentL) { + // Remove this loop's (original) blocks from all of the intervening loops. + for (Loop *IL = L.getParentLoop(); IL != ParentL; + IL = IL->getParentLoop()) { + IL->getBlocksSet().erase(PH); + for (auto *BB : L.blocks()) + IL->getBlocksSet().erase(BB); + llvm::erase_if(IL->getBlocksVector(), [&](BasicBlock *BB) { + return BB == PH || L.contains(BB); + }); + } + + LI.changeLoopFor(PH, ParentL); + L.getParentLoop()->removeChildLoop(&L); + if (ParentL) + ParentL->addChildLoop(&L); + else + LI.addTopLevelLoop(&L); + } + + // Now we update all the blocks which are no longer within the loop. + auto &Blocks = L.getBlocksVector(); + auto BlocksSplitI = + LoopBlockSet.empty() + ? Blocks.begin() + : std::stable_partition( + Blocks.begin(), Blocks.end(), + [&](BasicBlock *BB) { return LoopBlockSet.count(BB); }); + + // Before we erase the list of unlooped blocks, build a set of them. + SmallPtrSet<BasicBlock *, 16> UnloopedBlocks(BlocksSplitI, Blocks.end()); + if (LoopBlockSet.empty()) + UnloopedBlocks.insert(PH); + + // Now erase these blocks from the loop. + for (auto *BB : make_range(BlocksSplitI, Blocks.end())) + L.getBlocksSet().erase(BB); + Blocks.erase(BlocksSplitI, Blocks.end()); + + // Sort the exits in ascending loop depth, we'll work backwards across these + // to process them inside out. + std::stable_sort(ExitsInLoops.begin(), ExitsInLoops.end(), + [&](BasicBlock *LHS, BasicBlock *RHS) { + return LI.getLoopDepth(LHS) < LI.getLoopDepth(RHS); + }); + + // We'll build up a set for each exit loop. + SmallPtrSet<BasicBlock *, 16> NewExitLoopBlocks; + Loop *PrevExitL = L.getParentLoop(); // The deepest possible exit loop. + + auto RemoveUnloopedBlocksFromLoop = + [](Loop &L, SmallPtrSetImpl<BasicBlock *> &UnloopedBlocks) { + for (auto *BB : UnloopedBlocks) + L.getBlocksSet().erase(BB); + llvm::erase_if(L.getBlocksVector(), [&](BasicBlock *BB) { + return UnloopedBlocks.count(BB); + }); + }; + + SmallVector<BasicBlock *, 16> Worklist; + while (!UnloopedBlocks.empty() && !ExitsInLoops.empty()) { + assert(Worklist.empty() && "Didn't clear worklist!"); + assert(NewExitLoopBlocks.empty() && "Didn't clear loop set!"); + + // Grab the next exit block, in decreasing loop depth order. + BasicBlock *ExitBB = ExitsInLoops.pop_back_val(); + Loop &ExitL = *LI.getLoopFor(ExitBB); + assert(ExitL.contains(&L) && "Exit loop must contain the inner loop!"); + + // Erase all of the unlooped blocks from the loops between the previous + // exit loop and this exit loop. This works because the ExitInLoops list is + // sorted in increasing order of loop depth and thus we visit loops in + // decreasing order of loop depth. + for (; PrevExitL != &ExitL; PrevExitL = PrevExitL->getParentLoop()) + RemoveUnloopedBlocksFromLoop(*PrevExitL, UnloopedBlocks); + + // Walk the CFG back until we hit the cloned PH adding everything reachable + // and in the unlooped set to this exit block's loop. + Worklist.push_back(ExitBB); + do { + BasicBlock *BB = Worklist.pop_back_val(); + // We can stop recursing at the cloned preheader (if we get there). + if (BB == PH) + continue; + + for (BasicBlock *PredBB : predecessors(BB)) { + // If this pred has already been moved to our set or is part of some + // (inner) loop, no update needed. + if (!UnloopedBlocks.erase(PredBB)) { + assert((NewExitLoopBlocks.count(PredBB) || + ExitL.contains(LI.getLoopFor(PredBB))) && + "Predecessor not in a nested loop (or already visited)!"); + continue; + } + + // We just insert into the loop set here. We'll add these blocks to the + // exit loop after we build up the set in a deterministic order rather + // than the predecessor-influenced visit order. + bool Inserted = NewExitLoopBlocks.insert(PredBB).second; + (void)Inserted; + assert(Inserted && "Should only visit an unlooped block once!"); + + // And recurse through to its predecessors. + Worklist.push_back(PredBB); + } + } while (!Worklist.empty()); + + // If blocks in this exit loop were directly part of the original loop (as + // opposed to a child loop) update the map to point to this exit loop. This + // just updates a map and so the fact that the order is unstable is fine. + for (auto *BB : NewExitLoopBlocks) + if (Loop *BBL = LI.getLoopFor(BB)) + if (BBL == &L || !L.contains(BBL)) + LI.changeLoopFor(BB, &ExitL); + + // We will remove the remaining unlooped blocks from this loop in the next + // iteration or below. + NewExitLoopBlocks.clear(); + } + + // Any remaining unlooped blocks are no longer part of any loop unless they + // are part of some child loop. + for (; PrevExitL; PrevExitL = PrevExitL->getParentLoop()) + RemoveUnloopedBlocksFromLoop(*PrevExitL, UnloopedBlocks); + for (auto *BB : UnloopedBlocks) + if (Loop *BBL = LI.getLoopFor(BB)) + if (BBL == &L || !L.contains(BBL)) + LI.changeLoopFor(BB, nullptr); + + // Sink all the child loops whose headers are no longer in the loop set to + // the parent (or to be top level loops). We reach into the loop and directly + // update its subloop vector to make this batch update efficient. + auto &SubLoops = L.getSubLoopsVector(); + auto SubLoopsSplitI = + LoopBlockSet.empty() + ? SubLoops.begin() + : std::stable_partition( + SubLoops.begin(), SubLoops.end(), [&](Loop *SubL) { + return LoopBlockSet.count(SubL->getHeader()); + }); + for (auto *HoistedL : make_range(SubLoopsSplitI, SubLoops.end())) { + HoistedLoops.push_back(HoistedL); + HoistedL->setParentLoop(nullptr); + + // To compute the new parent of this hoisted loop we look at where we + // placed the preheader above. We can't lookup the header itself because we + // retained the mapping from the header to the hoisted loop. But the + // preheader and header should have the exact same new parent computed + // based on the set of exit blocks from the original loop as the preheader + // is a predecessor of the header and so reached in the reverse walk. And + // because the loops were all in simplified form the preheader of the + // hoisted loop can't be part of some *other* loop. + if (auto *NewParentL = LI.getLoopFor(HoistedL->getLoopPreheader())) + NewParentL->addChildLoop(HoistedL); + else + LI.addTopLevelLoop(HoistedL); + } + SubLoops.erase(SubLoopsSplitI, SubLoops.end()); + + // Actually delete the loop if nothing remained within it. + if (Blocks.empty()) { + assert(SubLoops.empty() && + "Failed to remove all subloops from the original loop!"); + if (Loop *ParentL = L.getParentLoop()) + ParentL->removeChildLoop(llvm::find(*ParentL, &L)); + else + LI.removeLoop(llvm::find(LI, &L)); + LI.destroy(&L); + return false; + } + + return true; +} + +/// Helper to visit a dominator subtree, invoking a callable on each node. +/// +/// Returning false at any point will stop walking past that node of the tree. +template <typename CallableT> +void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { + SmallVector<DomTreeNode *, 4> DomWorklist; + DomWorklist.push_back(DT[BB]); +#ifndef NDEBUG + SmallPtrSet<DomTreeNode *, 4> Visited; + Visited.insert(DT[BB]); +#endif + do { + DomTreeNode *N = DomWorklist.pop_back_val(); + + // Visit this node. + if (!Callable(N->getBlock())) + continue; + + // Accumulate the child nodes. + for (DomTreeNode *ChildN : *N) { + assert(Visited.insert(ChildN).second && + "Cannot visit a node twice when walking a tree!"); + DomWorklist.push_back(ChildN); + } + } while (!DomWorklist.empty()); +} + +/// Take an invariant branch that has been determined to be safe and worthwhile +/// to unswitch despite being non-trivial to do so and perform the unswitch. +/// +/// This directly updates the CFG to hoist the predicate out of the loop, and +/// clone the necessary parts of the loop to maintain behavior. +/// +/// It also updates both dominator tree and loopinfo based on the unswitching. +/// +/// Once unswitching has been performed it runs the provided callback to report +/// the new loops and no-longer valid loops to the caller. +static bool unswitchInvariantBranch( + Loop &L, BranchInst &BI, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, + function_ref<void(bool, ArrayRef<Loop *>)> NonTrivialUnswitchCB) { + assert(BI.isConditional() && "Can only unswitch a conditional branch!"); + assert(L.isLoopInvariant(BI.getCondition()) && + "Can only unswitch an invariant branch condition!"); + + // Constant and BBs tracking the cloned and continuing successor. + const int ClonedSucc = 0; + auto *ParentBB = BI.getParent(); + auto *UnswitchedSuccBB = BI.getSuccessor(ClonedSucc); + auto *ContinueSuccBB = BI.getSuccessor(1 - ClonedSucc); + + assert(UnswitchedSuccBB != ContinueSuccBB && + "Should not unswitch a branch that always goes to the same place!"); + + // The branch should be in this exact loop. Any inner loop's invariant branch + // should be handled by unswitching that inner loop. The caller of this + // routine should filter out any candidates that remain (but were skipped for + // whatever reason). + assert(LI.getLoopFor(ParentBB) == &L && "Branch in an inner loop!"); + + SmallVector<BasicBlock *, 4> ExitBlocks; + L.getUniqueExitBlocks(ExitBlocks); + + // We cannot unswitch if exit blocks contain a cleanuppad instruction as we + // don't know how to split those exit blocks. + // FIXME: We should teach SplitBlock to handle this and remove this + // restriction. + for (auto *ExitBB : ExitBlocks) + if (isa<CleanupPadInst>(ExitBB->getFirstNonPHI())) + return false; + + SmallPtrSet<BasicBlock *, 4> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); + + // Compute the parent loop now before we start hacking on things. + Loop *ParentL = L.getParentLoop(); + + // Compute the outer-most loop containing one of our exit blocks. This is the + // furthest up our loopnest which can be mutated, which we will use below to + // update things. + Loop *OuterExitL = &L; + for (auto *ExitBB : ExitBlocks) { + Loop *NewOuterExitL = LI.getLoopFor(ExitBB); + if (!NewOuterExitL) { + // We exited the entire nest with this block, so we're done. + OuterExitL = nullptr; + break; + } + if (NewOuterExitL != OuterExitL && NewOuterExitL->contains(OuterExitL)) + OuterExitL = NewOuterExitL; + } + + // If the edge we *aren't* cloning in the unswitch (the continuing edge) + // dominates its target, we can skip cloning the dominated region of the loop + // and its exits. We compute this as a set of nodes to be skipped. + SmallPtrSet<BasicBlock *, 4> SkippedLoopAndExitBlocks; + if (ContinueSuccBB->getUniquePredecessor() || + llvm::all_of(predecessors(ContinueSuccBB), [&](BasicBlock *PredBB) { + return PredBB == ParentBB || DT.dominates(ContinueSuccBB, PredBB); + })) { + visitDomSubTree(DT, ContinueSuccBB, [&](BasicBlock *BB) { + SkippedLoopAndExitBlocks.insert(BB); + return true; + }); + } + // Similarly, if the edge we *are* cloning in the unswitch (the unswitched + // edge) dominates its target, we will end up with dead nodes in the original + // loop and its exits that will need to be deleted. Here, we just retain that + // the property holds and will compute the deleted set later. + bool DeleteUnswitchedSucc = + UnswitchedSuccBB->getUniquePredecessor() || + llvm::all_of(predecessors(UnswitchedSuccBB), [&](BasicBlock *PredBB) { + return PredBB == ParentBB || DT.dominates(UnswitchedSuccBB, PredBB); + }); + + // Split the preheader, so that we know that there is a safe place to insert + // the conditional branch. We will change the preheader to have a conditional + // branch on LoopCond. The original preheader will become the split point + // between the unswitched versions, and we will have a new preheader for the + // original loop. + BasicBlock *SplitBB = L.getLoopPreheader(); + BasicBlock *LoopPH = SplitEdge(SplitBB, L.getHeader(), &DT, &LI); + + // Keep a mapping for the cloned values. + ValueToValueMapTy VMap; + + // Build the cloned blocks from the loop. + auto *ClonedPH = buildClonedLoopBlocks( + L, LoopPH, SplitBB, ExitBlocks, ParentBB, UnswitchedSuccBB, + ContinueSuccBB, SkippedLoopAndExitBlocks, VMap, AC, DT, LI); + + // Build the cloned loop structure itself. This may be substantially + // different from the original structure due to the simplified CFG. This also + // handles inserting all the cloned blocks into the correct loops. + SmallVector<Loop *, 4> NonChildClonedLoops; + Loop *ClonedL = + buildClonedLoops(L, ExitBlocks, VMap, LI, NonChildClonedLoops); + + // Remove the parent as a predecessor of the unswitched successor. + UnswitchedSuccBB->removePredecessor(ParentBB, /*DontDeleteUselessPHIs*/ true); + + // Now splice the branch from the original loop and use it to select between + // the two loops. + SplitBB->getTerminator()->eraseFromParent(); + SplitBB->getInstList().splice(SplitBB->end(), ParentBB->getInstList(), BI); + BI.setSuccessor(ClonedSucc, ClonedPH); + BI.setSuccessor(1 - ClonedSucc, LoopPH); + + // Create a new unconditional branch to the continuing block (as opposed to + // the one cloned). + BranchInst::Create(ContinueSuccBB, ParentBB); + + // Delete anything that was made dead in the original loop due to + // unswitching. + if (DeleteUnswitchedSucc) + deleteDeadBlocksFromLoop(L, UnswitchedSuccBB, ExitBlocks, DT, LI); + + SmallVector<Loop *, 4> HoistedLoops; + bool IsStillLoop = rebuildLoopAfterUnswitch(L, ExitBlocks, LI, HoistedLoops); + + // This will have completely invalidated the dominator tree. We can't easily + // bound how much is invalid because in some cases we will refine the + // predecessor set of exit blocks of the loop which can move large unrelated + // regions of code into a new subtree. + // + // FIXME: Eventually, we should use an incremental update utility that + // leverages the existing information in the dominator tree (and potentially + // the nature of the change) to more efficiently update things. + DT.recalculate(*SplitBB->getParent()); + + // We can change which blocks are exit blocks of all the cloned sibling + // loops, the current loop, and any parent loops which shared exit blocks + // with the current loop. As a consequence, we need to re-form LCSSA for + // them. But we shouldn't need to re-form LCSSA for any child loops. + // FIXME: This could be made more efficient by tracking which exit blocks are + // new, and focusing on them, but that isn't likely to be necessary. + // + // In order to reasonably rebuild LCSSA we need to walk inside-out across the + // loop nest and update every loop that could have had its exits changed. We + // also need to cover any intervening loops. We add all of these loops to + // a list and sort them by loop depth to achieve this without updating + // unnecessary loops. + auto UpdateLCSSA = [&](Loop &UpdateL) { +#ifndef NDEBUG + for (Loop *ChildL : UpdateL) + assert(ChildL->isRecursivelyLCSSAForm(DT, LI) && + "Perturbed a child loop's LCSSA form!"); +#endif + formLCSSA(UpdateL, DT, &LI, nullptr); + }; + + // For non-child cloned loops and hoisted loops, we just need to update LCSSA + // and we can do it in any order as they don't nest relative to each other. + for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) + UpdateLCSSA(*UpdatedL); + + // If the original loop had exit blocks, walk up through the outer most loop + // of those exit blocks to update LCSSA and form updated dedicated exits. + if (OuterExitL != &L) { + SmallVector<Loop *, 4> OuterLoops; + // We start with the cloned loop and the current loop if they are loops and + // move toward OuterExitL. Also, if either the cloned loop or the current + // loop have become top level loops we need to walk all the way out. + if (ClonedL) { + OuterLoops.push_back(ClonedL); + if (!ClonedL->getParentLoop()) + OuterExitL = nullptr; + } + if (IsStillLoop) { + OuterLoops.push_back(&L); + if (!L.getParentLoop()) + OuterExitL = nullptr; + } + // Grab all of the enclosing loops now. + for (Loop *OuterL = ParentL; OuterL != OuterExitL; + OuterL = OuterL->getParentLoop()) + OuterLoops.push_back(OuterL); + + // Finally, update our list of outer loops. This is nicely ordered to work + // inside-out. + for (Loop *OuterL : OuterLoops) { + // First build LCSSA for this loop so that we can preserve it when + // forming dedicated exits. We don't want to perturb some other loop's + // LCSSA while doing that CFG edit. + UpdateLCSSA(*OuterL); + + // For loops reached by this loop's original exit blocks we may + // introduced new, non-dedicated exits. At least try to re-form dedicated + // exits for these loops. This may fail if they couldn't have dedicated + // exits to start with. + formDedicatedExitBlocks(OuterL, &DT, &LI, /*PreserveLCSSA*/ true); + } + } + +#ifndef NDEBUG + // Verify the entire loop structure to catch any incorrect updates before we + // progress in the pass pipeline. + LI.verify(DT); +#endif + + // Now that we've unswitched something, make callbacks to report the changes. + // For that we need to merge together the updated loops and the cloned loops + // and check whether the original loop survived. + SmallVector<Loop *, 4> SibLoops; + for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) + if (UpdatedL->getParentLoop() == ParentL) + SibLoops.push_back(UpdatedL); + NonTrivialUnswitchCB(IsStillLoop, SibLoops); + + ++NumBranches; + return true; +} + +/// Recursively compute the cost of a dominator subtree based on the per-block +/// cost map provided. +/// +/// The recursive computation is memozied into the provided DT-indexed cost map +/// to allow querying it for most nodes in the domtree without it becoming +/// quadratic. +static int +computeDomSubtreeCost(DomTreeNode &N, + const SmallDenseMap<BasicBlock *, int, 4> &BBCostMap, + SmallDenseMap<DomTreeNode *, int, 4> &DTCostMap) { + // Don't accumulate cost (or recurse through) blocks not in our block cost + // map and thus not part of the duplication cost being considered. + auto BBCostIt = BBCostMap.find(N.getBlock()); + if (BBCostIt == BBCostMap.end()) + return 0; + + // Lookup this node to see if we already computed its cost. + auto DTCostIt = DTCostMap.find(&N); + if (DTCostIt != DTCostMap.end()) + return DTCostIt->second; + + // If not, we have to compute it. We can't use insert above and update + // because computing the cost may insert more things into the map. + int Cost = std::accumulate( + N.begin(), N.end(), BBCostIt->second, [&](int Sum, DomTreeNode *ChildN) { + return Sum + computeDomSubtreeCost(*ChildN, BBCostMap, DTCostMap); + }); + bool Inserted = DTCostMap.insert({&N, Cost}).second; + (void)Inserted; + assert(Inserted && "Should not insert a node while visiting children!"); + return Cost; +} + /// Unswitch control flow predicated on loop invariant conditions. /// /// This first hoists all branches or switches which are trivial (IE, do not /// require duplicating any part of the loop) out of the loop body. It then /// looks at other loop invariant control flows and tries to unswitch those as /// well by cloning the loop if the result is small enough. -static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, - AssumptionCache &AC) { - assert(L.isLCSSAForm(DT) && +static bool +unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + TargetTransformInfo &TTI, bool NonTrivial, + function_ref<void(bool, ArrayRef<Loop *>)> NonTrivialUnswitchCB) { + assert(L.isRecursivelyLCSSAForm(DT, LI) && "Loops must be in LCSSA form before unswitching."); bool Changed = false; @@ -727,7 +1926,136 @@ static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, // Try trivial unswitch first before loop over other basic blocks in the loop. Changed |= unswitchAllTrivialConditions(L, DT, LI); - // FIXME: Add support for non-trivial unswitching by cloning the loop. + // If we're not doing non-trivial unswitching, we're done. We both accept + // a parameter but also check a local flag that can be used for testing + // a debugging. + if (!NonTrivial && !EnableNonTrivialUnswitch) + return Changed; + + // Collect all remaining invariant branch conditions within this loop (as + // opposed to an inner loop which would be handled when visiting that inner + // loop). + SmallVector<TerminatorInst *, 4> UnswitchCandidates; + for (auto *BB : L.blocks()) + if (LI.getLoopFor(BB) == &L) + if (auto *BI = dyn_cast<BranchInst>(BB->getTerminator())) + if (BI->isConditional() && L.isLoopInvariant(BI->getCondition()) && + BI->getSuccessor(0) != BI->getSuccessor(1)) + UnswitchCandidates.push_back(BI); + + // If we didn't find any candidates, we're done. + if (UnswitchCandidates.empty()) + return Changed; + + DEBUG(dbgs() << "Considering " << UnswitchCandidates.size() + << " non-trivial loop invariant conditions for unswitching.\n"); + + // Given that unswitching these terminators will require duplicating parts of + // the loop, so we need to be able to model that cost. Compute the ephemeral + // values and set up a data structure to hold per-BB costs. We cache each + // block's cost so that we don't recompute this when considering different + // subsets of the loop for duplication during unswitching. + SmallPtrSet<const Value *, 4> EphValues; + CodeMetrics::collectEphemeralValues(&L, &AC, EphValues); + SmallDenseMap<BasicBlock *, int, 4> BBCostMap; + + // Compute the cost of each block, as well as the total loop cost. Also, bail + // out if we see instructions which are incompatible with loop unswitching + // (convergent, noduplicate, or cross-basic-block tokens). + // FIXME: We might be able to safely handle some of these in non-duplicated + // regions. + int LoopCost = 0; + for (auto *BB : L.blocks()) { + int Cost = 0; + for (auto &I : *BB) { + if (EphValues.count(&I)) + continue; + + if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) + return Changed; + if (auto CS = CallSite(&I)) + if (CS.isConvergent() || CS.cannotDuplicate()) + return Changed; + + Cost += TTI.getUserCost(&I); + } + assert(Cost >= 0 && "Must not have negative costs!"); + LoopCost += Cost; + assert(LoopCost >= 0 && "Must not have negative loop costs!"); + BBCostMap[BB] = Cost; + } + DEBUG(dbgs() << " Total loop cost: " << LoopCost << "\n"); + + // Now we find the best candidate by searching for the one with the following + // properties in order: + // + // 1) An unswitching cost below the threshold + // 2) The smallest number of duplicated unswitch candidates (to avoid + // creating redundant subsequent unswitching) + // 3) The smallest cost after unswitching. + // + // We prioritize reducing fanout of unswitch candidates provided the cost + // remains below the threshold because this has a multiplicative effect. + // + // This requires memoizing each dominator subtree to avoid redundant work. + // + // FIXME: Need to actually do the number of candidates part above. + SmallDenseMap<DomTreeNode *, int, 4> DTCostMap; + // Given a terminator which might be unswitched, computes the non-duplicated + // cost for that terminator. + auto ComputeUnswitchedCost = [&](TerminatorInst *TI) { + BasicBlock &BB = *TI->getParent(); + SmallPtrSet<BasicBlock *, 4> Visited; + + int Cost = LoopCost; + for (BasicBlock *SuccBB : successors(&BB)) { + // Don't count successors more than once. + if (!Visited.insert(SuccBB).second) + continue; + + // This successor's domtree will not need to be duplicated after + // unswitching if the edge to the successor dominates it (and thus the + // entire tree). This essentially means there is no other path into this + // subtree and so it will end up live in only one clone of the loop. + if (SuccBB->getUniquePredecessor() || + llvm::all_of(predecessors(SuccBB), [&](BasicBlock *PredBB) { + return PredBB == &BB || DT.dominates(SuccBB, PredBB); + })) { + Cost -= computeDomSubtreeCost(*DT[SuccBB], BBCostMap, DTCostMap); + assert(Cost >= 0 && + "Non-duplicated cost should never exceed total loop cost!"); + } + } + + // Now scale the cost by the number of unique successors minus one. We + // subtract one because there is already at least one copy of the entire + // loop. This is computing the new cost of unswitching a condition. + assert(Visited.size() > 1 && + "Cannot unswitch a condition without multiple distinct successors!"); + return Cost * (Visited.size() - 1); + }; + TerminatorInst *BestUnswitchTI = nullptr; + int BestUnswitchCost; + for (TerminatorInst *CandidateTI : UnswitchCandidates) { + int CandidateCost = ComputeUnswitchedCost(CandidateTI); + DEBUG(dbgs() << " Computed cost of " << CandidateCost + << " for unswitch candidate: " << *CandidateTI << "\n"); + if (!BestUnswitchTI || CandidateCost < BestUnswitchCost) { + BestUnswitchTI = CandidateTI; + BestUnswitchCost = CandidateCost; + } + } + + if (BestUnswitchCost < UnswitchThreshold) { + DEBUG(dbgs() << " Trying to unswitch non-trivial (cost = " + << BestUnswitchCost << ") branch: " << *BestUnswitchTI + << "\n"); + Changed |= unswitchInvariantBranch(L, cast<BranchInst>(*BestUnswitchTI), DT, + LI, AC, NonTrivialUnswitchCB); + } else { + DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " << BestUnswitchCost + << "\n"); + } return Changed; } @@ -740,7 +2068,25 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); - if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC)) + // Save the current loop name in a variable so that we can report it even + // after it has been deleted. + std::string LoopName = L.getName(); + + auto NonTrivialUnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, + ArrayRef<Loop *> NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + U.addSiblingLoops(NewLoops); + + // If the current loop remains valid, we should revisit it to catch any + // other unswitch opportunities. Otherwise, we need to mark it as deleted. + if (CurrentLoopValid) + U.revisitCurrentLoop(); + else + U.markLoopAsDeleted(L, LoopName); + }; + + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, + NonTrivialUnswitchCB)) return PreservedAnalyses::all(); #ifndef NDEBUG @@ -754,10 +2100,13 @@ PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, namespace { class SimpleLoopUnswitchLegacyPass : public LoopPass { + bool NonTrivial; + public: static char ID; // Pass ID, replacement for typeid - explicit SimpleLoopUnswitchLegacyPass() : LoopPass(ID) { + explicit SimpleLoopUnswitchLegacyPass(bool NonTrivial = false) + : LoopPass(ID), NonTrivial(NonTrivial) { initializeSimpleLoopUnswitchLegacyPassPass( *PassRegistry::getPassRegistry()); } @@ -766,6 +2115,7 @@ public: void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); getLoopAnalysisUsage(AU); } }; @@ -783,8 +2133,29 @@ bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + auto NonTrivialUnswitchCB = [&L, &LPM](bool CurrentLoopValid, + ArrayRef<Loop *> NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + for (auto *NewL : NewLoops) + LPM.addLoop(*NewL); + + // If the current loop remains valid, re-add it to the queue. This is + // a little wasteful as we'll finish processing the current loop as well, + // but it is the best we can do in the old PM. + if (CurrentLoopValid) + LPM.addLoop(*L); + else + LPM.markLoopAsDeleted(*L); + }; + + bool Changed = + unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, NonTrivialUnswitchCB); - bool Changed = unswitchLoop(*L, DT, LI, AC); + // If anything was unswitched, also clear any cached information about this + // loop. + LPM.deleteSimpleAnalysisLoop(L); #ifndef NDEBUG // Historically this pass has had issues with the dominator tree so verify it @@ -798,11 +2169,13 @@ char SimpleLoopUnswitchLegacyPass::ID = 0; INITIALIZE_PASS_BEGIN(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", "Simple unswitch loops", false, false) INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(LoopPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", "Simple unswitch loops", false, false) -Pass *llvm::createSimpleLoopUnswitchLegacyPass() { - return new SimpleLoopUnswitchLegacyPass(); +Pass *llvm::createSimpleLoopUnswitchLegacyPass(bool NonTrivial) { + return new SimpleLoopUnswitchLegacyPass(NonTrivial); } diff --git a/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/lib/Transforms/Scalar/SimplifyCFGPass.cpp index 8754c714c5b2..1522170dc3b9 100644 --- a/lib/Transforms/Scalar/SimplifyCFGPass.cpp +++ b/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -45,9 +45,26 @@ using namespace llvm; #define DEBUG_TYPE "simplifycfg" -static cl::opt<unsigned> -UserBonusInstThreshold("bonus-inst-threshold", cl::Hidden, cl::init(1), - cl::desc("Control the number of bonus instructions (default = 1)")); +static cl::opt<unsigned> UserBonusInstThreshold( + "bonus-inst-threshold", cl::Hidden, cl::init(1), + cl::desc("Control the number of bonus instructions (default = 1)")); + +static cl::opt<bool> UserKeepLoops( + "keep-loops", cl::Hidden, cl::init(true), + cl::desc("Preserve canonical loop structure (default = true)")); + +static cl::opt<bool> UserSwitchToLookup( + "switch-to-lookup", cl::Hidden, cl::init(false), + cl::desc("Convert switches to lookup tables (default = false)")); + +static cl::opt<bool> UserForwardSwitchCond( + "forward-switch-cond", cl::Hidden, cl::init(false), + cl::desc("Forward switch condition to phi ops (default = false)")); + +static cl::opt<bool> UserSinkCommonInsts( + "sink-common-insts", cl::Hidden, cl::init(false), + cl::desc("Sink common instructions (default = false)")); + STATISTIC(NumSimpl, "Number of blocks simplified"); @@ -129,9 +146,7 @@ static bool mergeEmptyReturnBlocks(Function &F) { /// Call SimplifyCFG on all the blocks in the function, /// iterating until no more changes are made. static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, - AssumptionCache *AC, - unsigned BonusInstThreshold, - bool LateSimplifyCFG) { + const SimplifyCFGOptions &Options) { bool Changed = false; bool LocalChange = true; @@ -146,7 +161,7 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, // Loop over all of the basic blocks and remove them if they are unneeded. for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) { - if (SimplifyCFG(&*BBIt++, TTI, BonusInstThreshold, AC, &LoopHeaders, LateSimplifyCFG)) { + if (simplifyCFG(&*BBIt++, TTI, Options, &LoopHeaders)) { LocalChange = true; ++NumSimpl; } @@ -157,12 +172,10 @@ static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, } static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, - AssumptionCache *AC, int BonusInstThreshold, - bool LateSimplifyCFG) { + const SimplifyCFGOptions &Options) { bool EverChanged = removeUnreachableBlocks(F); EverChanged |= mergeEmptyReturnBlocks(F); - EverChanged |= iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold, - LateSimplifyCFG); + EverChanged |= iterativelySimplifyCFG(F, TTI, Options); // If neither pass changed anything, we're done. if (!EverChanged) return false; @@ -176,28 +189,37 @@ static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, return true; do { - EverChanged = iterativelySimplifyCFG(F, TTI, AC, BonusInstThreshold, - LateSimplifyCFG); + EverChanged = iterativelySimplifyCFG(F, TTI, Options); EverChanged |= removeUnreachableBlocks(F); } while (EverChanged); return true; } -SimplifyCFGPass::SimplifyCFGPass() - : BonusInstThreshold(UserBonusInstThreshold), - LateSimplifyCFG(true) {} - -SimplifyCFGPass::SimplifyCFGPass(int BonusInstThreshold, bool LateSimplifyCFG) - : BonusInstThreshold(BonusInstThreshold), - LateSimplifyCFG(LateSimplifyCFG) {} +// Command-line settings override compile-time settings. +SimplifyCFGPass::SimplifyCFGPass(const SimplifyCFGOptions &Opts) { + Options.BonusInstThreshold = UserBonusInstThreshold.getNumOccurrences() + ? UserBonusInstThreshold + : Opts.BonusInstThreshold; + Options.ForwardSwitchCondToPhi = UserForwardSwitchCond.getNumOccurrences() + ? UserForwardSwitchCond + : Opts.ForwardSwitchCondToPhi; + Options.ConvertSwitchToLookupTable = UserSwitchToLookup.getNumOccurrences() + ? UserSwitchToLookup + : Opts.ConvertSwitchToLookupTable; + Options.NeedCanonicalLoop = UserKeepLoops.getNumOccurrences() + ? UserKeepLoops + : Opts.NeedCanonicalLoop; + Options.SinkCommonInsts = UserSinkCommonInsts.getNumOccurrences() + ? UserSinkCommonInsts + : Opts.SinkCommonInsts; +} PreservedAnalyses SimplifyCFGPass::run(Function &F, FunctionAnalysisManager &AM) { auto &TTI = AM.getResult<TargetIRAnalysis>(F); - auto &AC = AM.getResult<AssumptionAnalysis>(F); - - if (!simplifyFunctionCFG(F, TTI, &AC, BonusInstThreshold, LateSimplifyCFG)) + Options.AC = &AM.getResult<AssumptionAnalysis>(F); + if (!simplifyFunctionCFG(F, TTI, Options)) return PreservedAnalyses::all(); PreservedAnalyses PA; PA.preserve<GlobalsAA>(); @@ -205,55 +227,54 @@ PreservedAnalyses SimplifyCFGPass::run(Function &F, } namespace { -struct BaseCFGSimplifyPass : public FunctionPass { - unsigned BonusInstThreshold; +struct CFGSimplifyPass : public FunctionPass { + static char ID; + SimplifyCFGOptions Options; std::function<bool(const Function &)> PredicateFtor; - bool LateSimplifyCFG; - - BaseCFGSimplifyPass(int T, bool LateSimplifyCFG, - std::function<bool(const Function &)> Ftor, - char &ID) - : FunctionPass(ID), PredicateFtor(std::move(Ftor)), - LateSimplifyCFG(LateSimplifyCFG) { - BonusInstThreshold = (T == -1) ? UserBonusInstThreshold : unsigned(T); + + CFGSimplifyPass(unsigned Threshold = 1, bool ForwardSwitchCond = false, + bool ConvertSwitch = false, bool KeepLoops = true, + bool SinkCommon = false, + std::function<bool(const Function &)> Ftor = nullptr) + : FunctionPass(ID), PredicateFtor(std::move(Ftor)) { + + initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); + + // Check for command-line overrides of options for debug/customization. + Options.BonusInstThreshold = UserBonusInstThreshold.getNumOccurrences() + ? UserBonusInstThreshold + : Threshold; + + Options.ForwardSwitchCondToPhi = UserForwardSwitchCond.getNumOccurrences() + ? UserForwardSwitchCond + : ForwardSwitchCond; + + Options.ConvertSwitchToLookupTable = UserSwitchToLookup.getNumOccurrences() + ? UserSwitchToLookup + : ConvertSwitch; + + Options.NeedCanonicalLoop = + UserKeepLoops.getNumOccurrences() ? UserKeepLoops : KeepLoops; + + Options.SinkCommonInsts = UserSinkCommonInsts.getNumOccurrences() + ? UserSinkCommonInsts + : SinkCommon; } + bool runOnFunction(Function &F) override { if (skipFunction(F) || (PredicateFtor && !PredicateFtor(F))) return false; - AssumptionCache *AC = - &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); - const TargetTransformInfo &TTI = - getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - return simplifyFunctionCFG(F, TTI, AC, BonusInstThreshold, LateSimplifyCFG); + Options.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return simplifyFunctionCFG(F, TTI, Options); } - void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<AssumptionCacheTracker>(); AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } }; - -struct CFGSimplifyPass : public BaseCFGSimplifyPass { - static char ID; // Pass identification, replacement for typeid - - CFGSimplifyPass(int T = -1, - std::function<bool(const Function &)> Ftor = nullptr) - : BaseCFGSimplifyPass(T, false, Ftor, ID) { - initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); - } -}; - -struct LateCFGSimplifyPass : public BaseCFGSimplifyPass { - static char ID; // Pass identification, replacement for typeid - - LateCFGSimplifyPass(int T = -1, - std::function<bool(const Function &)> Ftor = nullptr) - : BaseCFGSimplifyPass(T, true, Ftor, ID) { - initializeLateCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); - } -}; } char CFGSimplifyPass::ID = 0; @@ -264,24 +285,12 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, false) -char LateCFGSimplifyPass::ID = 0; -INITIALIZE_PASS_BEGIN(LateCFGSimplifyPass, "latesimplifycfg", - "Simplify the CFG more aggressively", false, false) -INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) -INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) -INITIALIZE_PASS_END(LateCFGSimplifyPass, "latesimplifycfg", - "Simplify the CFG more aggressively", false, false) - // Public interface to the CFGSimplification pass FunctionPass * -llvm::createCFGSimplificationPass(int Threshold, - std::function<bool(const Function &)> Ftor) { - return new CFGSimplifyPass(Threshold, std::move(Ftor)); -} - -// Public interface to the LateCFGSimplification pass -FunctionPass * -llvm::createLateCFGSimplificationPass(int Threshold, +llvm::createCFGSimplificationPass(unsigned Threshold, bool ForwardSwitchCond, + bool ConvertSwitch, bool KeepLoops, + bool SinkCommon, std::function<bool(const Function &)> Ftor) { - return new LateCFGSimplifyPass(Threshold, std::move(Ftor)); + return new CFGSimplifyPass(Threshold, ForwardSwitchCond, ConvertSwitch, + KeepLoops, SinkCommon, std::move(Ftor)); } diff --git a/lib/Transforms/Scalar/Sink.cpp b/lib/Transforms/Scalar/Sink.cpp index 5210f165b874..cfb8a062299f 100644 --- a/lib/Transforms/Scalar/Sink.cpp +++ b/lib/Transforms/Scalar/Sink.cpp @@ -68,7 +68,7 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis &AA, if (LoadInst *L = dyn_cast<LoadInst>(Inst)) { MemoryLocation Loc = MemoryLocation::get(L); for (Instruction *S : Stores) - if (AA.getModRefInfo(S, Loc) & MRI_Mod) + if (isModSet(AA.getModRefInfo(S, Loc))) return false; } @@ -83,7 +83,7 @@ static bool isSafeToMove(Instruction *Inst, AliasAnalysis &AA, return false; for (Instruction *S : Stores) - if (AA.getModRefInfo(S, CS) & MRI_Mod) + if (isModSet(AA.getModRefInfo(S, CS))) return false; } diff --git a/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp new file mode 100644 index 000000000000..23156d5a4d83 --- /dev/null +++ b/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -0,0 +1,811 @@ +//===- SpeculateAroundPHIs.cpp --------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/SpeculateAroundPHIs.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "spec-phis" + +STATISTIC(NumPHIsSpeculated, "Number of PHI nodes we speculated around"); +STATISTIC(NumEdgesSplit, + "Number of critical edges which were split for speculation"); +STATISTIC(NumSpeculatedInstructions, + "Number of instructions we speculated around the PHI nodes"); +STATISTIC(NumNewRedundantInstructions, + "Number of new, redundant instructions inserted"); + +/// Check wether speculating the users of a PHI node around the PHI +/// will be safe. +/// +/// This checks both that all of the users are safe and also that all of their +/// operands are either recursively safe or already available along an incoming +/// edge to the PHI. +/// +/// This routine caches both all the safe nodes explored in `PotentialSpecSet` +/// and the chain of nodes that definitively reach any unsafe node in +/// `UnsafeSet`. By preserving these between repeated calls to this routine for +/// PHIs in the same basic block, the exploration here can be reused. However, +/// these caches must no be reused for PHIs in a different basic block as they +/// reflect what is available along incoming edges. +static bool +isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT, + SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + SmallPtrSetImpl<Instruction *> &UnsafeSet) { + auto *PhiBB = PN.getParent(); + SmallPtrSet<Instruction *, 4> Visited; + SmallVector<std::pair<Instruction *, User::value_op_iterator>, 16> DFSStack; + + // Walk each user of the PHI node. + for (Use &U : PN.uses()) { + auto *UI = cast<Instruction>(U.getUser()); + + // Ensure the use post-dominates the PHI node. This ensures that, in the + // absence of unwinding, the use will actually be reached. + // FIXME: We use a blunt hammer of requiring them to be in the same basic + // block. We should consider using actual post-dominance here in the + // future. + if (UI->getParent() != PhiBB) { + DEBUG(dbgs() << " Unsafe: use in a different BB: " << *UI << "\n"); + return false; + } + + // FIXME: This check is much too conservative. We're not going to move these + // instructions onto new dynamic paths through the program unless there is + // a call instruction between the use and the PHI node. And memory isn't + // changing unless there is a store in that same sequence. We should + // probably change this to do at least a limited scan of the intervening + // instructions and allow handling stores in easily proven safe cases. + if (mayBeMemoryDependent(*UI)) { + DEBUG(dbgs() << " Unsafe: can't speculate use: " << *UI << "\n"); + return false; + } + + // Now do a depth-first search of everything these users depend on to make + // sure they are transitively safe. This is a depth-first search, but we + // check nodes in preorder to minimize the amount of checking. + Visited.insert(UI); + DFSStack.push_back({UI, UI->value_op_begin()}); + do { + User::value_op_iterator OpIt; + std::tie(UI, OpIt) = DFSStack.pop_back_val(); + + while (OpIt != UI->value_op_end()) { + auto *OpI = dyn_cast<Instruction>(*OpIt); + // Increment to the next operand for whenever we continue. + ++OpIt; + // No need to visit non-instructions, which can't form dependencies. + if (!OpI) + continue; + + // Now do the main pre-order checks that this operand is a viable + // dependency of something we want to speculate. + + // First do a few checks for instructions that won't require + // speculation at all because they are trivially available on the + // incoming edge (either through dominance or through an incoming value + // to a PHI). + // + // The cases in the current block will be trivially dominated by the + // edge. + auto *ParentBB = OpI->getParent(); + if (ParentBB == PhiBB) { + if (isa<PHINode>(OpI)) { + // We can trivially map through phi nodes in the same block. + continue; + } + } else if (DT.dominates(ParentBB, PhiBB)) { + // Instructions from dominating blocks are already available. + continue; + } + + // Once we know that we're considering speculating the operand, check + // if we've already explored this subgraph and found it to be safe. + if (PotentialSpecSet.count(OpI)) + continue; + + // If we've already explored this subgraph and found it unsafe, bail. + // If when we directly test whether this is safe it fails, bail. + if (UnsafeSet.count(OpI) || ParentBB != PhiBB || + mayBeMemoryDependent(*OpI)) { + DEBUG(dbgs() << " Unsafe: can't speculate transitive use: " << *OpI + << "\n"); + // Record the stack of instructions which reach this node as unsafe + // so we prune subsequent searches. + UnsafeSet.insert(OpI); + for (auto &StackPair : DFSStack) { + Instruction *I = StackPair.first; + UnsafeSet.insert(I); + } + return false; + } + + // Skip any operands we're already recursively checking. + if (!Visited.insert(OpI).second) + continue; + + // Push onto the stack and descend. We can directly continue this + // loop when ascending. + DFSStack.push_back({UI, OpIt}); + UI = OpI; + OpIt = OpI->value_op_begin(); + } + + // This node and all its operands are safe. Go ahead and cache that for + // reuse later. + PotentialSpecSet.insert(UI); + + // Continue with the next node on the stack. + } while (!DFSStack.empty()); + } + +#ifndef NDEBUG + // Every visited operand should have been marked as safe for speculation at + // this point. Verify this and return success. + for (auto *I : Visited) + assert(PotentialSpecSet.count(I) && + "Failed to mark a visited instruction as safe!"); +#endif + return true; +} + +/// Check whether, in isolation, a given PHI node is both safe and profitable +/// to speculate users around. +/// +/// This handles checking whether there are any constant operands to a PHI +/// which could represent a useful speculation candidate, whether the users of +/// the PHI are safe to speculate including all their transitive dependencies, +/// and whether after speculation there will be some cost savings (profit) to +/// folding the operands into the users of the PHI node. Returns true if both +/// safe and profitable with relevant cost savings updated in the map and with +/// an update to the `PotentialSpecSet`. Returns false if either safety or +/// profitability are absent. Some new entries may be made to the +/// `PotentialSpecSet` even when this routine returns false, but they remain +/// conservatively correct. +/// +/// The profitability check here is a local one, but it checks this in an +/// interesting way. Beyond checking that the total cost of materializing the +/// constants will be less than the cost of folding them into their users, it +/// also checks that no one incoming constant will have a higher cost when +/// folded into its users rather than materialized. This higher cost could +/// result in a dynamic *path* that is more expensive even when the total cost +/// is lower. Currently, all of the interesting cases where this optimization +/// should fire are ones where it is a no-loss operation in this sense. If we +/// ever want to be more aggressive here, we would need to balance the +/// different incoming edges' cost by looking at their respective +/// probabilities. +static bool isSafeAndProfitableToSpeculateAroundPHI( + PHINode &PN, SmallDenseMap<PHINode *, int, 16> &CostSavingsMap, + SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + SmallPtrSetImpl<Instruction *> &UnsafeSet, DominatorTree &DT, + TargetTransformInfo &TTI) { + // First see whether there is any cost savings to speculating around this + // PHI, and build up a map of the constant inputs to how many times they + // occur. + bool NonFreeMat = false; + struct CostsAndCount { + int MatCost = TargetTransformInfo::TCC_Free; + int FoldedCost = TargetTransformInfo::TCC_Free; + int Count = 0; + }; + SmallDenseMap<ConstantInt *, CostsAndCount, 16> CostsAndCounts; + SmallPtrSet<BasicBlock *, 16> IncomingConstantBlocks; + for (int i : llvm::seq<int>(0, PN.getNumIncomingValues())) { + auto *IncomingC = dyn_cast<ConstantInt>(PN.getIncomingValue(i)); + if (!IncomingC) + continue; + + // Only visit each incoming edge with a constant input once. + if (!IncomingConstantBlocks.insert(PN.getIncomingBlock(i)).second) + continue; + + auto InsertResult = CostsAndCounts.insert({IncomingC, {}}); + // Count how many edges share a given incoming costant. + ++InsertResult.first->second.Count; + // Only compute the cost the first time we see a particular constant. + if (!InsertResult.second) + continue; + + int &MatCost = InsertResult.first->second.MatCost; + MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType()); + NonFreeMat |= MatCost != TTI.TCC_Free; + } + if (!NonFreeMat) { + DEBUG(dbgs() << " Free: " << PN << "\n"); + // No profit in free materialization. + return false; + } + + // Now check that the uses of this PHI can actually be speculated, + // otherwise we'll still have to materialize the PHI value. + if (!isSafeToSpeculatePHIUsers(PN, DT, PotentialSpecSet, UnsafeSet)) { + DEBUG(dbgs() << " Unsafe PHI: " << PN << "\n"); + return false; + } + + // Compute how much (if any) savings are available by speculating around this + // PHI. + for (Use &U : PN.uses()) { + auto *UserI = cast<Instruction>(U.getUser()); + // Now check whether there is any savings to folding the incoming constants + // into this use. + unsigned Idx = U.getOperandNo(); + + // If we have a binary operator that is commutative, an actual constant + // operand would end up on the RHS, so pretend the use of the PHI is on the + // RHS. + // + // Technically, this is a bit weird if *both* operands are PHIs we're + // speculating. But if that is the case, giving an "optimistic" cost isn't + // a bad thing because after speculation it will constant fold. And + // moreover, such cases should likely have been constant folded already by + // some other pass, so we shouldn't worry about "modeling" them terribly + // accurately here. Similarly, if the other operand is a constant, it still + // seems fine to be "optimistic" in our cost modeling, because when the + // incoming operand from the PHI node is also a constant, we will end up + // constant folding. + if (UserI->isBinaryOp() && UserI->isCommutative() && Idx != 1) + // Assume we will commute the constant to the RHS to be canonical. + Idx = 1; + + // Get the intrinsic ID if this user is an instrinsic. + Intrinsic::ID IID = Intrinsic::not_intrinsic; + if (auto *UserII = dyn_cast<IntrinsicInst>(UserI)) + IID = UserII->getIntrinsicID(); + + for (auto &IncomingConstantAndCostsAndCount : CostsAndCounts) { + ConstantInt *IncomingC = IncomingConstantAndCostsAndCount.first; + int MatCost = IncomingConstantAndCostsAndCount.second.MatCost; + int &FoldedCost = IncomingConstantAndCostsAndCount.second.FoldedCost; + if (IID) + FoldedCost += TTI.getIntImmCost(IID, Idx, IncomingC->getValue(), + IncomingC->getType()); + else + FoldedCost += + TTI.getIntImmCost(UserI->getOpcode(), Idx, IncomingC->getValue(), + IncomingC->getType()); + + // If we accumulate more folded cost for this incoming constant than + // materialized cost, then we'll regress any edge with this constant so + // just bail. We're only interested in cases where folding the incoming + // constants is at least break-even on all paths. + if (FoldedCost > MatCost) { + DEBUG(dbgs() << " Not profitable to fold imm: " << *IncomingC << "\n" + " Materializing cost: " << MatCost << "\n" + " Accumulated folded cost: " << FoldedCost << "\n"); + return false; + } + } + } + + // Compute the total cost savings afforded by this PHI node. + int TotalMatCost = TTI.TCC_Free, TotalFoldedCost = TTI.TCC_Free; + for (auto IncomingConstantAndCostsAndCount : CostsAndCounts) { + int MatCost = IncomingConstantAndCostsAndCount.second.MatCost; + int FoldedCost = IncomingConstantAndCostsAndCount.second.FoldedCost; + int Count = IncomingConstantAndCostsAndCount.second.Count; + + TotalMatCost += MatCost * Count; + TotalFoldedCost += FoldedCost * Count; + } + assert(TotalFoldedCost <= TotalMatCost && "If each constant's folded cost is " + "less that its materialized cost, " + "the sum must be as well."); + + DEBUG(dbgs() << " Cost savings " << (TotalMatCost - TotalFoldedCost) + << ": " << PN << "\n"); + CostSavingsMap[&PN] = TotalMatCost - TotalFoldedCost; + return true; +} + +/// Simple helper to walk all the users of a list of phis depth first, and call +/// a visit function on each one in post-order. +/// +/// All of the PHIs should be in the same basic block, and this is primarily +/// used to make a single depth-first walk across their collective users +/// without revisiting any subgraphs. Callers should provide a fast, idempotent +/// callable to test whether a node has been visited and the more important +/// callable to actually visit a particular node. +/// +/// Depth-first and postorder here refer to the *operand* graph -- we start +/// from a collection of users of PHI nodes and walk "up" the operands +/// depth-first. +template <typename IsVisitedT, typename VisitT> +static void visitPHIUsersAndDepsInPostOrder(ArrayRef<PHINode *> PNs, + IsVisitedT IsVisited, + VisitT Visit) { + SmallVector<std::pair<Instruction *, User::value_op_iterator>, 16> DFSStack; + for (auto *PN : PNs) + for (Use &U : PN->uses()) { + auto *UI = cast<Instruction>(U.getUser()); + if (IsVisited(UI)) + // Already visited this user, continue across the roots. + continue; + + // Otherwise, walk the operand graph depth-first and visit each + // dependency in postorder. + DFSStack.push_back({UI, UI->value_op_begin()}); + do { + User::value_op_iterator OpIt; + std::tie(UI, OpIt) = DFSStack.pop_back_val(); + while (OpIt != UI->value_op_end()) { + auto *OpI = dyn_cast<Instruction>(*OpIt); + // Increment to the next operand for whenever we continue. + ++OpIt; + // No need to visit non-instructions, which can't form dependencies, + // or instructions outside of our potential dependency set that we + // were given. Finally, if we've already visited the node, continue + // to the next. + if (!OpI || IsVisited(OpI)) + continue; + + // Push onto the stack and descend. We can directly continue this + // loop when ascending. + DFSStack.push_back({UI, OpIt}); + UI = OpI; + OpIt = OpI->value_op_begin(); + } + + // Finished visiting children, visit this node. + assert(!IsVisited(UI) && "Should not have already visited a node!"); + Visit(UI); + } while (!DFSStack.empty()); + } +} + +/// Find profitable PHIs to speculate. +/// +/// For a PHI node to be profitable, we need the cost of speculating its users +/// (and their dependencies) to not exceed the savings of folding the PHI's +/// constant operands into the speculated users. +/// +/// Computing this is surprisingly challenging. Because users of two different +/// PHI nodes can depend on each other or on common other instructions, it may +/// be profitable to speculate two PHI nodes together even though neither one +/// in isolation is profitable. The straightforward way to find all the +/// profitable PHIs would be to check each combination of PHIs' cost, but this +/// is exponential in complexity. +/// +/// Even if we assume that we only care about cases where we can consider each +/// PHI node in isolation (rather than considering cases where none are +/// profitable in isolation but some subset are profitable as a set), we still +/// have a challenge. The obvious way to find all individually profitable PHIs +/// is to iterate until reaching a fixed point, but this will be quadratic in +/// complexity. =/ +/// +/// This code currently uses a linear-to-compute order for a greedy approach. +/// It won't find cases where a set of PHIs must be considered together, but it +/// handles most cases of order dependence without quadratic iteration. The +/// specific order used is the post-order across the operand DAG. When the last +/// user of a PHI is visited in this postorder walk, we check it for +/// profitability. +/// +/// There is an orthogonal extra complexity to all of this: computing the cost +/// itself can easily become a linear computation making everything again (at +/// best) quadratic. Using a postorder over the operand graph makes it +/// particularly easy to avoid this through dynamic programming. As we do the +/// postorder walk, we build the transitive cost of that subgraph. It is also +/// straightforward to then update these costs when we mark a PHI for +/// speculation so that subsequent PHIs don't re-pay the cost of already +/// speculated instructions. +static SmallVector<PHINode *, 16> +findProfitablePHIs(ArrayRef<PHINode *> PNs, + const SmallDenseMap<PHINode *, int, 16> &CostSavingsMap, + const SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + int NumPreds, DominatorTree &DT, TargetTransformInfo &TTI) { + SmallVector<PHINode *, 16> SpecPNs; + + // First, establish a reverse mapping from immediate users of the PHI nodes + // to the nodes themselves, and count how many users each PHI node has in + // a way we can update while processing them. + SmallDenseMap<Instruction *, TinyPtrVector<PHINode *>, 16> UserToPNMap; + SmallDenseMap<PHINode *, int, 16> PNUserCountMap; + SmallPtrSet<Instruction *, 16> UserSet; + for (auto *PN : PNs) { + assert(UserSet.empty() && "Must start with an empty user set!"); + for (Use &U : PN->uses()) + UserSet.insert(cast<Instruction>(U.getUser())); + PNUserCountMap[PN] = UserSet.size(); + for (auto *UI : UserSet) + UserToPNMap.insert({UI, {}}).first->second.push_back(PN); + UserSet.clear(); + } + + // Now do a DFS across the operand graph of the users, computing cost as we + // go and when all costs for a given PHI are known, checking that PHI for + // profitability. + SmallDenseMap<Instruction *, int, 16> SpecCostMap; + visitPHIUsersAndDepsInPostOrder( + PNs, + /*IsVisited*/ + [&](Instruction *I) { + // We consider anything that isn't potentially speculated to be + // "visited" as it is already handled. Similarly, anything that *is* + // potentially speculated but for which we have an entry in our cost + // map, we're done. + return !PotentialSpecSet.count(I) || SpecCostMap.count(I); + }, + /*Visit*/ + [&](Instruction *I) { + // We've fully visited the operands, so sum their cost with this node + // and update the cost map. + int Cost = TTI.TCC_Free; + for (Value *OpV : I->operand_values()) + if (auto *OpI = dyn_cast<Instruction>(OpV)) { + auto CostMapIt = SpecCostMap.find(OpI); + if (CostMapIt != SpecCostMap.end()) + Cost += CostMapIt->second; + } + Cost += TTI.getUserCost(I); + bool Inserted = SpecCostMap.insert({I, Cost}).second; + (void)Inserted; + assert(Inserted && "Must not re-insert a cost during the DFS!"); + + // Now check if this node had a corresponding PHI node using it. If so, + // we need to decrement the outstanding user count for it. + auto UserPNsIt = UserToPNMap.find(I); + if (UserPNsIt == UserToPNMap.end()) + return; + auto &UserPNs = UserPNsIt->second; + auto UserPNsSplitIt = std::stable_partition( + UserPNs.begin(), UserPNs.end(), [&](PHINode *UserPN) { + int &PNUserCount = PNUserCountMap.find(UserPN)->second; + assert( + PNUserCount > 0 && + "Should never re-visit a PN after its user count hits zero!"); + --PNUserCount; + return PNUserCount != 0; + }); + + // FIXME: Rather than one at a time, we should sum the savings as the + // cost will be completely shared. + SmallVector<Instruction *, 16> SpecWorklist; + for (auto *PN : llvm::make_range(UserPNsSplitIt, UserPNs.end())) { + int SpecCost = TTI.TCC_Free; + for (Use &U : PN->uses()) + SpecCost += + SpecCostMap.find(cast<Instruction>(U.getUser()))->second; + SpecCost *= (NumPreds - 1); + // When the user count of a PHI node hits zero, we should check its + // profitability. If profitable, we should mark it for speculation + // and zero out the cost of everything it depends on. + int CostSavings = CostSavingsMap.find(PN)->second; + if (SpecCost > CostSavings) { + DEBUG(dbgs() << " Not profitable, speculation cost: " << *PN << "\n" + " Cost savings: " << CostSavings << "\n" + " Speculation cost: " << SpecCost << "\n"); + continue; + } + + // We're going to speculate this user-associated PHI. Copy it out and + // add its users to the worklist to update their cost. + SpecPNs.push_back(PN); + for (Use &U : PN->uses()) { + auto *UI = cast<Instruction>(U.getUser()); + auto CostMapIt = SpecCostMap.find(UI); + if (CostMapIt->second == 0) + continue; + // Zero out this cost entry to avoid duplicates. + CostMapIt->second = 0; + SpecWorklist.push_back(UI); + } + } + + // Now walk all the operands of the users in the worklist transitively + // to zero out all the memoized costs. + while (!SpecWorklist.empty()) { + Instruction *SpecI = SpecWorklist.pop_back_val(); + assert(SpecCostMap.find(SpecI)->second == 0 && + "Didn't zero out a cost!"); + + // Walk the operands recursively to zero out their cost as well. + for (auto *OpV : SpecI->operand_values()) { + auto *OpI = dyn_cast<Instruction>(OpV); + if (!OpI) + continue; + auto CostMapIt = SpecCostMap.find(OpI); + if (CostMapIt == SpecCostMap.end() || CostMapIt->second == 0) + continue; + CostMapIt->second = 0; + SpecWorklist.push_back(OpI); + } + } + }); + + return SpecPNs; +} + +/// Speculate users around a set of PHI nodes. +/// +/// This routine does the actual speculation around a set of PHI nodes where we +/// have determined this to be both safe and profitable. +/// +/// This routine handles any spliting of critical edges necessary to create +/// a safe block to speculate into as well as cloning the instructions and +/// rewriting all uses. +static void speculatePHIs(ArrayRef<PHINode *> SpecPNs, + SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + SmallSetVector<BasicBlock *, 16> &PredSet, + DominatorTree &DT) { + DEBUG(dbgs() << " Speculating around " << SpecPNs.size() << " PHIs!\n"); + NumPHIsSpeculated += SpecPNs.size(); + + // Split any critical edges so that we have a block to hoist into. + auto *ParentBB = SpecPNs[0]->getParent(); + SmallVector<BasicBlock *, 16> SpecPreds; + SpecPreds.reserve(PredSet.size()); + for (auto *PredBB : PredSet) { + auto *NewPredBB = SplitCriticalEdge( + PredBB, ParentBB, + CriticalEdgeSplittingOptions(&DT).setMergeIdenticalEdges()); + if (NewPredBB) { + ++NumEdgesSplit; + DEBUG(dbgs() << " Split critical edge from: " << PredBB->getName() + << "\n"); + SpecPreds.push_back(NewPredBB); + } else { + assert(PredBB->getSingleSuccessor() == ParentBB && + "We need a non-critical predecessor to speculate into."); + assert(!isa<InvokeInst>(PredBB->getTerminator()) && + "Cannot have a non-critical invoke!"); + + // Already non-critical, use existing pred. + SpecPreds.push_back(PredBB); + } + } + + SmallPtrSet<Instruction *, 16> SpecSet; + SmallVector<Instruction *, 16> SpecList; + visitPHIUsersAndDepsInPostOrder(SpecPNs, + /*IsVisited*/ + [&](Instruction *I) { + // This is visited if we don't need to + // speculate it or we already have + // speculated it. + return !PotentialSpecSet.count(I) || + SpecSet.count(I); + }, + /*Visit*/ + [&](Instruction *I) { + // All operands scheduled, schedule this + // node. + SpecSet.insert(I); + SpecList.push_back(I); + }); + + int NumSpecInsts = SpecList.size() * SpecPreds.size(); + int NumRedundantInsts = NumSpecInsts - SpecList.size(); + DEBUG(dbgs() << " Inserting " << NumSpecInsts << " speculated instructions, " + << NumRedundantInsts << " redundancies\n"); + NumSpeculatedInstructions += NumSpecInsts; + NumNewRedundantInstructions += NumRedundantInsts; + + // Each predecessor is numbered by its index in `SpecPreds`, so for each + // instruction we speculate, the speculated instruction is stored in that + // index of the vector asosciated with the original instruction. We also + // store the incoming values for each predecessor from any PHIs used. + SmallDenseMap<Instruction *, SmallVector<Value *, 2>, 16> SpeculatedValueMap; + + // Inject the synthetic mappings to rewrite PHIs to the appropriate incoming + // value. This handles both the PHIs we are speculating around and any other + // PHIs that happen to be used. + for (auto *OrigI : SpecList) + for (auto *OpV : OrigI->operand_values()) { + auto *OpPN = dyn_cast<PHINode>(OpV); + if (!OpPN || OpPN->getParent() != ParentBB) + continue; + + auto InsertResult = SpeculatedValueMap.insert({OpPN, {}}); + if (!InsertResult.second) + continue; + + auto &SpeculatedVals = InsertResult.first->second; + + // Populating our structure for mapping is particularly annoying because + // finding an incoming value for a particular predecessor block in a PHI + // node is a linear time operation! To avoid quadratic behavior, we build + // a map for this PHI node's incoming values and then translate it into + // the more compact representation used below. + SmallDenseMap<BasicBlock *, Value *, 16> IncomingValueMap; + for (int i : llvm::seq<int>(0, OpPN->getNumIncomingValues())) + IncomingValueMap[OpPN->getIncomingBlock(i)] = OpPN->getIncomingValue(i); + + for (auto *PredBB : SpecPreds) + SpeculatedVals.push_back(IncomingValueMap.find(PredBB)->second); + } + + // Speculate into each predecessor. + for (int PredIdx : llvm::seq<int>(0, SpecPreds.size())) { + auto *PredBB = SpecPreds[PredIdx]; + assert(PredBB->getSingleSuccessor() == ParentBB && + "We need a non-critical predecessor to speculate into."); + + for (auto *OrigI : SpecList) { + auto *NewI = OrigI->clone(); + NewI->setName(Twine(OrigI->getName()) + "." + Twine(PredIdx)); + NewI->insertBefore(PredBB->getTerminator()); + + // Rewrite all the operands to the previously speculated instructions. + // Because we're walking in-order, the defs must precede the uses and we + // should already have these mappings. + for (Use &U : NewI->operands()) { + auto *OpI = dyn_cast<Instruction>(U.get()); + if (!OpI) + continue; + auto MapIt = SpeculatedValueMap.find(OpI); + if (MapIt == SpeculatedValueMap.end()) + continue; + const auto &SpeculatedVals = MapIt->second; + assert(SpeculatedVals[PredIdx] && + "Must have a speculated value for this predecessor!"); + assert(SpeculatedVals[PredIdx]->getType() == OpI->getType() && + "Speculated value has the wrong type!"); + + // Rewrite the use to this predecessor's speculated instruction. + U.set(SpeculatedVals[PredIdx]); + } + + // Commute instructions which now have a constant in the LHS but not the + // RHS. + if (NewI->isBinaryOp() && NewI->isCommutative() && + isa<Constant>(NewI->getOperand(0)) && + !isa<Constant>(NewI->getOperand(1))) + NewI->getOperandUse(0).swap(NewI->getOperandUse(1)); + + SpeculatedValueMap[OrigI].push_back(NewI); + assert(SpeculatedValueMap[OrigI][PredIdx] == NewI && + "Mismatched speculated instruction index!"); + } + } + + // Walk the speculated instruction list and if they have uses, insert a PHI + // for them from the speculated versions, and replace the uses with the PHI. + // Then erase the instructions as they have been fully speculated. The walk + // needs to be in reverse so that we don't think there are users when we'll + // actually eventually remove them later. + IRBuilder<> IRB(SpecPNs[0]); + for (auto *OrigI : llvm::reverse(SpecList)) { + // Check if we need a PHI for any remaining users and if so, insert it. + if (!OrigI->use_empty()) { + auto *SpecIPN = IRB.CreatePHI(OrigI->getType(), SpecPreds.size(), + Twine(OrigI->getName()) + ".phi"); + // Add the incoming values we speculated. + auto &SpeculatedVals = SpeculatedValueMap.find(OrigI)->second; + for (int PredIdx : llvm::seq<int>(0, SpecPreds.size())) + SpecIPN->addIncoming(SpeculatedVals[PredIdx], SpecPreds[PredIdx]); + + // And replace the uses with the PHI node. + OrigI->replaceAllUsesWith(SpecIPN); + } + + // It is important to immediately erase this so that it stops using other + // instructions. This avoids inserting needless PHIs of them. + OrigI->eraseFromParent(); + } + + // All of the uses of the speculated phi nodes should be removed at this + // point, so erase them. + for (auto *SpecPN : SpecPNs) { + assert(SpecPN->use_empty() && "All users should have been speculated!"); + SpecPN->eraseFromParent(); + } +} + +/// Try to speculate around a series of PHIs from a single basic block. +/// +/// This routine checks whether any of these PHIs are profitable to speculate +/// users around. If safe and profitable, it does the speculation. It returns +/// true when at least some speculation occurs. +static bool tryToSpeculatePHIs(SmallVectorImpl<PHINode *> &PNs, + DominatorTree &DT, TargetTransformInfo &TTI) { + DEBUG(dbgs() << "Evaluating phi nodes for speculation:\n"); + + // Savings in cost from speculating around a PHI node. + SmallDenseMap<PHINode *, int, 16> CostSavingsMap; + + // Remember the set of instructions that are candidates for speculation so + // that we can quickly walk things within that space. This prunes out + // instructions already available along edges, etc. + SmallPtrSet<Instruction *, 16> PotentialSpecSet; + + // Remember the set of instructions that are (transitively) unsafe to + // speculate into the incoming edges of this basic block. This avoids + // recomputing them for each PHI node we check. This set is specific to this + // block though as things are pruned out of it based on what is available + // along incoming edges. + SmallPtrSet<Instruction *, 16> UnsafeSet; + + // For each PHI node in this block, check whether there are immediate folding + // opportunities from speculation, and whether that speculation will be + // valid. This determise the set of safe PHIs to speculate. + PNs.erase(llvm::remove_if(PNs, + [&](PHINode *PN) { + return !isSafeAndProfitableToSpeculateAroundPHI( + *PN, CostSavingsMap, PotentialSpecSet, + UnsafeSet, DT, TTI); + }), + PNs.end()); + // If no PHIs were profitable, skip. + if (PNs.empty()) { + DEBUG(dbgs() << " No safe and profitable PHIs found!\n"); + return false; + } + + // We need to know how much speculation will cost which is determined by how + // many incoming edges will need a copy of each speculated instruction. + SmallSetVector<BasicBlock *, 16> PredSet; + for (auto *PredBB : PNs[0]->blocks()) { + if (!PredSet.insert(PredBB)) + continue; + + // We cannot speculate when a predecessor is an indirect branch. + // FIXME: We also can't reliably create a non-critical edge block for + // speculation if the predecessor is an invoke. This doesn't seem + // fundamental and we should probably be splitting critical edges + // differently. + if (isa<IndirectBrInst>(PredBB->getTerminator()) || + isa<InvokeInst>(PredBB->getTerminator())) { + DEBUG(dbgs() << " Invalid: predecessor terminator: " << PredBB->getName() + << "\n"); + return false; + } + } + if (PredSet.size() < 2) { + DEBUG(dbgs() << " Unimportant: phi with only one predecessor\n"); + return false; + } + + SmallVector<PHINode *, 16> SpecPNs = findProfitablePHIs( + PNs, CostSavingsMap, PotentialSpecSet, PredSet.size(), DT, TTI); + if (SpecPNs.empty()) + // Nothing to do. + return false; + + speculatePHIs(SpecPNs, PotentialSpecSet, PredSet, DT); + return true; +} + +PreservedAnalyses SpeculateAroundPHIsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + + bool Changed = false; + for (auto *BB : ReversePostOrderTraversal<Function *>(&F)) { + SmallVector<PHINode *, 16> PNs; + auto BBI = BB->begin(); + while (auto *PN = dyn_cast<PHINode>(&*BBI)) { + PNs.push_back(PN); + ++BBI; + } + + if (PNs.empty()) + continue; + + Changed |= tryToSpeculatePHIs(PNs, DT, TTI); + } + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + return PA; +} diff --git a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp index 8b8d6590aa6a..ce40af1223f6 100644 --- a/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp +++ b/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -1,4 +1,4 @@ -//===-- StraightLineStrengthReduce.cpp - ------------------------*- C++ -*-===// +//===- StraightLineStrengthReduce.cpp - -----------------------------------===// // // The LLVM Compiler Infrastructure // @@ -55,26 +55,45 @@ // // - When (i' - i) is constant but i and i' are not, we could still perform // SLSR. + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" -#include "llvm/Support/raw_ostream.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> +#include <limits> #include <list> #include <vector> using namespace llvm; using namespace PatternMatch; -namespace { +static const unsigned UnknownAddressSpace = + std::numeric_limits<unsigned>::max(); -static const unsigned UnknownAddressSpace = ~0u; +namespace { class StraightLineStrengthReduce : public FunctionPass { public: @@ -88,20 +107,22 @@ public: GEP, // &B[..][i * S][..] }; - Candidate() - : CandidateKind(Invalid), Base(nullptr), Index(nullptr), - Stride(nullptr), Ins(nullptr), Basis(nullptr) {} + Candidate() = default; Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, Instruction *I) - : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I), - Basis(nullptr) {} - Kind CandidateKind; - const SCEV *Base; + : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I) {} + + Kind CandidateKind = Invalid; + + const SCEV *Base = nullptr; + // Note that Index and Stride of a GEP candidate do not necessarily have the // same integer type. In that case, during rewriting, Stride will be // sign-extended or truncated to Index's type. - ConstantInt *Index; - Value *Stride; + ConstantInt *Index = nullptr; + + Value *Stride = nullptr; + // The instruction this candidate corresponds to. It helps us to rewrite a // candidate with respect to its immediate basis. Note that one instruction // can correspond to multiple candidates depending on how you associate the @@ -116,16 +137,16 @@ public: // or // // <Base: b, Index: 2, Stride: a + 1> - Instruction *Ins; + Instruction *Ins = nullptr; + // Points to the immediate basis of this candidate, or nullptr if we cannot // find any basis for this candidate. - Candidate *Basis; + Candidate *Basis = nullptr; }; static char ID; - StraightLineStrengthReduce() - : FunctionPass(ID), DL(nullptr), DT(nullptr), TTI(nullptr) { + StraightLineStrengthReduce() : FunctionPass(ID) { initializeStraightLineStrengthReducePass(*PassRegistry::getPassRegistry()); } @@ -148,46 +169,58 @@ private: // Returns true if Basis is a basis for C, i.e., Basis dominates C and they // share the same base and stride. bool isBasisFor(const Candidate &Basis, const Candidate &C); + // Returns whether the candidate can be folded into an addressing mode. bool isFoldable(const Candidate &C, TargetTransformInfo *TTI, const DataLayout *DL); + // Returns true if C is already in a simplest form and not worth being // rewritten. bool isSimplestForm(const Candidate &C); + // Checks whether I is in a candidate form. If so, adds all the matching forms // to Candidates, and tries to find the immediate basis for each of them. void allocateCandidatesAndFindBasis(Instruction *I); + // Allocate candidates and find bases for Add instructions. void allocateCandidatesAndFindBasisForAdd(Instruction *I); + // Given I = LHS + RHS, factors RHS into i * S and makes (LHS + i * S) a // candidate. void allocateCandidatesAndFindBasisForAdd(Value *LHS, Value *RHS, Instruction *I); // Allocate candidates and find bases for Mul instructions. void allocateCandidatesAndFindBasisForMul(Instruction *I); + // Splits LHS into Base + Index and, if succeeds, calls // allocateCandidatesAndFindBasis. void allocateCandidatesAndFindBasisForMul(Value *LHS, Value *RHS, Instruction *I); + // Allocate candidates and find bases for GetElementPtr instructions. void allocateCandidatesAndFindBasisForGEP(GetElementPtrInst *GEP); + // A helper function that scales Idx with ElementSize before invoking // allocateCandidatesAndFindBasis. void allocateCandidatesAndFindBasisForGEP(const SCEV *B, ConstantInt *Idx, Value *S, uint64_t ElementSize, Instruction *I); + // Adds the given form <CT, B, Idx, S> to Candidates, and finds its immediate // basis. void allocateCandidatesAndFindBasis(Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, Instruction *I); + // Rewrites candidate C with respect to Basis. void rewriteCandidateWithBasis(const Candidate &C, const Candidate &Basis); + // A helper function that factors ArrayIdx to a product of a stride and a // constant index, and invokes allocateCandidatesAndFindBasis with the // factorings. void factorArrayIndex(Value *ArrayIdx, const SCEV *Base, uint64_t ElementSize, GetElementPtrInst *GEP); + // Emit code that computes the "bump" from Basis to C. If the candidate is a // GEP and the bump is not divisible by the element size of the GEP, this // function sets the BumpWithUglyGEP flag to notify its caller to bump the @@ -196,19 +229,22 @@ private: IRBuilder<> &Builder, const DataLayout *DL, bool &BumpWithUglyGEP); - const DataLayout *DL; - DominatorTree *DT; + const DataLayout *DL = nullptr; + DominatorTree *DT = nullptr; ScalarEvolution *SE; - TargetTransformInfo *TTI; + TargetTransformInfo *TTI = nullptr; std::list<Candidate> Candidates; + // Temporarily holds all instructions that are unlinked (but not deleted) by // rewriteCandidateWithBasis. These instructions will be actually removed // after all rewriting finishes. std::vector<Instruction *> UnlinkedInstructions; }; -} // anonymous namespace + +} // end anonymous namespace char StraightLineStrengthReduce::ID = 0; + INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr", "Straight line strength reduction", false, false) INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) @@ -650,8 +686,8 @@ void StraightLineStrengthReduce::rewriteCandidateWithBasis( else Reduced = Builder.CreateGEP(nullptr, Basis.Ins, Bump); } + break; } - break; default: llvm_unreachable("C.CandidateKind is invalid"); }; diff --git a/lib/Transforms/Scalar/StructurizeCFG.cpp b/lib/Transforms/Scalar/StructurizeCFG.cpp index 0cccb415efdb..2972e1cff9a4 100644 --- a/lib/Transforms/Scalar/StructurizeCFG.cpp +++ b/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -1,4 +1,4 @@ -//===-- StructurizeCFG.cpp ------------------------------------------------===// +//===- StructurizeCFG.cpp -------------------------------------------------===// // // The LLVM Compiler Infrastructure // @@ -7,49 +7,72 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/DivergenceAnalysis.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/RegionInfo.h" #include "llvm/Analysis/RegionIterator.h" #include "llvm/Analysis/RegionPass.h" -#include "llvm/IR/Module.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/SSAUpdater.h" +#include <algorithm> +#include <cassert> +#include <utility> using namespace llvm; using namespace llvm::PatternMatch; #define DEBUG_TYPE "structurizecfg" +// The name for newly created blocks. +static const char *const FlowBlockName = "Flow"; + namespace { // Definition of the complex types used in this pass. -typedef std::pair<BasicBlock *, Value *> BBValuePair; +using BBValuePair = std::pair<BasicBlock *, Value *>; -typedef SmallVector<RegionNode*, 8> RNVector; -typedef SmallVector<BasicBlock*, 8> BBVector; -typedef SmallVector<BranchInst*, 8> BranchVector; -typedef SmallVector<BBValuePair, 2> BBValueVector; +using RNVector = SmallVector<RegionNode *, 8>; +using BBVector = SmallVector<BasicBlock *, 8>; +using BranchVector = SmallVector<BranchInst *, 8>; +using BBValueVector = SmallVector<BBValuePair, 2>; -typedef SmallPtrSet<BasicBlock *, 8> BBSet; +using BBSet = SmallPtrSet<BasicBlock *, 8>; -typedef MapVector<PHINode *, BBValueVector> PhiMap; -typedef MapVector<BasicBlock *, BBVector> BB2BBVecMap; +using PhiMap = MapVector<PHINode *, BBValueVector>; +using BB2BBVecMap = MapVector<BasicBlock *, BBVector>; -typedef DenseMap<BasicBlock *, PhiMap> BBPhiMap; -typedef DenseMap<BasicBlock *, Value *> BBPredicates; -typedef DenseMap<BasicBlock *, BBPredicates> PredMap; -typedef DenseMap<BasicBlock *, BasicBlock*> BB2BBMap; - -// The name for newly created blocks. -static const char *const FlowBlockName = "Flow"; +using BBPhiMap = DenseMap<BasicBlock *, PhiMap>; +using BBPredicates = DenseMap<BasicBlock *, Value *>; +using PredMap = DenseMap<BasicBlock *, BBPredicates>; +using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>; /// Finds the nearest common dominator of a set of BasicBlocks. /// @@ -736,7 +759,6 @@ void StructurizeCFG::wireFlow(bool ExitUseAllowed, changeExit(PrevNode, Node->getEntry(), true); } PrevNode = Node; - } else { // Insert extra prefix node (or reuse last one) BasicBlock *Flow = needPrefix(false); diff --git a/lib/Transforms/Scalar/TailRecursionElimination.cpp b/lib/Transforms/Scalar/TailRecursionElimination.cpp index 90c5c243f464..2a1106b41de2 100644 --- a/lib/Transforms/Scalar/TailRecursionElimination.cpp +++ b/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -60,6 +60,7 @@ #include "llvm/Analysis/InlineCost.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" @@ -78,7 +79,6 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" using namespace llvm; #define DEBUG_TYPE "tailcallelim" @@ -177,7 +177,8 @@ struct AllocaDerivedValueTracker { }; } -static bool markTails(Function &F, bool &AllCallsAreTailCalls) { +static bool markTails(Function &F, bool &AllCallsAreTailCalls, + OptimizationRemarkEmitter *ORE) { if (F.callsFunctionThatReturnsTwice()) return false; AllCallsAreTailCalls = true; @@ -228,7 +229,7 @@ static bool markTails(Function &F, bool &AllCallsAreTailCalls) { Escaped = ESCAPED; CallInst *CI = dyn_cast<CallInst>(&I); - if (!CI || CI->isTailCall()) + if (!CI || CI->isTailCall() || isa<DbgInfoIntrinsic>(&I)) continue; bool IsNoTail = CI->isNoTailCall() || CI->hasOperandBundles(); @@ -252,9 +253,11 @@ static bool markTails(Function &F, bool &AllCallsAreTailCalls) { break; } if (SafeToTail) { - emitOptimizationRemark( - F.getContext(), "tailcallelim", F, CI->getDebugLoc(), - "marked this readnone call a tail call candidate"); + using namespace ore; + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "tailcall-readnone", CI) + << "marked as tail call candidate (readnone)"; + }); CI->setTailCall(); Modified = true; continue; @@ -299,9 +302,7 @@ static bool markTails(Function &F, bool &AllCallsAreTailCalls) { if (Visited[CI->getParent()] != ESCAPED) { // If the escape point was part way through the block, calls after the // escape point wouldn't have been put into DeferredTails. - emitOptimizationRemark(F.getContext(), "tailcallelim", F, - CI->getDebugLoc(), - "marked this call a tail call candidate"); + DEBUG(dbgs() << "Marked as tail call candidate: " << *CI << "\n"); CI->setTailCall(); Modified = true; } else { @@ -330,7 +331,7 @@ static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { // Writes to memory only matter if they may alias the pointer // being loaded from. const DataLayout &DL = L->getModule()->getDataLayout(); - if ((AA->getModRefInfo(CI, MemoryLocation::get(L)) & MRI_Mod) || + if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) || !isSafeToLoadUnconditionally(L->getPointerOperand(), L->getAlignment(), DL, L)) return false; @@ -491,7 +492,8 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, BasicBlock *&OldEntry, bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, - AliasAnalysis *AA) { + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE) { // If we are introducing accumulator recursion to eliminate operations after // the call instruction that are both associative and commutative, the initial // value for the accumulator is placed in this variable. If this value is set @@ -551,8 +553,11 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, BasicBlock *BB = Ret->getParent(); Function *F = BB->getParent(); - emitOptimizationRemark(F->getContext(), "tailcallelim", *F, CI->getDebugLoc(), - "transforming tail recursion to loop"); + using namespace ore; + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "tailcall-recursion", CI) + << "transforming tail recursion into loop"; + }); // OK! We can transform this tail call. If this is the first one found, // create the new entry block, allowing us to branch back to the old entry. @@ -666,13 +671,11 @@ static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, return true; } -static bool foldReturnAndProcessPred(BasicBlock *BB, ReturnInst *Ret, - BasicBlock *&OldEntry, - bool &TailCallsAreMarkedTail, - SmallVectorImpl<PHINode *> &ArgumentPHIs, - bool CannotTailCallElimCallsMarkedTail, - const TargetTransformInfo *TTI, - AliasAnalysis *AA) { +static bool foldReturnAndProcessPred( + BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE) { bool Change = false; // Make sure this block is a trivial return block. @@ -708,7 +711,7 @@ static bool foldReturnAndProcessPred(BasicBlock *BB, ReturnInst *Ret, BB->eraseFromParent(); eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA); + ArgumentPHIs, AA, ORE); ++NumRetDuped; Change = true; } @@ -722,23 +725,25 @@ static bool processReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, SmallVectorImpl<PHINode *> &ArgumentPHIs, bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, - AliasAnalysis *AA) { + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE) { CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); if (!CI) return false; return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, AA); + ArgumentPHIs, AA, ORE); } static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, - AliasAnalysis *AA) { + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE) { if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") return false; bool MadeChange = false; bool AllCallsAreTailCalls = false; - MadeChange |= markTails(F, AllCallsAreTailCalls); + MadeChange |= markTails(F, AllCallsAreTailCalls, ORE); if (!AllCallsAreTailCalls) return MadeChange; @@ -765,13 +770,13 @@ static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) { BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB. if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { - bool Change = - processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, - ArgumentPHIs, !CanTRETailMarkedCall, TTI, AA); + bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs, !CanTRETailMarkedCall, + TTI, AA, ORE); if (!Change && BB->getFirstNonPHIOrDbg() == Ret) Change = foldReturnAndProcessPred(BB, Ret, OldEntry, TailCallsAreMarkedTail, ArgumentPHIs, - !CanTRETailMarkedCall, TTI, AA); + !CanTRETailMarkedCall, TTI, AA, ORE); MadeChange |= Change; } } @@ -802,6 +807,7 @@ struct TailCallElim : public FunctionPass { void getAnalysisUsage(AnalysisUsage &AU) const override { AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); AU.addPreserved<GlobalsAAWrapperPass>(); } @@ -811,7 +817,8 @@ struct TailCallElim : public FunctionPass { return eliminateTailRecursion( F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F), - &getAnalysis<AAResultsWrapperPass>().getAAResults()); + &getAnalysis<AAResultsWrapperPass>().getAAResults(), + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); } }; } @@ -820,6 +827,7 @@ char TailCallElim::ID = 0; INITIALIZE_PASS_BEGIN(TailCallElim, "tailcallelim", "Tail Call Elimination", false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(TailCallElim, "tailcallelim", "Tail Call Elimination", false, false) @@ -833,8 +841,9 @@ PreservedAnalyses TailCallElimPass::run(Function &F, TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); AliasAnalysis &AA = AM.getResult<AAManager>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); - bool Changed = eliminateTailRecursion(F, &TTI, &AA); + bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE); if (!Changed) return PreservedAnalyses::all(); diff --git a/lib/Transforms/Utils/ASanStackFrameLayout.cpp b/lib/Transforms/Utils/ASanStackFrameLayout.cpp index df9d5da9e26e..364878dc588d 100644 --- a/lib/Transforms/Utils/ASanStackFrameLayout.cpp +++ b/lib/Transforms/Utils/ASanStackFrameLayout.cpp @@ -36,9 +36,11 @@ static inline bool CompareVars(const ASanStackVariableDescription &a, // with e.g. alignment 1 and alignment 16 do not get reordered by CompareVars. static const size_t kMinAlignment = 16; +// We want to add a full redzone after every variable. // The larger the variable Size the larger is the redzone. // The resulting frame size is a multiple of Alignment. -static size_t VarAndRedzoneSize(size_t Size, size_t Alignment) { +static size_t VarAndRedzoneSize(size_t Size, size_t Granularity, + size_t Alignment) { size_t Res = 0; if (Size <= 4) Res = 16; else if (Size <= 16) Res = 32; @@ -46,7 +48,7 @@ static size_t VarAndRedzoneSize(size_t Size, size_t Alignment) { else if (Size <= 512) Res = Size + 64; else if (Size <= 4096) Res = Size + 128; else Res = Size + 256; - return alignTo(Res, Alignment); + return alignTo(std::max(Res, 2 * Granularity), Alignment); } ASanStackFrameLayout @@ -80,7 +82,8 @@ ComputeASanStackFrameLayout(SmallVectorImpl<ASanStackVariableDescription> &Vars, assert(Size > 0); size_t NextAlignment = IsLast ? Granularity : std::max(Granularity, Vars[i + 1].Alignment); - size_t SizeWithRedzone = VarAndRedzoneSize(Size, NextAlignment); + size_t SizeWithRedzone = VarAndRedzoneSize(Size, Granularity, + NextAlignment); Vars[i].Offset = Offset; Offset += SizeWithRedzone; } diff --git a/lib/Transforms/Utils/AddDiscriminators.cpp b/lib/Transforms/Utils/AddDiscriminators.cpp index 4c9746b8c691..0f0668f24db5 100644 --- a/lib/Transforms/Utils/AddDiscriminators.cpp +++ b/lib/Transforms/Utils/AddDiscriminators.cpp @@ -50,31 +50,45 @@ // // For more details about DWARF discriminators, please visit // http://wiki.dwarfstd.org/index.php?title=Path_Discriminators +// //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/AddDiscriminators.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/StringRef.h" #include "llvm/IR/BasicBlock.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" +#include <utility> using namespace llvm; #define DEBUG_TYPE "add-discriminators" +// Command line option to disable discriminator generation even in the +// presence of debug information. This is only needed when debugging +// debug info generation issues. +static cl::opt<bool> NoDiscriminators( + "no-discriminators", cl::init(false), + cl::desc("Disable generation of discriminator information.")); + namespace { + // The legacy pass of AddDiscriminators. struct AddDiscriminatorsLegacyPass : public FunctionPass { static char ID; // Pass identification, replacement for typeid + AddDiscriminatorsLegacyPass() : FunctionPass(ID) { initializeAddDiscriminatorsLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -85,18 +99,12 @@ struct AddDiscriminatorsLegacyPass : public FunctionPass { } // end anonymous namespace char AddDiscriminatorsLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(AddDiscriminatorsLegacyPass, "add-discriminators", "Add DWARF path discriminators", false, false) INITIALIZE_PASS_END(AddDiscriminatorsLegacyPass, "add-discriminators", "Add DWARF path discriminators", false, false) -// Command line option to disable discriminator generation even in the -// presence of debug information. This is only needed when debugging -// debug info generation issues. -static cl::opt<bool> NoDiscriminators( - "no-discriminators", cl::init(false), - cl::desc("Disable generation of discriminator information.")); - // Create the legacy AddDiscriminatorsPass. FunctionPass *llvm::createAddDiscriminatorsPass() { return new AddDiscriminatorsLegacyPass(); @@ -166,11 +174,11 @@ static bool addDiscriminators(Function &F) { bool Changed = false; - typedef std::pair<StringRef, unsigned> Location; - typedef DenseSet<const BasicBlock *> BBSet; - typedef DenseMap<Location, BBSet> LocationBBMap; - typedef DenseMap<Location, unsigned> LocationDiscriminatorMap; - typedef DenseSet<Location> LocationSet; + using Location = std::pair<StringRef, unsigned>; + using BBSet = DenseSet<const BasicBlock *>; + using LocationBBMap = DenseMap<Location, BBSet>; + using LocationDiscriminatorMap = DenseMap<Location, unsigned>; + using LocationSet = DenseSet<Location>; LocationBBMap LBM; LocationDiscriminatorMap LDM; @@ -242,6 +250,7 @@ static bool addDiscriminators(Function &F) { bool AddDiscriminatorsLegacyPass::runOnFunction(Function &F) { return addDiscriminators(F); } + PreservedAnalyses AddDiscriminatorsPass::run(Function &F, FunctionAnalysisManager &AM) { if (!addDiscriminators(F)) diff --git a/lib/Transforms/Utils/BasicBlockUtils.cpp b/lib/Transforms/Utils/BasicBlockUtils.cpp index 3d5cbfc93f2e..606bd8baccaa 100644 --- a/lib/Transforms/Utils/BasicBlockUtils.cpp +++ b/lib/Transforms/Utils/BasicBlockUtils.cpp @@ -1,4 +1,4 @@ -//===-- BasicBlockUtils.cpp - BasicBlock Utilities -------------------------==// +//===- BasicBlockUtils.cpp - BasicBlock Utilities --------------------------==// // // The LLVM Compiler Infrastructure // @@ -13,22 +13,36 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/MemoryDependenceAnalysis.h" -#include "llvm/IR/Constant.h" -#include "llvm/IR/DataLayout.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" -#include "llvm/Support/ErrorHandling.h" -#include "llvm/Transforms/Scalar.h" +#include "llvm/Support/Casting.h" #include "llvm/Transforms/Utils/Local.h" -#include <algorithm> +#include <cassert> +#include <cstdint> +#include <string> +#include <utility> +#include <vector> + using namespace llvm; void llvm::DeleteDeadBlock(BasicBlock *BB) { @@ -130,8 +144,16 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, } // Begin by getting rid of unneeded PHIs. - if (isa<PHINode>(BB->front())) + SmallVector<Value *, 4> IncomingValues; + if (isa<PHINode>(BB->front())) { + for (auto &I : *BB) + if (PHINode *PN = dyn_cast<PHINode>(&I)) { + if (PN->getIncomingValue(0) != PN) + IncomingValues.push_back(PN->getIncomingValue(0)); + } else + break; FoldSingleEntryPHINodes(BB, MemDep); + } // Delete the unconditional branch from the predecessor... PredBB->getInstList().pop_back(); @@ -143,6 +165,21 @@ bool llvm::MergeBlockIntoPredecessor(BasicBlock *BB, DominatorTree *DT, // Move all definitions in the successor to the predecessor... PredBB->getInstList().splice(PredBB->end(), BB->getInstList()); + // Eliminate duplicate dbg.values describing the entry PHI node post-splice. + for (auto *Incoming : IncomingValues) { + if (isa<Instruction>(Incoming)) { + SmallVector<DbgValueInst *, 2> DbgValues; + SmallDenseSet<std::pair<DILocalVariable *, DIExpression *>, 2> + DbgValueSet; + llvm::findDbgValues(DbgValues, Incoming); + for (auto &DVI : DbgValues) { + auto R = DbgValueSet.insert({DVI->getVariable(), DVI->getExpression()}); + if (!R.second) + DVI->eraseFromParent(); + } + } + } + // Inherit predecessors name if it exists. if (!PredBB->hasName()) PredBB->takeName(BB); @@ -454,7 +491,7 @@ BasicBlock *llvm::SplitBlockPredecessors(BasicBlock *BB, // node becomes an incoming value for BB's phi node. However, if the Preds // list is empty, we need to insert dummy entries into the PHI nodes in BB to // account for the newly created predecessor. - if (Preds.size() == 0) { + if (Preds.empty()) { // Insert dummy values as the incoming value. for (BasicBlock::iterator I = BB->begin(); isa<PHINode>(I); ++I) cast<PHINode>(I)->addIncoming(UndefValue::get(I->getType()), NewBB); @@ -675,7 +712,6 @@ void llvm::SplitBlockAndInsertIfThenElse(Value *Cond, Instruction *SplitBefore, ReplaceInstWithInst(HeadOldTerm, HeadNewTerm); } - Value *llvm::GetIfCondition(BasicBlock *BB, BasicBlock *&IfTrue, BasicBlock *&IfFalse) { PHINode *SomePHI = dyn_cast<PHINode>(BB->begin()); diff --git a/lib/Transforms/Utils/BreakCriticalEdges.cpp b/lib/Transforms/Utils/BreakCriticalEdges.cpp index 175cbd2ce0df..3653c307619b 100644 --- a/lib/Transforms/Utils/BreakCriticalEdges.cpp +++ b/lib/Transforms/Utils/BreakCriticalEdges.cpp @@ -16,9 +16,11 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/BreakCriticalEdges.h" +#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/CFG.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/IR/CFG.h" @@ -28,6 +30,8 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" using namespace llvm; #define DEBUG_TYPE "break-crit-edges" @@ -198,59 +202,23 @@ llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, if (!DT && !LI) return NewBB; - // Now update analysis information. Since the only predecessor of NewBB is - // the TIBB, TIBB clearly dominates NewBB. TIBB usually doesn't dominate - // anything, as there are other successors of DestBB. However, if all other - // predecessors of DestBB are already dominated by DestBB (e.g. DestBB is a - // loop header) then NewBB dominates DestBB. - SmallVector<BasicBlock*, 8> OtherPreds; - - // If there is a PHI in the block, loop over predecessors with it, which is - // faster than iterating pred_begin/end. - if (PHINode *PN = dyn_cast<PHINode>(DestBB->begin())) { - for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) - if (PN->getIncomingBlock(i) != NewBB) - OtherPreds.push_back(PN->getIncomingBlock(i)); - } else { - for (pred_iterator I = pred_begin(DestBB), E = pred_end(DestBB); - I != E; ++I) { - BasicBlock *P = *I; - if (P != NewBB) - OtherPreds.push_back(P); - } - } - - bool NewBBDominatesDestBB = true; - - // Should we update DominatorTree information? if (DT) { - DomTreeNode *TINode = DT->getNode(TIBB); - - // The new block is not the immediate dominator for any other nodes, but - // TINode is the immediate dominator for the new node. + // Update the DominatorTree. + // ---> NewBB -----\ + // / V + // TIBB -------\\------> DestBB // - if (TINode) { // Don't break unreachable code! - DomTreeNode *NewBBNode = DT->addNewBlock(NewBB, TIBB); - DomTreeNode *DestBBNode = nullptr; - - // If NewBBDominatesDestBB hasn't been computed yet, do so with DT. - if (!OtherPreds.empty()) { - DestBBNode = DT->getNode(DestBB); - while (!OtherPreds.empty() && NewBBDominatesDestBB) { - if (DomTreeNode *OPNode = DT->getNode(OtherPreds.back())) - NewBBDominatesDestBB = DT->dominates(DestBBNode, OPNode); - OtherPreds.pop_back(); - } - OtherPreds.clear(); - } - - // If NewBBDominatesDestBB, then NewBB dominates DestBB, otherwise it - // doesn't dominate anything. - if (NewBBDominatesDestBB) { - if (!DestBBNode) DestBBNode = DT->getNode(DestBB); - DT->changeImmediateDominator(DestBBNode, NewBBNode); - } - } + // First, inform the DT about the new path from TIBB to DestBB via NewBB, + // then delete the old edge from TIBB to DestBB. By doing this in that order + // DestBB stays reachable in the DT the whole time and its subtree doesn't + // get disconnected. + SmallVector<DominatorTree::UpdateType, 3> Updates; + Updates.push_back({DominatorTree::Insert, TIBB, NewBB}); + Updates.push_back({DominatorTree::Insert, NewBB, DestBB}); + if (llvm::find(successors(TIBB), DestBB) == succ_end(TIBB)) + Updates.push_back({DominatorTree::Delete, TIBB, DestBB}); + + DT->applyUpdates(Updates); } // Update LoopInfo if it is around. @@ -326,3 +294,159 @@ llvm::SplitCriticalEdge(TerminatorInst *TI, unsigned SuccNum, return NewBB; } + +// Return the unique indirectbr predecessor of a block. This may return null +// even if such a predecessor exists, if it's not useful for splitting. +// If a predecessor is found, OtherPreds will contain all other (non-indirectbr) +// predecessors of BB. +static BasicBlock * +findIBRPredecessor(BasicBlock *BB, SmallVectorImpl<BasicBlock *> &OtherPreds) { + // If the block doesn't have any PHIs, we don't care about it, since there's + // no point in splitting it. + PHINode *PN = dyn_cast<PHINode>(BB->begin()); + if (!PN) + return nullptr; + + // Verify we have exactly one IBR predecessor. + // Conservatively bail out if one of the other predecessors is not a "regular" + // terminator (that is, not a switch or a br). + BasicBlock *IBB = nullptr; + for (unsigned Pred = 0, E = PN->getNumIncomingValues(); Pred != E; ++Pred) { + BasicBlock *PredBB = PN->getIncomingBlock(Pred); + TerminatorInst *PredTerm = PredBB->getTerminator(); + switch (PredTerm->getOpcode()) { + case Instruction::IndirectBr: + if (IBB) + return nullptr; + IBB = PredBB; + break; + case Instruction::Br: + case Instruction::Switch: + OtherPreds.push_back(PredBB); + continue; + default: + return nullptr; + } + } + + return IBB; +} + +bool llvm::SplitIndirectBrCriticalEdges(Function &F, + BranchProbabilityInfo *BPI, + BlockFrequencyInfo *BFI) { + // Check whether the function has any indirectbrs, and collect which blocks + // they may jump to. Since most functions don't have indirect branches, + // this lowers the common case's overhead to O(Blocks) instead of O(Edges). + SmallSetVector<BasicBlock *, 16> Targets; + for (auto &BB : F) { + auto *IBI = dyn_cast<IndirectBrInst>(BB.getTerminator()); + if (!IBI) + continue; + + for (unsigned Succ = 0, E = IBI->getNumSuccessors(); Succ != E; ++Succ) + Targets.insert(IBI->getSuccessor(Succ)); + } + + if (Targets.empty()) + return false; + + bool ShouldUpdateAnalysis = BPI && BFI; + bool Changed = false; + for (BasicBlock *Target : Targets) { + SmallVector<BasicBlock *, 16> OtherPreds; + BasicBlock *IBRPred = findIBRPredecessor(Target, OtherPreds); + // If we did not found an indirectbr, or the indirectbr is the only + // incoming edge, this isn't the kind of edge we're looking for. + if (!IBRPred || OtherPreds.empty()) + continue; + + // Don't even think about ehpads/landingpads. + Instruction *FirstNonPHI = Target->getFirstNonPHI(); + if (FirstNonPHI->isEHPad() || Target->isLandingPad()) + continue; + + BasicBlock *BodyBlock = Target->splitBasicBlock(FirstNonPHI, ".split"); + if (ShouldUpdateAnalysis) { + // Copy the BFI/BPI from Target to BodyBlock. + for (unsigned I = 0, E = BodyBlock->getTerminator()->getNumSuccessors(); + I < E; ++I) + BPI->setEdgeProbability(BodyBlock, I, + BPI->getEdgeProbability(Target, I)); + BFI->setBlockFreq(BodyBlock, BFI->getBlockFreq(Target).getFrequency()); + } + // It's possible Target was its own successor through an indirectbr. + // In this case, the indirectbr now comes from BodyBlock. + if (IBRPred == Target) + IBRPred = BodyBlock; + + // At this point Target only has PHIs, and BodyBlock has the rest of the + // block's body. Create a copy of Target that will be used by the "direct" + // preds. + ValueToValueMapTy VMap; + BasicBlock *DirectSucc = CloneBasicBlock(Target, VMap, ".clone", &F); + + BlockFrequency BlockFreqForDirectSucc; + for (BasicBlock *Pred : OtherPreds) { + // If the target is a loop to itself, then the terminator of the split + // block (BodyBlock) needs to be updated. + BasicBlock *Src = Pred != Target ? Pred : BodyBlock; + Src->getTerminator()->replaceUsesOfWith(Target, DirectSucc); + if (ShouldUpdateAnalysis) + BlockFreqForDirectSucc += BFI->getBlockFreq(Src) * + BPI->getEdgeProbability(Src, DirectSucc); + } + if (ShouldUpdateAnalysis) { + BFI->setBlockFreq(DirectSucc, BlockFreqForDirectSucc.getFrequency()); + BlockFrequency NewBlockFreqForTarget = + BFI->getBlockFreq(Target) - BlockFreqForDirectSucc; + BFI->setBlockFreq(Target, NewBlockFreqForTarget.getFrequency()); + BPI->eraseBlock(Target); + } + + // Ok, now fix up the PHIs. We know the two blocks only have PHIs, and that + // they are clones, so the number of PHIs are the same. + // (a) Remove the edge coming from IBRPred from the "Direct" PHI + // (b) Leave that as the only edge in the "Indirect" PHI. + // (c) Merge the two in the body block. + BasicBlock::iterator Indirect = Target->begin(), + End = Target->getFirstNonPHI()->getIterator(); + BasicBlock::iterator Direct = DirectSucc->begin(); + BasicBlock::iterator MergeInsert = BodyBlock->getFirstInsertionPt(); + + assert(&*End == Target->getTerminator() && + "Block was expected to only contain PHIs"); + + while (Indirect != End) { + PHINode *DirPHI = cast<PHINode>(Direct); + PHINode *IndPHI = cast<PHINode>(Indirect); + + // Now, clean up - the direct block shouldn't get the indirect value, + // and vice versa. + DirPHI->removeIncomingValue(IBRPred); + Direct++; + + // Advance the pointer here, to avoid invalidation issues when the old + // PHI is erased. + Indirect++; + + PHINode *NewIndPHI = PHINode::Create(IndPHI->getType(), 1, "ind", IndPHI); + NewIndPHI->addIncoming(IndPHI->getIncomingValueForBlock(IBRPred), + IBRPred); + + // Create a PHI in the body block, to merge the direct and indirect + // predecessors. + PHINode *MergePHI = + PHINode::Create(IndPHI->getType(), 2, "merge", &*MergeInsert); + MergePHI->addIncoming(NewIndPHI, Target); + MergePHI->addIncoming(DirPHI, DirectSucc); + + IndPHI->replaceAllUsesWith(MergePHI); + IndPHI->eraseFromParent(); + } + + Changed = true; + } + + return Changed; +} diff --git a/lib/Transforms/Utils/BypassSlowDivision.cpp b/lib/Transforms/Utils/BypassSlowDivision.cpp index 83ec7f55d1af..f711b192f604 100644 --- a/lib/Transforms/Utils/BypassSlowDivision.cpp +++ b/lib/Transforms/Utils/BypassSlowDivision.cpp @@ -1,4 +1,4 @@ -//===-- BypassSlowDivision.cpp - Bypass slow division ---------------------===// +//===- BypassSlowDivision.cpp - Bypass slow division ----------------------===// // // The LLVM Compiler Infrastructure // @@ -17,27 +17,32 @@ #include "llvm/Transforms/Utils/BypassSlowDivision.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/KnownBits.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> using namespace llvm; #define DEBUG_TYPE "bypass-slow-division" namespace { - struct DivOpInfo { - bool SignedOp; - Value *Dividend; - Value *Divisor; - - DivOpInfo(bool InSignedOp, Value *InDividend, Value *InDivisor) - : SignedOp(InSignedOp), Dividend(InDividend), Divisor(InDivisor) {} - }; struct QuotRemPair { Value *Quotient; @@ -55,38 +60,11 @@ namespace { Value *Quotient = nullptr; Value *Remainder = nullptr; }; -} - -namespace llvm { - template<> - struct DenseMapInfo<DivOpInfo> { - static bool isEqual(const DivOpInfo &Val1, const DivOpInfo &Val2) { - return Val1.SignedOp == Val2.SignedOp && - Val1.Dividend == Val2.Dividend && - Val1.Divisor == Val2.Divisor; - } - - static DivOpInfo getEmptyKey() { - return DivOpInfo(false, nullptr, nullptr); - } - static DivOpInfo getTombstoneKey() { - return DivOpInfo(true, nullptr, nullptr); - } +using DivCacheTy = DenseMap<DivRemMapKey, QuotRemPair>; +using BypassWidthsTy = DenseMap<unsigned, unsigned>; +using VisitedSetTy = SmallPtrSet<Instruction *, 4>; - static unsigned getHashValue(const DivOpInfo &Val) { - return (unsigned)(reinterpret_cast<uintptr_t>(Val.Dividend) ^ - reinterpret_cast<uintptr_t>(Val.Divisor)) ^ - (unsigned)Val.SignedOp; - } - }; - - typedef DenseMap<DivOpInfo, QuotRemPair> DivCacheTy; - typedef DenseMap<unsigned, unsigned> BypassWidthsTy; - typedef SmallPtrSet<Instruction *, 4> VisitedSetTy; -} - -namespace { enum ValueRange { /// Operand definitely fits into BypassType. No runtime checks are needed. VALRNG_KNOWN_SHORT, @@ -116,17 +94,21 @@ class FastDivInsertionTask { return SlowDivOrRem->getOpcode() == Instruction::SDiv || SlowDivOrRem->getOpcode() == Instruction::SRem; } + bool isDivisionOp() { return SlowDivOrRem->getOpcode() == Instruction::SDiv || SlowDivOrRem->getOpcode() == Instruction::UDiv; } + Type *getSlowType() { return SlowDivOrRem->getType(); } public: FastDivInsertionTask(Instruction *I, const BypassWidthsTy &BypassWidths); + Value *getReplacement(DivCacheTy &Cache); }; -} // anonymous namespace + +} // end anonymous namespace FastDivInsertionTask::FastDivInsertionTask(Instruction *I, const BypassWidthsTy &BypassWidths) { @@ -175,7 +157,7 @@ Value *FastDivInsertionTask::getReplacement(DivCacheTy &Cache) { // Then, look for a value in Cache. Value *Dividend = SlowDivOrRem->getOperand(0); Value *Divisor = SlowDivOrRem->getOperand(1); - DivOpInfo Key(isSignedOp(), Dividend, Divisor); + DivRemMapKey Key(isSignedOp(), Dividend, Divisor); auto CacheI = Cache.find(Key); if (CacheI == Cache.end()) { @@ -225,7 +207,7 @@ bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) { C = dyn_cast<ConstantInt>(cast<BitCastInst>(Op1)->getOperand(0)); return C && C->getValue().getMinSignedBits() > BypassType->getBitWidth(); } - case Instruction::PHI: { + case Instruction::PHI: // Stop IR traversal in case of a crazy input code. This limits recursion // depth. if (Visited.size() >= 16) @@ -241,7 +223,6 @@ bool FastDivInsertionTask::isHashLikeValue(Value *V, VisitedSetTy &Visited) { return getValueRange(V, Visited) == VALRNG_LIKELY_LONG || isa<UndefValue>(V); }); - } default: return false; } @@ -371,11 +352,6 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { Value *Dividend = SlowDivOrRem->getOperand(0); Value *Divisor = SlowDivOrRem->getOperand(1); - if (isa<ConstantInt>(Divisor)) { - // Keep division by a constant for DAGCombiner. - return None; - } - VisitedSetTy SetL; ValueRange DividendRange = getValueRange(Dividend, SetL); if (DividendRange == VALRNG_LIKELY_LONG) @@ -391,7 +367,9 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { if (DividendShort && DivisorShort) { // If both operands are known to be short then just replace the long - // division with a short one in-place. + // division with a short one in-place. Since we're not introducing control + // flow in this case, narrowing the division is always a win, even if the + // divisor is a constant (and will later get replaced by a multiplication). IRBuilder<> Builder(SlowDivOrRem); Value *TruncDividend = Builder.CreateTrunc(Dividend, BypassType); @@ -401,7 +379,16 @@ Optional<QuotRemPair> FastDivInsertionTask::insertFastDivAndRem() { Value *ExtDiv = Builder.CreateZExt(TruncDiv, getSlowType()); Value *ExtRem = Builder.CreateZExt(TruncRem, getSlowType()); return QuotRemPair(ExtDiv, ExtRem); - } else if (DividendShort && !isSignedOp()) { + } + + if (isa<ConstantInt>(Divisor)) { + // If the divisor is not a constant, DAGCombiner will convert it to a + // multiplication by a magic constant. It isn't clear if it is worth + // introducing control flow to get a narrower multiply. + return None; + } + + if (DividendShort && !isSignedOp()) { // If the division is unsigned and Dividend is known to be short, then // either // 1) Divisor is less or equal to Dividend, and the result can be computed diff --git a/lib/Transforms/Utils/CMakeLists.txt b/lib/Transforms/Utils/CMakeLists.txt index 83bc05d0311c..972e47f9270a 100644 --- a/lib/Transforms/Utils/CMakeLists.txt +++ b/lib/Transforms/Utils/CMakeLists.txt @@ -5,12 +5,13 @@ add_llvm_library(LLVMTransformUtils BreakCriticalEdges.cpp BuildLibCalls.cpp BypassSlowDivision.cpp + CallPromotionUtils.cpp CloneFunction.cpp CloneModule.cpp - CmpInstAnalysis.cpp CodeExtractor.cpp CtorUtils.cpp DemoteRegToStack.cpp + EntryExitInstrumenter.cpp EscapeEnumerator.cpp Evaluator.cpp FlattenCFG.cpp diff --git a/lib/Transforms/Utils/CallPromotionUtils.cpp b/lib/Transforms/Utils/CallPromotionUtils.cpp new file mode 100644 index 000000000000..eb3139ce4293 --- /dev/null +++ b/lib/Transforms/Utils/CallPromotionUtils.cpp @@ -0,0 +1,328 @@ +//===- CallPromotionUtils.cpp - Utilities for call promotion ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements utilities useful for promoting indirect call sites to +// direct call sites. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/CallPromotionUtils.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "call-promotion-utils" + +/// Fix-up phi nodes in an invoke instruction's normal destination. +/// +/// After versioning an invoke instruction, values coming from the original +/// block will now either be coming from the original block or the "else" block. +static void fixupPHINodeForNormalDest(InvokeInst *Invoke, BasicBlock *OrigBlock, + BasicBlock *ElseBlock, + Instruction *NewInst) { + for (auto &I : *Invoke->getNormalDest()) { + auto *Phi = dyn_cast<PHINode>(&I); + if (!Phi) + break; + int Idx = Phi->getBasicBlockIndex(OrigBlock); + if (Idx == -1) + continue; + Value *V = Phi->getIncomingValue(Idx); + if (dyn_cast<Instruction>(V) == Invoke) { + Phi->setIncomingBlock(Idx, ElseBlock); + Phi->addIncoming(NewInst, OrigBlock); + continue; + } + Phi->addIncoming(V, ElseBlock); + } +} + +/// Fix-up phi nodes in an invoke instruction's unwind destination. +/// +/// After versioning an invoke instruction, values coming from the original +/// block will now be coming from either the "then" block or the "else" block. +static void fixupPHINodeForUnwindDest(InvokeInst *Invoke, BasicBlock *OrigBlock, + BasicBlock *ThenBlock, + BasicBlock *ElseBlock) { + for (auto &I : *Invoke->getUnwindDest()) { + auto *Phi = dyn_cast<PHINode>(&I); + if (!Phi) + break; + int Idx = Phi->getBasicBlockIndex(OrigBlock); + if (Idx == -1) + continue; + auto *V = Phi->getIncomingValue(Idx); + Phi->setIncomingBlock(Idx, ThenBlock); + Phi->addIncoming(V, ElseBlock); + } +} + +/// Get the phi node having the returned value of a call or invoke instruction +/// as it's operand. +static bool getRetPhiNode(Instruction *Inst, BasicBlock *Block) { + BasicBlock *FromBlock = Inst->getParent(); + for (auto &I : *Block) { + PHINode *PHI = dyn_cast<PHINode>(&I); + if (!PHI) + break; + int Idx = PHI->getBasicBlockIndex(FromBlock); + if (Idx == -1) + continue; + auto *V = PHI->getIncomingValue(Idx); + if (V == Inst) + return true; + } + return false; +} + +/// Create a phi node for the returned value of a call or invoke instruction. +/// +/// After versioning a call or invoke instruction that returns a value, we have +/// to merge the value of the original and new instructions. We do this by +/// creating a phi node and replacing uses of the original instruction with this +/// phi node. +static void createRetPHINode(Instruction *OrigInst, Instruction *NewInst) { + + if (OrigInst->getType()->isVoidTy() || OrigInst->use_empty()) + return; + + BasicBlock *RetValBB = NewInst->getParent(); + if (auto *Invoke = dyn_cast<InvokeInst>(NewInst)) + RetValBB = Invoke->getNormalDest(); + BasicBlock *PhiBB = RetValBB->getSingleSuccessor(); + + if (getRetPhiNode(OrigInst, PhiBB)) + return; + + IRBuilder<> Builder(&PhiBB->front()); + PHINode *Phi = Builder.CreatePHI(OrigInst->getType(), 0); + SmallVector<User *, 16> UsersToUpdate; + for (User *U : OrigInst->users()) + UsersToUpdate.push_back(U); + for (User *U : UsersToUpdate) + U->replaceUsesOfWith(OrigInst, Phi); + Phi->addIncoming(OrigInst, OrigInst->getParent()); + Phi->addIncoming(NewInst, RetValBB); +} + +/// Cast a call or invoke instruction to the given type. +/// +/// When promoting a call site, the return type of the call site might not match +/// that of the callee. If this is the case, we have to cast the returned value +/// to the correct type. The location of the cast depends on if we have a call +/// or invoke instruction. +Instruction *createRetBitCast(CallSite CS, Type *RetTy) { + + // Save the users of the calling instruction. These uses will be changed to + // use the bitcast after we create it. + SmallVector<User *, 16> UsersToUpdate; + for (User *U : CS.getInstruction()->users()) + UsersToUpdate.push_back(U); + + // Determine an appropriate location to create the bitcast for the return + // value. The location depends on if we have a call or invoke instruction. + Instruction *InsertBefore = nullptr; + if (auto *Invoke = dyn_cast<InvokeInst>(CS.getInstruction())) + InsertBefore = &*Invoke->getNormalDest()->getFirstInsertionPt(); + else + InsertBefore = &*std::next(CS.getInstruction()->getIterator()); + + // Bitcast the return value to the correct type. + auto *Cast = CastInst::Create(Instruction::BitCast, CS.getInstruction(), + RetTy, "", InsertBefore); + + // Replace all the original uses of the calling instruction with the bitcast. + for (User *U : UsersToUpdate) + U->replaceUsesOfWith(CS.getInstruction(), Cast); + + return Cast; +} + +/// Predicate and clone the given call site. +/// +/// This function creates an if-then-else structure at the location of the call +/// site. The "if" condition compares the call site's called value to the given +/// callee. The original call site is moved into the "else" block, and a clone +/// of the call site is placed in the "then" block. The cloned instruction is +/// returned. +static Instruction *versionCallSite(CallSite CS, Value *Callee, + MDNode *BranchWeights, + BasicBlock *&ThenBlock, + BasicBlock *&ElseBlock, + BasicBlock *&MergeBlock) { + + IRBuilder<> Builder(CS.getInstruction()); + Instruction *OrigInst = CS.getInstruction(); + + // Create the compare. The called value and callee must have the same type to + // be compared. + auto *LHS = + Builder.CreateBitCast(CS.getCalledValue(), Builder.getInt8PtrTy()); + auto *RHS = Builder.CreateBitCast(Callee, Builder.getInt8PtrTy()); + auto *Cond = Builder.CreateICmpEQ(LHS, RHS); + + // Create an if-then-else structure. The original instruction is moved into + // the "else" block, and a clone of the original instruction is placed in the + // "then" block. + TerminatorInst *ThenTerm = nullptr; + TerminatorInst *ElseTerm = nullptr; + SplitBlockAndInsertIfThenElse(Cond, CS.getInstruction(), &ThenTerm, &ElseTerm, + BranchWeights); + ThenBlock = ThenTerm->getParent(); + ElseBlock = ElseTerm->getParent(); + MergeBlock = OrigInst->getParent(); + + ThenBlock->setName("if.true.direct_targ"); + ElseBlock->setName("if.false.orig_indirect"); + MergeBlock->setName("if.end.icp"); + + Instruction *NewInst = OrigInst->clone(); + OrigInst->moveBefore(ElseTerm); + NewInst->insertBefore(ThenTerm); + + // If the original call site is an invoke instruction, we have extra work to + // do since invoke instructions are terminating. + if (auto *OrigInvoke = dyn_cast<InvokeInst>(OrigInst)) { + auto *NewInvoke = cast<InvokeInst>(NewInst); + + // Invoke instructions are terminating, so we don't need the terminator + // instructions that were just created. + ThenTerm->eraseFromParent(); + ElseTerm->eraseFromParent(); + + // Branch from the "merge" block to the original normal destination. + Builder.SetInsertPoint(MergeBlock); + Builder.CreateBr(OrigInvoke->getNormalDest()); + + // Now set the normal destination of new the invoke instruction to be the + // "merge" block. + NewInvoke->setNormalDest(MergeBlock); + } + + return NewInst; +} + +bool llvm::isLegalToPromote(CallSite CS, Function *Callee, + const char **FailureReason) { + assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted"); + + // Check the return type. The callee's return value type must be bitcast + // compatible with the call site's type. + Type *CallRetTy = CS.getInstruction()->getType(); + Type *FuncRetTy = Callee->getReturnType(); + if (CallRetTy != FuncRetTy) + if (!CastInst::isBitCastable(FuncRetTy, CallRetTy)) { + if (FailureReason) + *FailureReason = "Return type mismatch"; + return false; + } + + // The number of formal arguments of the callee. + unsigned NumParams = Callee->getFunctionType()->getNumParams(); + + // Check the number of arguments. The callee and call site must agree on the + // number of arguments. + if (CS.arg_size() != NumParams && !Callee->isVarArg()) { + if (FailureReason) + *FailureReason = "The number of arguments mismatch"; + return false; + } + + // Check the argument types. The callee's formal argument types must be + // bitcast compatible with the corresponding actual argument types of the call + // site. + for (unsigned I = 0; I < NumParams; ++I) { + Type *FormalTy = Callee->getFunctionType()->getFunctionParamType(I); + Type *ActualTy = CS.getArgument(I)->getType(); + if (FormalTy == ActualTy) + continue; + if (!CastInst::isBitCastable(ActualTy, FormalTy)) { + if (FailureReason) + *FailureReason = "Argument type mismatch"; + return false; + } + } + + return true; +} + +static void promoteCall(CallSite CS, Function *Callee, Instruction *&Cast) { + assert(!CS.getCalledFunction() && "Only indirect call sites can be promoted"); + + // Set the called function of the call site to be the given callee. + CS.setCalledFunction(Callee); + + // Since the call site will no longer be direct, we must clear metadata that + // is only appropriate for indirect calls. This includes !prof and !callees + // metadata. + CS.getInstruction()->setMetadata(LLVMContext::MD_prof, nullptr); + CS.getInstruction()->setMetadata(LLVMContext::MD_callees, nullptr); + + // If the function type of the call site matches that of the callee, no + // additional work is required. + if (CS.getFunctionType() == Callee->getFunctionType()) + return; + + // Save the return types of the call site and callee. + Type *CallSiteRetTy = CS.getInstruction()->getType(); + Type *CalleeRetTy = Callee->getReturnType(); + + // Change the function type of the call site the match that of the callee. + CS.mutateFunctionType(Callee->getFunctionType()); + + // Inspect the arguments of the call site. If an argument's type doesn't + // match the corresponding formal argument's type in the callee, bitcast it + // to the correct type. + for (Use &U : CS.args()) { + unsigned ArgNo = CS.getArgumentNo(&U); + Type *FormalTy = Callee->getFunctionType()->getParamType(ArgNo); + Type *ActualTy = U.get()->getType(); + if (FormalTy != ActualTy) { + auto *Cast = CastInst::Create(Instruction::BitCast, U.get(), FormalTy, "", + CS.getInstruction()); + CS.setArgument(ArgNo, Cast); + } + } + + // If the return type of the call site doesn't match that of the callee, cast + // the returned value to the appropriate type. + if (!CallSiteRetTy->isVoidTy() && CallSiteRetTy != CalleeRetTy) + Cast = createRetBitCast(CS, CallSiteRetTy); +} + +Instruction *llvm::promoteCallWithIfThenElse(CallSite CS, Function *Callee, + MDNode *BranchWeights) { + + // Version the indirect call site. If the called value is equal to the given + // callee, 'NewInst' will be executed, otherwise the original call site will + // be executed. + BasicBlock *ThenBlock, *ElseBlock, *MergeBlock; + Instruction *NewInst = versionCallSite(CS, Callee, BranchWeights, ThenBlock, + ElseBlock, MergeBlock); + + // Promote 'NewInst' so that it directly calls the desired function. + Instruction *Cast = NewInst; + promoteCall(CallSite(NewInst), Callee, Cast); + + // If the original call site is an invoke instruction, we have to fix-up phi + // nodes in the invoke's normal and unwind destinations. + if (auto *OrigInvoke = dyn_cast<InvokeInst>(CS.getInstruction())) { + fixupPHINodeForNormalDest(OrigInvoke, MergeBlock, ElseBlock, Cast); + fixupPHINodeForUnwindDest(OrigInvoke, MergeBlock, ThenBlock, ElseBlock); + } + + // Create a phi node for the returned value of the call site. + createRetPHINode(CS.getInstruction(), Cast ? Cast : NewInst); + + // Return the new direct call. + return NewInst; +} + +#undef DEBUG_TYPE diff --git a/lib/Transforms/Utils/CloneFunction.cpp b/lib/Transforms/Utils/CloneFunction.cpp index 9c4e13903ed7..3b19ba1b50f2 100644 --- a/lib/Transforms/Utils/CloneFunction.cpp +++ b/lib/Transforms/Utils/CloneFunction.cpp @@ -747,7 +747,7 @@ Loop *llvm::cloneLoopWithPreheader(BasicBlock *Before, BasicBlock *LoopDomBB, Function *F = OrigLoop->getHeader()->getParent(); Loop *ParentLoop = OrigLoop->getParentLoop(); - Loop *NewLoop = new Loop(); + Loop *NewLoop = LI->AllocateLoop(); if (ParentLoop) ParentLoop->addChildLoop(NewLoop); else diff --git a/lib/Transforms/Utils/CloneModule.cpp b/lib/Transforms/Utils/CloneModule.cpp index e5392b53050d..8fee10854229 100644 --- a/lib/Transforms/Utils/CloneModule.cpp +++ b/lib/Transforms/Utils/CloneModule.cpp @@ -12,7 +12,6 @@ // //===----------------------------------------------------------------------===// -#include "llvm-c/Core.h" #include "llvm/IR/Constant.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Module.h" diff --git a/lib/Transforms/Utils/CmpInstAnalysis.cpp b/lib/Transforms/Utils/CmpInstAnalysis.cpp deleted file mode 100644 index d9294c499309..000000000000 --- a/lib/Transforms/Utils/CmpInstAnalysis.cpp +++ /dev/null @@ -1,108 +0,0 @@ -//===- CmpInstAnalysis.cpp - Utils to help fold compares ---------------===// -// -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. -// -//===----------------------------------------------------------------------===// -// -// This file holds routines to help analyse compare instructions -// and fold them into constants or other compare instructions -// -//===----------------------------------------------------------------------===// - -#include "llvm/Transforms/Utils/CmpInstAnalysis.h" -#include "llvm/IR/Constants.h" -#include "llvm/IR/Instructions.h" - -using namespace llvm; - -unsigned llvm::getICmpCode(const ICmpInst *ICI, bool InvertPred) { - ICmpInst::Predicate Pred = InvertPred ? ICI->getInversePredicate() - : ICI->getPredicate(); - switch (Pred) { - // False -> 0 - case ICmpInst::ICMP_UGT: return 1; // 001 - case ICmpInst::ICMP_SGT: return 1; // 001 - case ICmpInst::ICMP_EQ: return 2; // 010 - case ICmpInst::ICMP_UGE: return 3; // 011 - case ICmpInst::ICMP_SGE: return 3; // 011 - case ICmpInst::ICMP_ULT: return 4; // 100 - case ICmpInst::ICMP_SLT: return 4; // 100 - case ICmpInst::ICMP_NE: return 5; // 101 - case ICmpInst::ICMP_ULE: return 6; // 110 - case ICmpInst::ICMP_SLE: return 6; // 110 - // True -> 7 - default: - llvm_unreachable("Invalid ICmp predicate!"); - } -} - -Value *llvm::getICmpValue(bool Sign, unsigned Code, Value *LHS, Value *RHS, - CmpInst::Predicate &NewICmpPred) { - switch (Code) { - default: llvm_unreachable("Illegal ICmp code!"); - case 0: // False. - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 0); - case 1: NewICmpPred = Sign ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT; break; - case 2: NewICmpPred = ICmpInst::ICMP_EQ; break; - case 3: NewICmpPred = Sign ? ICmpInst::ICMP_SGE : ICmpInst::ICMP_UGE; break; - case 4: NewICmpPred = Sign ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; break; - case 5: NewICmpPred = ICmpInst::ICMP_NE; break; - case 6: NewICmpPred = Sign ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; break; - case 7: // True. - return ConstantInt::get(CmpInst::makeCmpResultType(LHS->getType()), 1); - } - return nullptr; -} - -bool llvm::PredicatesFoldable(ICmpInst::Predicate p1, ICmpInst::Predicate p2) { - return (CmpInst::isSigned(p1) == CmpInst::isSigned(p2)) || - (CmpInst::isSigned(p1) && ICmpInst::isEquality(p2)) || - (CmpInst::isSigned(p2) && ICmpInst::isEquality(p1)); -} - -bool llvm::decomposeBitTestICmp(const ICmpInst *I, CmpInst::Predicate &Pred, - Value *&X, Value *&Y, Value *&Z) { - ConstantInt *C = dyn_cast<ConstantInt>(I->getOperand(1)); - if (!C) - return false; - - switch (I->getPredicate()) { - default: - return false; - case ICmpInst::ICMP_SLT: - // X < 0 is equivalent to (X & SignMask) != 0. - if (!C->isZero()) - return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignMask(C->getBitWidth())); - Pred = ICmpInst::ICMP_NE; - break; - case ICmpInst::ICMP_SGT: - // X > -1 is equivalent to (X & SignMask) == 0. - if (!C->isMinusOne()) - return false; - Y = ConstantInt::get(I->getContext(), APInt::getSignMask(C->getBitWidth())); - Pred = ICmpInst::ICMP_EQ; - break; - case ICmpInst::ICMP_ULT: - // X <u 2^n is equivalent to (X & ~(2^n-1)) == 0. - if (!C->getValue().isPowerOf2()) - return false; - Y = ConstantInt::get(I->getContext(), -C->getValue()); - Pred = ICmpInst::ICMP_EQ; - break; - case ICmpInst::ICMP_UGT: - // X >u 2^n-1 is equivalent to (X & ~(2^n-1)) != 0. - if (!(C->getValue() + 1).isPowerOf2()) - return false; - Y = ConstantInt::get(I->getContext(), ~C->getValue()); - Pred = ICmpInst::ICMP_NE; - break; - } - - X = I->getOperand(0); - Z = ConstantInt::getNullValue(C->getType()); - return true; -} diff --git a/lib/Transforms/Utils/CodeExtractor.cpp b/lib/Transforms/Utils/CodeExtractor.cpp index 1189714dfab1..7a404241cb14 100644 --- a/lib/Transforms/Utils/CodeExtractor.cpp +++ b/lib/Transforms/Utils/CodeExtractor.cpp @@ -14,34 +14,57 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/CodeExtractor.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" -#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/BlockFrequencyInfoImpl.h" #include "llvm/Analysis/BranchProbabilityInfo.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/RegionInfo.h" -#include "llvm/Analysis/RegionIterator.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" #include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <map> #include <set> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "code-extractor" @@ -55,7 +78,8 @@ AggregateArgsOpt("aggregate-extracted-args", cl::Hidden, cl::desc("Aggregate arguments to code-extracted functions")); /// \brief Test whether a block is valid for extraction. -bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { +bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB, + bool AllowVarArgs) { // Landing pads must be in the function where they were inserted for cleanup. if (BB.isEHPad()) return false; @@ -87,14 +111,19 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { } } - // Don't hoist code containing allocas, invokes, or vastarts. + // Don't hoist code containing allocas or invokes. If explicitly requested, + // allow vastart. for (BasicBlock::const_iterator I = BB.begin(), E = BB.end(); I != E; ++I) { if (isa<AllocaInst>(I) || isa<InvokeInst>(I)) return false; if (const CallInst *CI = dyn_cast<CallInst>(I)) if (const Function *F = CI->getCalledFunction()) - if (F->getIntrinsicID() == Intrinsic::vastart) - return false; + if (F->getIntrinsicID() == Intrinsic::vastart) { + if (AllowVarArgs) + continue; + else + return false; + } } return true; @@ -102,21 +131,21 @@ bool CodeExtractor::isBlockValidForExtraction(const BasicBlock &BB) { /// \brief Build a set of blocks to extract if the input blocks are viable. static SetVector<BasicBlock *> -buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) { +buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, + bool AllowVarArgs) { assert(!BBs.empty() && "The set of blocks to extract must be non-empty"); SetVector<BasicBlock *> Result; // Loop over the blocks, adding them to our set-vector, and aborting with an // empty set if we encounter invalid blocks. for (BasicBlock *BB : BBs) { - // If this block is dead, don't process it. if (DT && !DT->isReachableFromEntry(BB)) continue; if (!Result.insert(BB)) llvm_unreachable("Repeated basic blocks in extraction input"); - if (!CodeExtractor::isBlockValidForExtraction(*BB)) { + if (!CodeExtractor::isBlockValidForExtraction(*BB, AllowVarArgs)) { Result.clear(); return Result; } @@ -138,16 +167,18 @@ buildExtractionBlockSet(ArrayRef<BasicBlock *> BBs, DominatorTree *DT) { CodeExtractor::CodeExtractor(ArrayRef<BasicBlock *> BBs, DominatorTree *DT, bool AggregateArgs, BlockFrequencyInfo *BFI, - BranchProbabilityInfo *BPI) + BranchProbabilityInfo *BPI, bool AllowVarArgs) : DT(DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(BBs, DT)), NumExitBlocks(~0U) {} + BPI(BPI), AllowVarArgs(AllowVarArgs), + Blocks(buildExtractionBlockSet(BBs, DT, AllowVarArgs)) {} CodeExtractor::CodeExtractor(DominatorTree &DT, Loop &L, bool AggregateArgs, BlockFrequencyInfo *BFI, BranchProbabilityInfo *BPI) : DT(&DT), AggregateArgs(AggregateArgs || AggregateArgsOpt), BFI(BFI), - BPI(BPI), Blocks(buildExtractionBlockSet(L.getBlocks(), &DT)), - NumExitBlocks(~0U) {} + BPI(BPI), AllowVarArgs(false), + Blocks(buildExtractionBlockSet(L.getBlocks(), &DT, + /* AllowVarArgs */ false)) {} /// definedInRegion - Return true if the specified value is defined in the /// extracted region. @@ -202,7 +233,6 @@ bool CodeExtractor::isLegalToShrinkwrapLifetimeMarkers( if (Blocks.count(&BB)) continue; for (Instruction &II : BB) { - if (isa<DbgInfoIntrinsic>(II)) continue; @@ -287,7 +317,9 @@ CodeExtractor::findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock) { BasicBlock *NewExitBlock = CommonExitBlock->splitBasicBlock( CommonExitBlock->getFirstNonPHI()->getIterator()); - for (auto *Pred : predecessors(CommonExitBlock)) { + for (auto PI = pred_begin(CommonExitBlock), PE = pred_end(CommonExitBlock); + PI != PE;) { + BasicBlock *Pred = *PI++; if (Blocks.count(Pred)) continue; Pred->getTerminator()->replaceUsesOfWith(CommonExitBlock, NewExitBlock); @@ -373,7 +405,6 @@ void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, // Follow the bitcast. Instruction *MarkerAddr = nullptr; for (User *U : AI->users()) { - if (U->stripInBoundsConstantOffsets() == AI) { SinkLifeStart = false; HoistLifeEnd = false; @@ -407,7 +438,6 @@ void CodeExtractor::findAllocas(ValueSet &SinkCands, ValueSet &HoistCands, void CodeExtractor::findInputsOutputs(ValueSet &Inputs, ValueSet &Outputs, const ValueSet &SinkCands) const { - for (BasicBlock *BB : Blocks) { // If a used value is defined outside the region, it's an input. If an // instruction is used outside the region, it's an output. @@ -457,7 +487,7 @@ void CodeExtractor::severSplitPHINodes(BasicBlock *&Header) { // containing PHI nodes merging values from outside of the region, and a // second that contains all of the code for the block and merges back any // incoming values from inside of the region. - BasicBlock *NewBB = llvm::SplitBlock(Header, Header->getFirstNonPHI(), DT); + BasicBlock *NewBB = SplitBlock(Header, Header->getFirstNonPHI(), DT); // We only want to code extract the second block now, and it becomes the new // header of the region. @@ -525,7 +555,6 @@ void CodeExtractor::splitReturnBlocks() { /// constructFunction - make a function based on inputs and outputs, as follows: /// f(in0, ..., inN, out0, ..., outN) -/// Function *CodeExtractor::constructFunction(const ValueSet &inputs, const ValueSet &outputs, BasicBlock *header, @@ -544,7 +573,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, default: RetTy = Type::getInt16Ty(header->getContext()); break; } - std::vector<Type*> paramTy; + std::vector<Type *> paramTy; // Add the types of the input values to the function's argument list for (Value *value : inputs) { @@ -575,7 +604,8 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, paramTy.push_back(PointerType::getUnqual(StructTy)); } FunctionType *funcType = - FunctionType::get(RetTy, paramTy, false); + FunctionType::get(RetTy, paramTy, + AllowVarArgs && oldFunction->isVarArg()); // Create the new function Function *newFunction = Function::Create(funcType, @@ -620,7 +650,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, } else RewriteVal = &*AI++; - std::vector<User*> Users(inputs[i]->user_begin(), inputs[i]->user_end()); + std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end()); for (User *use : Users) if (Instruction *inst = dyn_cast<Instruction>(use)) if (Blocks.count(inst->getParent())) @@ -639,7 +669,7 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, // Rewrite branches to basic blocks outside of the loop to new dummy blocks // within the new function. This must be done before we lose track of which // blocks were originally in the code region. - std::vector<User*> Users(header->user_begin(), header->user_end()); + std::vector<User *> Users(header->user_begin(), header->user_end()); for (unsigned i = 0, e = Users.size(); i != e; ++i) // The BasicBlock which contains the branch is not in the region // modify the branch target to a new block @@ -651,19 +681,6 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs, return newFunction; } -/// FindPhiPredForUseInBlock - Given a value and a basic block, find a PHI -/// that uses the value within the basic block, and return the predecessor -/// block associated with that use, or return 0 if none is found. -static BasicBlock* FindPhiPredForUseInBlock(Value* Used, BasicBlock* BB) { - for (Use &U : Used->uses()) { - PHINode *P = dyn_cast<PHINode>(U.getUser()); - if (P && P->getParent() == BB) - return P->getIncomingBlock(U); - } - - return nullptr; -} - /// emitCallAndSwitchStatement - This method sets up the caller side by adding /// the call instruction, splitting any PHI nodes in the header block as /// necessary. @@ -672,7 +689,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, ValueSet &inputs, ValueSet &outputs) { // Emit a call to the new function, passing in: *pointer to struct (if // aggregating parameters), or plan inputs and allocated memory for outputs - std::vector<Value*> params, StructValues, ReloadOutputs, Reloads; + std::vector<Value *> params, StructValues, ReloadOutputs, Reloads; Module *M = newFunction->getParent(); LLVMContext &Context = M->getContext(); @@ -702,7 +719,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, StructType *StructArgTy = nullptr; AllocaInst *Struct = nullptr; if (AggregateArgs && (inputs.size() + outputs.size() > 0)) { - std::vector<Type*> ArgTypes; + std::vector<Type *> ArgTypes; for (ValueSet::iterator v = StructValues.begin(), ve = StructValues.end(); v != ve; ++v) ArgTypes.push_back((*v)->getType()); @@ -729,6 +746,14 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, // Emit the call to the function CallInst *call = CallInst::Create(newFunction, params, NumExitBlocks > 1 ? "targetBlock" : ""); + // Add debug location to the new call, if the original function has debug + // info. In that case, the terminator of the entry block of the extracted + // function contains the first debug location of the extracted function, + // set in extractCodeRegion. + if (codeReplacer->getParent()->getSubprogram()) { + if (auto DL = newFunction->getEntryBlock().getTerminator()->getDebugLoc()) + call->setDebugLoc(DL); + } codeReplacer->getInstList().push_back(call); Function::arg_iterator OutputArgBegin = newFunction->arg_begin(); @@ -736,7 +761,8 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, if (!AggregateArgs) std::advance(OutputArgBegin, inputs.size()); - // Reload the outputs passed in by reference + // Reload the outputs passed in by reference. + Function::arg_iterator OAI = OutputArgBegin; for (unsigned i = 0, e = outputs.size(); i != e; ++i) { Value *Output = nullptr; if (AggregateArgs) { @@ -753,12 +779,40 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, LoadInst *load = new LoadInst(Output, outputs[i]->getName()+".reload"); Reloads.push_back(load); codeReplacer->getInstList().push_back(load); - std::vector<User*> Users(outputs[i]->user_begin(), outputs[i]->user_end()); + std::vector<User *> Users(outputs[i]->user_begin(), outputs[i]->user_end()); for (unsigned u = 0, e = Users.size(); u != e; ++u) { Instruction *inst = cast<Instruction>(Users[u]); if (!Blocks.count(inst->getParent())) inst->replaceUsesOfWith(outputs[i], load); } + + // Store to argument right after the definition of output value. + auto *OutI = dyn_cast<Instruction>(outputs[i]); + if (!OutI) + continue; + // Find proper insertion point. + Instruction *InsertPt = OutI->getNextNode(); + // Let's assume that there is no other guy interleave non-PHI in PHIs. + if (isa<PHINode>(InsertPt)) + InsertPt = InsertPt->getParent()->getFirstNonPHI(); + + assert(OAI != newFunction->arg_end() && + "Number of output arguments should match " + "the amount of defined values"); + if (AggregateArgs) { + Value *Idx[2]; + Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); + Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i); + GetElementPtrInst *GEP = GetElementPtrInst::Create( + StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(), InsertPt); + new StoreInst(outputs[i], GEP, InsertPt); + // Since there should be only one struct argument aggregating + // all the output values, we shouldn't increment OAI, which always + // points to the struct argument, in this case. + } else { + new StoreInst(outputs[i], &*OAI, InsertPt); + ++OAI; + } } // Now we can emit a switch statement using the call as a value. @@ -771,7 +825,7 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, // over all of the blocks in the extracted region, updating any terminator // instructions in the to-be-extracted region that branch to blocks that are // not in the region to be extracted. - std::map<BasicBlock*, BasicBlock*> ExitBlockMap; + std::map<BasicBlock *, BasicBlock *> ExitBlockMap; unsigned switchVal = 0; for (BasicBlock *Block : Blocks) { @@ -801,75 +855,12 @@ emitCallAndSwitchStatement(Function *newFunction, BasicBlock *codeReplacer, break; } - ReturnInst *NTRet = ReturnInst::Create(Context, brVal, NewTarget); + ReturnInst::Create(Context, brVal, NewTarget); // Update the switch instruction. TheSwitch->addCase(ConstantInt::get(Type::getInt16Ty(Context), SuccNum), OldTarget); - - // Restore values just before we exit - Function::arg_iterator OAI = OutputArgBegin; - for (unsigned out = 0, e = outputs.size(); out != e; ++out) { - // For an invoke, the normal destination is the only one that is - // dominated by the result of the invocation - BasicBlock *DefBlock = cast<Instruction>(outputs[out])->getParent(); - - bool DominatesDef = true; - - BasicBlock *NormalDest = nullptr; - if (auto *Invoke = dyn_cast<InvokeInst>(outputs[out])) - NormalDest = Invoke->getNormalDest(); - - if (NormalDest) { - DefBlock = NormalDest; - - // Make sure we are looking at the original successor block, not - // at a newly inserted exit block, which won't be in the dominator - // info. - for (const auto &I : ExitBlockMap) - if (DefBlock == I.second) { - DefBlock = I.first; - break; - } - - // In the extract block case, if the block we are extracting ends - // with an invoke instruction, make sure that we don't emit a - // store of the invoke value for the unwind block. - if (!DT && DefBlock != OldTarget) - DominatesDef = false; - } - - if (DT) { - DominatesDef = DT->dominates(DefBlock, OldTarget); - - // If the output value is used by a phi in the target block, - // then we need to test for dominance of the phi's predecessor - // instead. Unfortunately, this a little complicated since we - // have already rewritten uses of the value to uses of the reload. - BasicBlock* pred = FindPhiPredForUseInBlock(Reloads[out], - OldTarget); - if (pred && DT && DT->dominates(DefBlock, pred)) - DominatesDef = true; - } - - if (DominatesDef) { - if (AggregateArgs) { - Value *Idx[2]; - Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context)); - Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), - FirstOut+out); - GetElementPtrInst *GEP = GetElementPtrInst::Create( - StructArgTy, &*OAI, Idx, "gep_" + outputs[out]->getName(), - NTRet); - new StoreInst(outputs[out], GEP, NTRet); - } else { - new StoreInst(outputs[out], &*OAI, NTRet); - } - } - // Advance output iterator even if we don't emit a store - if (!AggregateArgs) ++OAI; - } } // rewrite the original branch instruction with this new target @@ -940,8 +931,8 @@ void CodeExtractor::calculateNewCallTerminatorWeights( BasicBlock *CodeReplacer, DenseMap<BasicBlock *, BlockFrequency> &ExitWeights, BranchProbabilityInfo *BPI) { - typedef BlockFrequencyInfoImplBase::Distribution Distribution; - typedef BlockFrequencyInfoImplBase::BlockNode BlockNode; + using Distribution = BlockFrequencyInfoImplBase::Distribution; + using BlockNode = BlockFrequencyInfoImplBase::BlockNode; // Update the branch weights for the exit block. TerminatorInst *TI = CodeReplacer->getTerminator(); @@ -985,12 +976,31 @@ Function *CodeExtractor::extractCodeRegion() { if (!isEligible()) return nullptr; - ValueSet inputs, outputs, SinkingCands, HoistingCands; - BasicBlock *CommonExit = nullptr; - // Assumption: this is a single-entry code region, and the header is the first // block in the region. BasicBlock *header = *Blocks.begin(); + Function *oldFunction = header->getParent(); + + // For functions with varargs, check that varargs handling is only done in the + // outlined function, i.e vastart and vaend are only used in outlined blocks. + if (AllowVarArgs && oldFunction->getFunctionType()->isVarArg()) { + auto containsVarArgIntrinsic = [](Instruction &I) { + if (const CallInst *CI = dyn_cast<CallInst>(&I)) + if (const Function *F = CI->getCalledFunction()) + return F->getIntrinsicID() == Intrinsic::vastart || + F->getIntrinsicID() == Intrinsic::vaend; + return false; + }; + + for (auto &BB : *oldFunction) { + if (Blocks.count(&BB)) + continue; + if (llvm::any_of(BB, containsVarArgIntrinsic)) + return nullptr; + } + } + ValueSet inputs, outputs, SinkingCands, HoistingCands; + BasicBlock *CommonExit = nullptr; // Calculate the entry frequency of the new function before we change the root // block. @@ -1012,8 +1022,6 @@ Function *CodeExtractor::extractCodeRegion() { // that the return is not in the region. splitReturnBlocks(); - Function *oldFunction = header->getParent(); - // This takes place of the original loop BasicBlock *codeReplacer = BasicBlock::Create(header->getContext(), "codeRepl", oldFunction, @@ -1023,7 +1031,22 @@ Function *CodeExtractor::extractCodeRegion() { // head of the region, but the entry node of a function cannot have preds. BasicBlock *newFuncRoot = BasicBlock::Create(header->getContext(), "newFuncRoot"); - newFuncRoot->getInstList().push_back(BranchInst::Create(header)); + auto *BranchI = BranchInst::Create(header); + // If the original function has debug info, we have to add a debug location + // to the new branch instruction from the artificial entry block. + // We use the debug location of the first instruction in the extracted + // blocks, as there is no other equivalent line in the source code. + if (oldFunction->getSubprogram()) { + any_of(Blocks, [&BranchI](const BasicBlock *BB) { + return any_of(*BB, [&BranchI](const Instruction &I) { + if (!I.getDebugLoc()) + return false; + BranchI->setDebugLoc(I.getDebugLoc()); + return true; + }); + }); + } + newFuncRoot->getInstList().push_back(BranchI); findAllocas(SinkingCands, HoistingCands, CommonExit); assert(HoistingCands.empty() || CommonExit); @@ -1044,7 +1067,7 @@ Function *CodeExtractor::extractCodeRegion() { } // Calculate the exit blocks for the extracted region and the total exit - // weights for each of those blocks. + // weights for each of those blocks. DenseMap<BasicBlock *, BlockFrequency> ExitWeights; SmallPtrSet<BasicBlock *, 1> ExitBlocks; for (BasicBlock *Block : Blocks) { @@ -1097,8 +1120,8 @@ Function *CodeExtractor::extractCodeRegion() { // Look at all successors of the codeReplacer block. If any of these blocks // had PHI nodes in them, we need to update the "from" block to be the code // replacer, not the original block in the extracted region. - std::vector<BasicBlock*> Succs(succ_begin(codeReplacer), - succ_end(codeReplacer)); + std::vector<BasicBlock *> Succs(succ_begin(codeReplacer), + succ_end(codeReplacer)); for (unsigned i = 0, e = Succs.size(); i != e; ++i) for (BasicBlock::iterator I = Succs[i]->begin(); isa<PHINode>(I); ++I) { PHINode *PN = cast<PHINode>(I); diff --git a/lib/Transforms/Utils/CtorUtils.cpp b/lib/Transforms/Utils/CtorUtils.cpp index 6642a97a29c2..82b67c293102 100644 --- a/lib/Transforms/Utils/CtorUtils.cpp +++ b/lib/Transforms/Utils/CtorUtils.cpp @@ -16,7 +16,6 @@ #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalVariable.h" -#include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" diff --git a/lib/Transforms/Utils/EntryExitInstrumenter.cpp b/lib/Transforms/Utils/EntryExitInstrumenter.cpp new file mode 100644 index 000000000000..421663f82565 --- /dev/null +++ b/lib/Transforms/Utils/EntryExitInstrumenter.cpp @@ -0,0 +1,163 @@ +//===- EntryExitInstrumenter.cpp - Function Entry/Exit Instrumentation ----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Utils/EntryExitInstrumenter.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/CodeGen/Passes.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +using namespace llvm; + +static void insertCall(Function &CurFn, StringRef Func, + Instruction *InsertionPt, DebugLoc DL) { + Module &M = *InsertionPt->getParent()->getParent()->getParent(); + LLVMContext &C = InsertionPt->getParent()->getContext(); + + if (Func == "mcount" || + Func == ".mcount" || + Func == "\01__gnu_mcount_nc" || + Func == "\01_mcount" || + Func == "\01mcount" || + Func == "__mcount" || + Func == "_mcount" || + Func == "__cyg_profile_func_enter_bare") { + Constant *Fn = M.getOrInsertFunction(Func, Type::getVoidTy(C)); + CallInst *Call = CallInst::Create(Fn, "", InsertionPt); + Call->setDebugLoc(DL); + return; + } + + if (Func == "__cyg_profile_func_enter" || Func == "__cyg_profile_func_exit") { + Type *ArgTypes[] = {Type::getInt8PtrTy(C), Type::getInt8PtrTy(C)}; + + Constant *Fn = M.getOrInsertFunction( + Func, FunctionType::get(Type::getVoidTy(C), ArgTypes, false)); + + Instruction *RetAddr = CallInst::Create( + Intrinsic::getDeclaration(&M, Intrinsic::returnaddress), + ArrayRef<Value *>(ConstantInt::get(Type::getInt32Ty(C), 0)), "", + InsertionPt); + RetAddr->setDebugLoc(DL); + + Value *Args[] = {ConstantExpr::getBitCast(&CurFn, Type::getInt8PtrTy(C)), + RetAddr}; + + CallInst *Call = + CallInst::Create(Fn, ArrayRef<Value *>(Args), "", InsertionPt); + Call->setDebugLoc(DL); + return; + } + + // We only know how to call a fixed set of instrumentation functions, because + // they all expect different arguments, etc. + report_fatal_error(Twine("Unknown instrumentation function: '") + Func + "'"); +} + +static bool runOnFunction(Function &F, bool PostInlining) { + StringRef EntryAttr = PostInlining ? "instrument-function-entry-inlined" + : "instrument-function-entry"; + + StringRef ExitAttr = PostInlining ? "instrument-function-exit-inlined" + : "instrument-function-exit"; + + StringRef EntryFunc = F.getFnAttribute(EntryAttr).getValueAsString(); + StringRef ExitFunc = F.getFnAttribute(ExitAttr).getValueAsString(); + + bool Changed = false; + + // If the attribute is specified, insert instrumentation and then "consume" + // the attribute so that it's not inserted again if the pass should happen to + // run later for some reason. + + if (!EntryFunc.empty()) { + DebugLoc DL; + if (auto SP = F.getSubprogram()) + DL = DebugLoc::get(SP->getScopeLine(), 0, SP); + + insertCall(F, EntryFunc, &*F.begin()->getFirstInsertionPt(), DL); + Changed = true; + F.removeAttribute(AttributeList::FunctionIndex, EntryAttr); + } + + if (!ExitFunc.empty()) { + for (BasicBlock &BB : F) { + TerminatorInst *T = BB.getTerminator(); + DebugLoc DL; + if (DebugLoc TerminatorDL = T->getDebugLoc()) + DL = TerminatorDL; + else if (auto SP = F.getSubprogram()) + DL = DebugLoc::get(0, 0, SP); + + if (isa<ReturnInst>(T)) { + insertCall(F, ExitFunc, T, DL); + Changed = true; + } + } + F.removeAttribute(AttributeList::FunctionIndex, ExitAttr); + } + + return Changed; +} + +namespace { +struct EntryExitInstrumenter : public FunctionPass { + static char ID; + EntryExitInstrumenter() : FunctionPass(ID) { + initializeEntryExitInstrumenterPass(*PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<GlobalsAAWrapperPass>(); + } + bool runOnFunction(Function &F) override { return ::runOnFunction(F, false); } +}; +char EntryExitInstrumenter::ID = 0; + +struct PostInlineEntryExitInstrumenter : public FunctionPass { + static char ID; + PostInlineEntryExitInstrumenter() : FunctionPass(ID) { + initializePostInlineEntryExitInstrumenterPass( + *PassRegistry::getPassRegistry()); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<GlobalsAAWrapperPass>(); + } + bool runOnFunction(Function &F) override { return ::runOnFunction(F, true); } +}; +char PostInlineEntryExitInstrumenter::ID = 0; +} + +INITIALIZE_PASS( + EntryExitInstrumenter, "ee-instrument", + "Instrument function entry/exit with calls to e.g. mcount() (pre inlining)", + false, false) +INITIALIZE_PASS(PostInlineEntryExitInstrumenter, "post-inline-ee-instrument", + "Instrument function entry/exit with calls to e.g. mcount() " + "(post inlining)", + false, false) + +FunctionPass *llvm::createEntryExitInstrumenterPass() { + return new EntryExitInstrumenter(); +} + +FunctionPass *llvm::createPostInlineEntryExitInstrumenterPass() { + return new PostInlineEntryExitInstrumenter(); +} + +PreservedAnalyses +llvm::EntryExitInstrumenterPass::run(Function &F, FunctionAnalysisManager &AM) { + runOnFunction(F, PostInlining); + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} diff --git a/lib/Transforms/Utils/Evaluator.cpp b/lib/Transforms/Utils/Evaluator.cpp index 1328f2f3ec01..3c5e299fae98 100644 --- a/lib/Transforms/Utils/Evaluator.cpp +++ b/lib/Transforms/Utils/Evaluator.cpp @@ -12,19 +12,33 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/Evaluator.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/DerivedTypes.h" -#include "llvm/IR/DiagnosticPrinter.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include <iterator> #define DEBUG_TYPE "evaluator" @@ -193,7 +207,7 @@ Constant *Evaluator::ComputeLoadResult(Constant *P) { bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, BasicBlock *&NextBB) { // This is the main evaluation loop. - while (1) { + while (true) { Constant *InstResult = nullptr; DEBUG(dbgs() << "Evaluating Instruction: " << *CurInst << "\n"); @@ -318,7 +332,6 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, DEBUG(dbgs() << "Found a GEP! Simplifying: " << *InstResult << "\n"); } else if (LoadInst *LI = dyn_cast<LoadInst>(CurInst)) { - if (!LI->isSimple()) { DEBUG(dbgs() << "Found a Load! Not a simple load, can not evaluate.\n"); return false; // no volatile/atomic accesses. @@ -344,9 +357,9 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, return false; // Cannot handle array allocs. } Type *Ty = AI->getAllocatedType(); - AllocaTmps.push_back( - make_unique<GlobalVariable>(Ty, false, GlobalValue::InternalLinkage, - UndefValue::get(Ty), AI->getName())); + AllocaTmps.push_back(llvm::make_unique<GlobalVariable>( + Ty, false, GlobalValue::InternalLinkage, UndefValue::get(Ty), + AI->getName())); InstResult = AllocaTmps.back().get(); DEBUG(dbgs() << "Found an alloca. Result: " << *InstResult << "\n"); } else if (isa<CallInst>(CurInst) || isa<InvokeInst>(CurInst)) { @@ -420,6 +433,10 @@ bool Evaluator::EvaluateBlock(BasicBlock::iterator CurInst, DEBUG(dbgs() << "Skipping assume intrinsic.\n"); ++CurInst; continue; + } else if (II->getIntrinsicID() == Intrinsic::sideeffect) { + DEBUG(dbgs() << "Skipping sideeffect intrinsic.\n"); + ++CurInst; + continue; } DEBUG(dbgs() << "Unknown intrinsic. Can not evaluate.\n"); @@ -559,7 +576,7 @@ bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, BasicBlock::iterator CurInst = CurBB->begin(); - while (1) { + while (true) { BasicBlock *NextBB = nullptr; // Initialized to avoid compiler warnings. DEBUG(dbgs() << "Trying to evaluate BB: " << *CurBB << "\n"); @@ -594,4 +611,3 @@ bool Evaluator::EvaluateFunction(Function *F, Constant *&RetVal, CurBB = NextBB; } } - diff --git a/lib/Transforms/Utils/FlattenCFG.cpp b/lib/Transforms/Utils/FlattenCFG.cpp index 435eff3bef47..5fdcc6d1d727 100644 --- a/lib/Transforms/Utils/FlattenCFG.cpp +++ b/lib/Transforms/Utils/FlattenCFG.cpp @@ -1,4 +1,4 @@ -//===- FlatternCFG.cpp - Code to perform CFG flattening ---------------===// +//===- FlatternCFG.cpp - Code to perform CFG flattening -------------------===// // // The LLVM Compiler Infrastructure // @@ -14,25 +14,37 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Local.h" +#include <cassert> + using namespace llvm; #define DEBUG_TYPE "flattencfg" namespace { + class FlattenCFGOpt { AliasAnalysis *AA; + /// \brief Use parallel-and or parallel-or to generate conditions for /// conditional branches. bool FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder); + /// \brief If \param BB is the merge block of an if-region, attempt to merge /// the if-region with an adjacent if-region upstream if two if-regions /// contain identical instructions. bool MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder); + /// \brief Compare a pair of blocks: \p Block1 and \p Block2, which /// are from two if-regions whose entry blocks are \p Head1 and \p /// Head2. \returns true if \p Block1 and \p Block2 contain identical @@ -43,9 +55,11 @@ class FlattenCFGOpt { public: FlattenCFGOpt(AliasAnalysis *AA) : AA(AA) {} + bool run(BasicBlock *BB); }; -} + +} // end anonymous namespace /// If \param [in] BB has more than one predecessor that is a conditional /// branch, attempt to use parallel and/or for the branch condition. \returns @@ -120,7 +134,6 @@ public: /// In Case 1, \param BB (BB4) has an unconditional branch (BB3) as /// its predecessor. In Case 2, \param BB (BB3) only has conditional branches /// as its predecessors. -/// bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { PHINode *PHI = dyn_cast<PHINode>(BB->begin()); if (PHI) @@ -237,8 +250,8 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { // Do branch inversion. BasicBlock *CurrBlock = LastCondBlock; bool EverChanged = false; - for (;CurrBlock != FirstCondBlock; - CurrBlock = CurrBlock->getSinglePredecessor()) { + for (; CurrBlock != FirstCondBlock; + CurrBlock = CurrBlock->getSinglePredecessor()) { BranchInst *BI = dyn_cast<BranchInst>(CurrBlock->getTerminator()); CmpInst *CI = dyn_cast<CmpInst>(BI->getCondition()); if (!CI) @@ -309,7 +322,6 @@ bool FlattenCFGOpt::FlattenParallelAndOr(BasicBlock *BB, IRBuilder<> &Builder) { // in the 2nd if-region to compare. \returns true if \param Block1 and \param /// Block2 have identical instructions and do not have memory reference alias /// with \param Head2. -/// bool FlattenCFGOpt::CompareIfRegionBlock(BasicBlock *Head1, BasicBlock *Head2, BasicBlock *Block1, BasicBlock *Block2) { @@ -330,7 +342,7 @@ bool FlattenCFGOpt::CompareIfRegionBlock(BasicBlock *Head1, BasicBlock *Head2, BasicBlock::iterator iter2 = Block2->begin(); BasicBlock::iterator end2 = Block2->getTerminator()->getIterator(); - while (1) { + while (true) { if (iter1 == end1) { if (iter2 != end2) return false; @@ -384,7 +396,6 @@ bool FlattenCFGOpt::CompareIfRegionBlock(BasicBlock *Head1, BasicBlock *Head2, /// To: /// if (a || b) /// statement; -/// bool FlattenCFGOpt::MergeIfRegion(BasicBlock *BB, IRBuilder<> &Builder) { BasicBlock *IfTrue2, *IfFalse2; Value *IfCond2 = GetIfCondition(BB, IfTrue2, IfFalse2); @@ -475,8 +486,7 @@ bool FlattenCFGOpt::run(BasicBlock *BB) { /// FlattenCFG - This function is used to flatten a CFG. For /// example, it uses parallel-and and parallel-or mode to collapse -// if-conditions and merge if-regions with identical statements. -/// +/// if-conditions and merge if-regions with identical statements. bool llvm::FlattenCFG(BasicBlock *BB, AliasAnalysis *AA) { return FlattenCFGOpt(AA).run(BB); } diff --git a/lib/Transforms/Utils/FunctionComparator.cpp b/lib/Transforms/Utils/FunctionComparator.cpp index 4a2be3a53176..bddcbd86e914 100644 --- a/lib/Transforms/Utils/FunctionComparator.cpp +++ b/lib/Transforms/Utils/FunctionComparator.cpp @@ -13,13 +13,41 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/FunctionComparator.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <utility> using namespace llvm; @@ -160,7 +188,6 @@ int FunctionComparator::cmpOperandBundlesSchema(const Instruction *L, /// For more details see declaration comments. int FunctionComparator::cmpConstants(const Constant *L, const Constant *R) const { - Type *TyL = L->getType(); Type *TyR = R->getType(); @@ -226,8 +253,8 @@ int FunctionComparator::cmpConstants(const Constant *L, if (!L->isNullValue() && R->isNullValue()) return -1; - auto GlobalValueL = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(L)); - auto GlobalValueR = const_cast<GlobalValue*>(dyn_cast<GlobalValue>(R)); + auto GlobalValueL = const_cast<GlobalValue *>(dyn_cast<GlobalValue>(L)); + auto GlobalValueR = const_cast<GlobalValue *>(dyn_cast<GlobalValue>(R)); if (GlobalValueL && GlobalValueR) { return cmpGlobalValues(GlobalValueL, GlobalValueR); } @@ -401,10 +428,9 @@ int FunctionComparator::cmpTypes(Type *TyL, Type *TyR) const { case Type::TokenTyID: return 0; - case Type::PointerTyID: { + case Type::PointerTyID: assert(PTyL && PTyR && "Both types must be pointers here."); return cmpNumbers(PTyL->getAddressSpace(), PTyR->getAddressSpace()); - } case Type::StructTyID: { StructType *STyL = cast<StructType>(TyL); @@ -637,7 +663,6 @@ int FunctionComparator::cmpOperations(const Instruction *L, // Read method declaration comments for more details. int FunctionComparator::cmpGEPs(const GEPOperator *GEPL, const GEPOperator *GEPR) const { - unsigned int ASL = GEPL->getPointerAddressSpace(); unsigned int ASR = GEPR->getPointerAddressSpace(); @@ -869,15 +894,19 @@ namespace { // buffer. class HashAccumulator64 { uint64_t Hash; + public: // Initialize to random constant, so the state isn't zero. HashAccumulator64() { Hash = 0x6acaa36bef8325c5ULL; } + void add(uint64_t V) { - Hash = llvm::hashing::detail::hash_16_bytes(Hash, V); + Hash = hashing::detail::hash_16_bytes(Hash, V); } + // No finishing is required, because the entire hash value is used. uint64_t getHash() { return Hash; } }; + } // end anonymous namespace // A function hash is calculated by considering only the number of arguments and @@ -919,5 +948,3 @@ FunctionComparator::FunctionHash FunctionComparator::functionHash(Function &F) { } return H.getHash(); } - - diff --git a/lib/Transforms/Utils/FunctionImportUtils.cpp b/lib/Transforms/Utils/FunctionImportUtils.cpp index a98d07237b47..6b5f593073b4 100644 --- a/lib/Transforms/Utils/FunctionImportUtils.cpp +++ b/lib/Transforms/Utils/FunctionImportUtils.cpp @@ -13,9 +13,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/FunctionImportUtils.h" -#include "llvm/Analysis/ModuleSummaryAnalysis.h" #include "llvm/IR/InstIterator.h" -#include "llvm/IR/Instructions.h" using namespace llvm; /// Checks if we should import SGV as a definition, otherwise import as a @@ -23,21 +21,15 @@ using namespace llvm; bool FunctionImportGlobalProcessing::doImportAsDefinition( const GlobalValue *SGV, SetVector<GlobalValue *> *GlobalsToImport) { - // For alias, we tie the definition to the base object. Extract it and recurse - if (auto *GA = dyn_cast<GlobalAlias>(SGV)) { - if (GA->isInterposable()) - return false; - const GlobalObject *GO = GA->getBaseObject(); - if (!GO->hasLinkOnceODRLinkage()) - return false; - return FunctionImportGlobalProcessing::doImportAsDefinition( - GO, GlobalsToImport); - } // Only import the globals requested for importing. - if (GlobalsToImport->count(const_cast<GlobalValue *>(SGV))) - return true; - // Otherwise no. - return false; + if (!GlobalsToImport->count(const_cast<GlobalValue *>(SGV))) + return false; + + assert(!isa<GlobalAlias>(SGV) && + "Unexpected global alias in the import list."); + + // Otherwise yes. + return true; } bool FunctionImportGlobalProcessing::doImportAsDefinition( @@ -132,8 +124,10 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, return SGV->getLinkage(); switch (SGV->getLinkage()) { + case GlobalValue::LinkOnceAnyLinkage: + case GlobalValue::LinkOnceODRLinkage: case GlobalValue::ExternalLinkage: - // External defnitions are converted to available_externally + // External and linkonce definitions are converted to available_externally // definitions upon import, so that they are available for inlining // and/or optimization, but are turned into declarations later // during the EliminateAvailableExternally pass. @@ -150,12 +144,6 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, // An imported available_externally declaration stays that way. return SGV->getLinkage(); - case GlobalValue::LinkOnceAnyLinkage: - case GlobalValue::LinkOnceODRLinkage: - // These both stay the same when importing the definition. - // The ThinLTO pass will eventually force-import their definitions. - return SGV->getLinkage(); - case GlobalValue::WeakAnyLinkage: // Can't import weak_any definitions correctly, or we might change the // program semantics, since the linker will pick the first weak_any @@ -213,6 +201,23 @@ FunctionImportGlobalProcessing::getLinkage(const GlobalValue *SGV, } void FunctionImportGlobalProcessing::processGlobalForThinLTO(GlobalValue &GV) { + + // Check the summaries to see if the symbol gets resolved to a known local + // definition. + if (GV.hasName()) { + ValueInfo VI = ImportIndex.getValueInfo(GV.getGUID()); + if (VI) { + // Need to check all summaries are local in case of hash collisions. + bool IsLocal = VI.getSummaryList().size() && + llvm::all_of(VI.getSummaryList(), + [](const std::unique_ptr<GlobalValueSummary> &Summary) { + return Summary->isDSOLocal(); + }); + if (IsLocal) + GV.setDSOLocal(true); + } + } + bool DoPromote = false; if (GV.hasLocalLinkage() && ((DoPromote = shouldPromoteLocalToGlobal(&GV)) || isPerformingImport())) { diff --git a/lib/Transforms/Utils/InlineFunction.cpp b/lib/Transforms/Utils/InlineFunction.cpp index 2a18c140c788..fedf6e100d6c 100644 --- a/lib/Transforms/Utils/InlineFunction.cpp +++ b/lib/Transforms/Utils/InlineFunction.cpp @@ -12,11 +12,15 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" -#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/BlockFrequencyInfo.h" @@ -26,25 +30,46 @@ #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/ProfileSummaryInfo.h" #include "llvm/Analysis/ValueTracking.h" -#include "llvm/IR/Attributes.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <limits> +#include <string> +#include <utility> +#include <vector> using namespace llvm; @@ -62,28 +87,37 @@ bool llvm::InlineFunction(CallInst *CI, InlineFunctionInfo &IFI, AAResults *CalleeAAR, bool InsertLifetime) { return InlineFunction(CallSite(CI), IFI, CalleeAAR, InsertLifetime); } + bool llvm::InlineFunction(InvokeInst *II, InlineFunctionInfo &IFI, AAResults *CalleeAAR, bool InsertLifetime) { return InlineFunction(CallSite(II), IFI, CalleeAAR, InsertLifetime); } namespace { + /// A class for recording information about inlining a landing pad. class LandingPadInliningInfo { - BasicBlock *OuterResumeDest; ///< Destination of the invoke's unwind. - BasicBlock *InnerResumeDest; ///< Destination for the callee's resume. - LandingPadInst *CallerLPad; ///< LandingPadInst associated with the invoke. - PHINode *InnerEHValuesPHI; ///< PHI for EH values from landingpad insts. + /// Destination of the invoke's unwind. + BasicBlock *OuterResumeDest; + + /// Destination for the callee's resume. + BasicBlock *InnerResumeDest = nullptr; + + /// LandingPadInst associated with the invoke. + LandingPadInst *CallerLPad = nullptr; + + /// PHI for EH values from landingpad insts. + PHINode *InnerEHValuesPHI = nullptr; + SmallVector<Value*, 8> UnwindDestPHIValues; public: LandingPadInliningInfo(InvokeInst *II) - : OuterResumeDest(II->getUnwindDest()), InnerResumeDest(nullptr), - CallerLPad(nullptr), InnerEHValuesPHI(nullptr) { + : OuterResumeDest(II->getUnwindDest()) { // If there are PHI nodes in the unwind destination block, we need to keep // track of which values came into them from the invoke before removing // the edge from this block. - llvm::BasicBlock *InvokeBB = II->getParent(); + BasicBlock *InvokeBB = II->getParent(); BasicBlock::iterator I = OuterResumeDest->begin(); for (; isa<PHINode>(I); ++I) { // Save the value to use for this edge. @@ -126,7 +160,8 @@ namespace { } } }; -} // anonymous namespace + +} // end anonymous namespace /// Get or create a target for the branch from ResumeInsts. BasicBlock *LandingPadInliningInfo::getInnerResumeDest() { @@ -189,7 +224,7 @@ static Value *getParentPad(Value *EHPad) { return cast<CatchSwitchInst>(EHPad)->getParentPad(); } -typedef DenseMap<Instruction *, Value *> UnwindDestMemoTy; +using UnwindDestMemoTy = DenseMap<Instruction *, Value *>; /// Helper for getUnwindDestToken that does the descendant-ward part of /// the search. @@ -617,7 +652,7 @@ static void HandleInlinedEHPad(InvokeInst *II, BasicBlock *FirstNewBlock, // track of which values came into them from the invoke before removing the // edge from this block. SmallVector<Value *, 8> UnwindDestPHIValues; - llvm::BasicBlock *InvokeBB = II->getParent(); + BasicBlock *InvokeBB = II->getParent(); for (Instruction &I : *UnwindDest) { // Save the value to use for this edge. PHINode *PHI = dyn_cast<PHINode>(&I); @@ -1359,6 +1394,7 @@ static void fixupLineNumbers(Function *Fn, Function::iterator FI, } } } + /// Update the block frequencies of the caller after a callee has been inlined. /// /// Each block cloned into the caller has its block frequency scaled by the @@ -1454,7 +1490,8 @@ static void updateCalleeCount(BlockFrequencyInfo *CallerBFI, BasicBlock *CallBB, /// exists in the instruction stream. Similarly this will inline a recursive /// function by one level. bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, - AAResults *CalleeAAR, bool InsertLifetime) { + AAResults *CalleeAAR, bool InsertLifetime, + Function *ForwardVarArgsTo) { Instruction *TheCall = CS.getInstruction(); assert(TheCall->getParent() && TheCall->getFunction() && "Instruction not in function!"); @@ -1464,8 +1501,9 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, Function *CalledFunc = CS.getCalledFunction(); if (!CalledFunc || // Can't inline external function or indirect - CalledFunc->isDeclaration() || // call, or call to a vararg function! - CalledFunc->getFunctionType()->isVarArg()) return false; + CalledFunc->isDeclaration() || + (!ForwardVarArgsTo && CalledFunc->isVarArg())) // call, or call to a vararg function! + return false; // The inliner does not know how to inline through calls with operand bundles // in general ... @@ -1592,8 +1630,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, auto &DL = Caller->getParent()->getDataLayout(); - assert(CalledFunc->arg_size() == CS.arg_size() && - "No varargs calls can be inlined!"); + assert((CalledFunc->arg_size() == CS.arg_size() || ForwardVarArgsTo) && + "Varargs calls can only be inlined if the Varargs are forwarded!"); // Calculate the vector of arguments to pass into the function cloner, which // matches up the formal to the actual argument values. @@ -1772,9 +1810,15 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Move any dbg.declares describing the allocas into the entry basic block. DIBuilder DIB(*Caller->getParent()); for (auto &AI : IFI.StaticAllocas) - replaceDbgDeclareForAlloca(AI, AI, DIB, /*Deref=*/false); + replaceDbgDeclareForAlloca(AI, AI, DIB, DIExpression::NoDeref, 0, + DIExpression::NoDeref); } + SmallVector<Value*,4> VarArgsToForward; + for (unsigned i = CalledFunc->getFunctionType()->getNumParams(); + i < CS.getNumArgOperands(); i++) + VarArgsToForward.push_back(CS.getArgOperand(i)); + bool InlinedMustTailCalls = false, InlinedDeoptimizeCalls = false; if (InlinedFunctionInfo.ContainsCalls) { CallInst::TailCallKind CallSiteTailKind = CallInst::TCK_None; @@ -1783,7 +1827,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, for (Function::iterator BB = FirstNewBlock, E = Caller->end(); BB != E; ++BB) { - for (Instruction &I : *BB) { + for (auto II = BB->begin(); II != BB->end();) { + Instruction &I = *II++; CallInst *CI = dyn_cast<CallInst>(&I); if (!CI) continue; @@ -1806,7 +1851,8 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // f -> g -> musttail f ==> f -> f // f -> g -> tail f ==> f -> f CallInst::TailCallKind ChildTCK = CI->getTailCallKind(); - ChildTCK = std::min(CallSiteTailKind, ChildTCK); + if (ChildTCK != CallInst::TCK_NoTail) + ChildTCK = std::min(CallSiteTailKind, ChildTCK); CI->setTailCallKind(ChildTCK); InlinedMustTailCalls |= CI->isMustTailCall(); @@ -1814,6 +1860,16 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // 'nounwind'. if (MarkNoUnwind) CI->setDoesNotThrow(); + + if (ForwardVarArgsTo && !VarArgsToForward.empty() && + CI->getCalledFunction() == ForwardVarArgsTo) { + SmallVector<Value*, 6> Params(CI->arg_operands()); + Params.append(VarArgsToForward.begin(), VarArgsToForward.end()); + CallInst *Call = CallInst::Create(CI->getCalledFunction(), Params, "", CI); + Call->setDebugLoc(CI->getDebugLoc()); + CI->replaceAllUsesWith(Call); + CI->eraseFromParent(); + } } } } @@ -1848,8 +1904,9 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // Check that array size doesn't saturate uint64_t and doesn't // overflow when it's multiplied by type size. - if (AllocaArraySize != ~0ULL && - UINT64_MAX / AllocaArraySize >= AllocaTypeSize) { + if (AllocaArraySize != std::numeric_limits<uint64_t>::max() && + std::numeric_limits<uint64_t>::max() / AllocaArraySize >= + AllocaTypeSize) { AllocaSize = ConstantInt::get(Type::getInt64Ty(AI->getContext()), AllocaArraySize * AllocaTypeSize); } @@ -1980,7 +2037,7 @@ bool llvm::InlineFunction(CallSite CS, InlineFunctionInfo &IFI, // match the callee's return type, we also need to change the return type of // the intrinsic. if (Caller->getReturnType() == TheCall->getType()) { - auto NewEnd = remove_if(Returns, [](ReturnInst *RI) { + auto NewEnd = llvm::remove_if(Returns, [](ReturnInst *RI) { return RI->getParent()->getTerminatingDeoptimizeCall() != nullptr; }); Returns.erase(NewEnd, Returns.end()); diff --git a/lib/Transforms/Utils/LCSSA.cpp b/lib/Transforms/Utils/LCSSA.cpp index 089f2b5f3b18..ae0e2bb6c280 100644 --- a/lib/Transforms/Utils/LCSSA.cpp +++ b/lib/Transforms/Utils/LCSSA.cpp @@ -56,9 +56,10 @@ static bool VerifyLoopLCSSA = true; #else static bool VerifyLoopLCSSA = false; #endif -static cl::opt<bool,true> -VerifyLoopLCSSAFlag("verify-loop-lcssa", cl::location(VerifyLoopLCSSA), - cl::desc("Verify loop lcssa form (time consuming)")); +static cl::opt<bool, true> + VerifyLoopLCSSAFlag("verify-loop-lcssa", cl::location(VerifyLoopLCSSA), + cl::Hidden, + cl::desc("Verify loop lcssa form (time consuming)")); /// Return true if the specified block is in the list. static bool isExitBlock(BasicBlock *BB, diff --git a/lib/Transforms/Utils/Local.cpp b/lib/Transforms/Utils/Local.cpp index 74610613001c..a1961eecb391 100644 --- a/lib/Transforms/Utils/Local.cpp +++ b/lib/Transforms/Utils/Local.cpp @@ -1,4 +1,4 @@ -//===-- Local.cpp - Functions to perform local transformations ------------===// +//===- Local.cpp - Functions to perform local transformations -------------===// // // The LLVM Compiler Infrastructure // @@ -13,42 +13,74 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/Local.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/Hashing.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LazyValueInfo.h" #include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/BinaryFormat/Dwarf.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GetElementPtrTypeIterator.h" -#include "llvm/IR/GlobalAlias.h" -#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/GlobalObject.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/KnownBits.h" -#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" +#include <algorithm> +#include <cassert> +#include <climits> +#include <cstdint> +#include <iterator> +#include <map> +#include <utility> + using namespace llvm; using namespace llvm::PatternMatch; @@ -282,7 +314,6 @@ bool llvm::ConstantFoldTerminator(BasicBlock *BB, bool DeleteDeadConditions, return false; } - //===----------------------------------------------------------------------===// // Local dead code elimination. // @@ -541,7 +572,6 @@ bool llvm::SimplifyInstructionsInBlock(BasicBlock *BB, // Control Flow Graph Restructuring. // - /// RemovePredecessorAndSimplify - Like BasicBlock::removePredecessor, this /// method is called when we're about to delete Pred as a predecessor of BB. If /// BB contains any PHI nodes, this drops the entries in the PHI nodes for Pred. @@ -578,12 +608,10 @@ void llvm::RemovePredecessorAndSimplify(BasicBlock *BB, BasicBlock *Pred) { } } - /// MergeBasicBlockIntoOnlyPred - DestBB is a block with one predecessor and its /// predecessor is known to have one successor (DestBB!). Eliminate the edge /// between them, moving the instructions in the predecessor into DestBB and /// deleting the predecessor block. -/// void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT) { // If BB has single-entry PHI nodes, fold them. while (PHINode *PN = dyn_cast<PHINode>(DestBB->begin())) { @@ -602,7 +630,7 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT) { if (DestBB->hasAddressTaken()) { BlockAddress *BA = BlockAddress::get(DestBB); Constant *Replacement = - ConstantInt::get(llvm::Type::getInt32Ty(BA->getContext()), 1); + ConstantInt::get(Type::getInt32Ty(BA->getContext()), 1); BA->replaceAllUsesWith(ConstantExpr::getIntToPtr(Replacement, BA->getType())); BA->destroyConstant(); @@ -621,9 +649,13 @@ void llvm::MergeBasicBlockIntoOnlyPred(BasicBlock *DestBB, DominatorTree *DT) { DestBB->moveAfter(PredBB); if (DT) { - BasicBlock *PredBBIDom = DT->getNode(PredBB)->getIDom()->getBlock(); - DT->changeImmediateDominator(DestBB, PredBBIDom); - DT->eraseNode(PredBB); + // For some irreducible CFG we end up having forward-unreachable blocks + // so check if getNode returns a valid node before updating the domtree. + if (DomTreeNode *DTN = DT->getNode(PredBB)) { + BasicBlock *PredBBIDom = DTN->getIDom()->getBlock(); + DT->changeImmediateDominator(DestBB, PredBBIDom); + DT->eraseNode(PredBB); + } } // Nuke BB. PredBB->eraseFromParent(); @@ -640,7 +672,6 @@ static bool CanMergeValues(Value *First, Value *Second) { /// almost-empty BB ending in an unconditional branch to Succ, into Succ. /// /// Assumption: Succ is the single successor for BB. -/// static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { assert(*succ_begin(BB) == Succ && "Succ is not successor of BB!"); @@ -696,8 +727,8 @@ static bool CanPropagatePredecessorsForPHIs(BasicBlock *BB, BasicBlock *Succ) { return true; } -typedef SmallVector<BasicBlock *, 16> PredBlockVector; -typedef DenseMap<BasicBlock *, Value *> IncomingValueMap; +using PredBlockVector = SmallVector<BasicBlock *, 16>; +using IncomingValueMap = DenseMap<BasicBlock *, Value *>; /// \brief Determines the value to use as the phi node input for a block. /// @@ -927,7 +958,6 @@ bool llvm::TryToSimplifyUncondBranchFromEmptyBlock(BasicBlock *BB) { /// nodes in this block. This doesn't try to be clever about PHI nodes /// which differ only in the order of the incoming values, but instcombine /// orders them so it usually won't matter. -/// bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { // This implementation doesn't currently consider undef operands // specially. Theoretically, two phis which are identical except for @@ -937,9 +967,11 @@ bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { static PHINode *getEmptyKey() { return DenseMapInfo<PHINode *>::getEmptyKey(); } + static PHINode *getTombstoneKey() { return DenseMapInfo<PHINode *>::getTombstoneKey(); } + static unsigned getHashValue(PHINode *PN) { // Compute a hash value on the operands. Instcombine will likely have // sorted them, which helps expose duplicates, but we have to check all @@ -948,6 +980,7 @@ bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { hash_combine_range(PN->value_op_begin(), PN->value_op_end()), hash_combine_range(PN->block_begin(), PN->block_end()))); } + static bool isEqual(PHINode *LHS, PHINode *RHS) { if (LHS == getEmptyKey() || LHS == getTombstoneKey() || RHS == getEmptyKey() || RHS == getTombstoneKey()) @@ -984,7 +1017,6 @@ bool llvm::EliminateDuplicatePHINodes(BasicBlock *BB) { /// often possible though. If alignment is important, a more reliable approach /// is to simply align all global variables and allocation instructions to /// their preferred alignment from the beginning. -/// static unsigned enforceKnownAlignment(Value *V, unsigned Align, unsigned PrefAlign, const DataLayout &DL) { @@ -1068,12 +1100,11 @@ static bool LdStHasDebugValue(DILocalVariable *DIVar, DIExpression *DIExpr, // Since we can't guarantee that the original dbg.declare instrinsic // is removed by LowerDbgDeclare(), we need to make sure that we are // not inserting the same dbg.value intrinsic over and over. - llvm::BasicBlock::InstListType::iterator PrevI(I); + BasicBlock::InstListType::iterator PrevI(I); if (PrevI != I->getParent()->getInstList().begin()) { --PrevI; if (DbgValueInst *DVI = dyn_cast<DbgValueInst>(PrevI)) if (DVI->getValue() == I->getOperand(0) && - DVI->getOffset() == 0 && DVI->getVariable() == DIVar && DVI->getExpression() == DIExpr) return true; @@ -1092,7 +1123,6 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, findDbgValues(DbgValues, APN); for (auto *DVI : DbgValues) { assert(DVI->getValue() == APN); - assert(DVI->getOffset() == 0); if ((DVI->getVariable() == DIVar) && (DVI->getExpression() == DIExpr)) return true; } @@ -1100,12 +1130,13 @@ static bool PhiHasDebugValue(DILocalVariable *DIVar, } /// Inserts a llvm.dbg.value intrinsic before a store to an alloca'd value -/// that has an associated llvm.dbg.decl intrinsic. -void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, +/// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. +void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, StoreInst *SI, DIBuilder &Builder) { - auto *DIVar = DDI->getVariable(); + assert(DII->isAddressOfVariable()); + auto *DIVar = DII->getVariable(); assert(DIVar && "Missing variable"); - auto *DIExpr = DDI->getExpression(); + auto *DIExpr = DII->getExpression(); Value *DV = SI->getOperand(0); // If an argument is zero extended then use argument directly. The ZExt @@ -1116,7 +1147,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, if (SExtInst *SExt = dyn_cast<SExtInst>(SI->getOperand(0))) ExtendedArg = dyn_cast<Argument>(SExt->getOperand(0)); if (ExtendedArg) { - // If this DDI was already describing only a fragment of a variable, ensure + // If this DII was already describing only a fragment of a variable, ensure // that fragment is appropriately narrowed here. // But if a fragment wasn't used, describe the value as the original // argument (rather than the zext or sext) so that it remains described even @@ -1129,23 +1160,23 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, DIExpr->elements_end() - 3); Ops.push_back(dwarf::DW_OP_LLVM_fragment); Ops.push_back(FragmentOffset); - const DataLayout &DL = DDI->getModule()->getDataLayout(); + const DataLayout &DL = DII->getModule()->getDataLayout(); Ops.push_back(DL.getTypeSizeInBits(ExtendedArg->getType())); DIExpr = Builder.createExpression(Ops); } DV = ExtendedArg; } if (!LdStHasDebugValue(DIVar, DIExpr, SI)) - Builder.insertDbgValueIntrinsic(DV, 0, DIVar, DIExpr, DDI->getDebugLoc(), + Builder.insertDbgValueIntrinsic(DV, DIVar, DIExpr, DII->getDebugLoc(), SI); } /// Inserts a llvm.dbg.value intrinsic before a load of an alloca'd value -/// that has an associated llvm.dbg.decl intrinsic. -void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, +/// that has an associated llvm.dbg.declare or llvm.dbg.addr intrinsic. +void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, LoadInst *LI, DIBuilder &Builder) { - auto *DIVar = DDI->getVariable(); - auto *DIExpr = DDI->getExpression(); + auto *DIVar = DII->getVariable(); + auto *DIExpr = DII->getExpression(); assert(DIVar && "Missing variable"); if (LdStHasDebugValue(DIVar, DIExpr, LI)) @@ -1156,16 +1187,16 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, // preferable to keep tracking both the loaded value and the original // address in case the alloca can not be elided. Instruction *DbgValue = Builder.insertDbgValueIntrinsic( - LI, 0, DIVar, DIExpr, DDI->getDebugLoc(), (Instruction *)nullptr); + LI, DIVar, DIExpr, DII->getDebugLoc(), (Instruction *)nullptr); DbgValue->insertAfter(LI); } -/// Inserts a llvm.dbg.value intrinsic after a phi -/// that has an associated llvm.dbg.decl intrinsic. -void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, +/// Inserts a llvm.dbg.value intrinsic after a phi that has an associated +/// llvm.dbg.declare or llvm.dbg.addr intrinsic. +void llvm::ConvertDebugDeclareToDebugValue(DbgInfoIntrinsic *DII, PHINode *APN, DIBuilder &Builder) { - auto *DIVar = DDI->getVariable(); - auto *DIExpr = DDI->getExpression(); + auto *DIVar = DII->getVariable(); + auto *DIExpr = DII->getExpression(); assert(DIVar && "Missing variable"); if (PhiHasDebugValue(DIVar, DIExpr, APN)) @@ -1178,7 +1209,7 @@ void llvm::ConvertDebugDeclareToDebugValue(DbgDeclareInst *DDI, // insertion point. // FIXME: Insert dbg.value markers in the successors when appropriate. if (InsertionPt != BB->end()) - Builder.insertDbgValueIntrinsic(APN, 0, DIVar, DIExpr, DDI->getDebugLoc(), + Builder.insertDbgValueIntrinsic(APN, DIVar, DIExpr, DII->getDebugLoc(), &*InsertionPt); } @@ -1222,7 +1253,7 @@ bool llvm::LowerDbgDeclare(Function &F) { // This is a call by-value or some other instruction that // takes a pointer to the variable. Insert a *value* // intrinsic that describes the alloca. - DIB.insertDbgValueIntrinsic(AI, 0, DDI->getVariable(), + DIB.insertDbgValueIntrinsic(AI, DDI->getVariable(), DDI->getExpression(), DDI->getDebugLoc(), CI); } @@ -1233,16 +1264,25 @@ bool llvm::LowerDbgDeclare(Function &F) { return true; } -/// FindAllocaDbgDeclare - Finds the llvm.dbg.declare intrinsic describing the -/// alloca 'V', if any. -DbgDeclareInst *llvm::FindAllocaDbgDeclare(Value *V) { - if (auto *L = LocalAsMetadata::getIfExists(V)) - if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) - for (User *U : MDV->users()) - if (DbgDeclareInst *DDI = dyn_cast<DbgDeclareInst>(U)) - return DDI; +/// Finds all intrinsics declaring local variables as living in the memory that +/// 'V' points to. This may include a mix of dbg.declare and +/// dbg.addr intrinsics. +TinyPtrVector<DbgInfoIntrinsic *> llvm::FindDbgAddrUses(Value *V) { + auto *L = LocalAsMetadata::getIfExists(V); + if (!L) + return {}; + auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L); + if (!MDV) + return {}; + + TinyPtrVector<DbgInfoIntrinsic *> Declares; + for (User *U : MDV->users()) { + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(U)) + if (DII->isAddressOfVariable()) + Declares.push_back(DII); + } - return nullptr; + return Declares; } void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) { @@ -1253,29 +1293,40 @@ void llvm::findDbgValues(SmallVectorImpl<DbgValueInst *> &DbgValues, Value *V) { DbgValues.push_back(DVI); } +static void findDbgUsers(SmallVectorImpl<DbgInfoIntrinsic *> &DbgUsers, + Value *V) { + if (auto *L = LocalAsMetadata::getIfExists(V)) + if (auto *MDV = MetadataAsValue::getIfExists(V->getContext(), L)) + for (User *U : MDV->users()) + if (DbgInfoIntrinsic *DII = dyn_cast<DbgInfoIntrinsic>(U)) + DbgUsers.push_back(DII); +} bool llvm::replaceDbgDeclare(Value *Address, Value *NewAddress, Instruction *InsertBefore, DIBuilder &Builder, - bool Deref, int Offset) { - DbgDeclareInst *DDI = FindAllocaDbgDeclare(Address); - if (!DDI) - return false; - DebugLoc Loc = DDI->getDebugLoc(); - auto *DIVar = DDI->getVariable(); - auto *DIExpr = DDI->getExpression(); - assert(DIVar && "Missing variable"); - DIExpr = DIExpression::prepend(DIExpr, Deref, Offset); - // Insert llvm.dbg.declare immediately after the original alloca, and remove - // old llvm.dbg.declare. - Builder.insertDeclare(NewAddress, DIVar, DIExpr, Loc, InsertBefore); - DDI->eraseFromParent(); - return true; + bool DerefBefore, int Offset, bool DerefAfter) { + auto DbgAddrs = FindDbgAddrUses(Address); + for (DbgInfoIntrinsic *DII : DbgAddrs) { + DebugLoc Loc = DII->getDebugLoc(); + auto *DIVar = DII->getVariable(); + auto *DIExpr = DII->getExpression(); + assert(DIVar && "Missing variable"); + DIExpr = DIExpression::prepend(DIExpr, DerefBefore, Offset, DerefAfter); + // Insert llvm.dbg.declare immediately after InsertBefore, and remove old + // llvm.dbg.declare. + Builder.insertDeclare(NewAddress, DIVar, DIExpr, Loc, InsertBefore); + if (DII == InsertBefore) + InsertBefore = &*std::next(InsertBefore->getIterator()); + DII->eraseFromParent(); + } + return !DbgAddrs.empty(); } bool llvm::replaceDbgDeclareForAlloca(AllocaInst *AI, Value *NewAllocaAddress, - DIBuilder &Builder, bool Deref, int Offset) { + DIBuilder &Builder, bool DerefBefore, + int Offset, bool DerefAfter) { return replaceDbgDeclare(AI, NewAllocaAddress, AI->getNextNode(), Builder, - Deref, Offset); + DerefBefore, Offset, DerefAfter); } static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, @@ -1302,8 +1353,7 @@ static void replaceOneDbgValueForAlloca(DbgValueInst *DVI, Value *NewAddress, DIExpr = Builder.createExpression(Ops); } - Builder.insertDbgValueIntrinsic(NewAddress, DVI->getOffset(), DIVar, DIExpr, - Loc, DVI); + Builder.insertDbgValueIntrinsic(NewAddress, DIVar, DIExpr, Loc, DVI); DVI->eraseFromParent(); } @@ -1322,17 +1372,28 @@ void llvm::salvageDebugInfo(Instruction &I) { SmallVector<DbgValueInst *, 1> DbgValues; auto &M = *I.getModule(); - auto MDWrap = [&](Value *V) { + auto wrapMD = [&](Value *V) { return MetadataAsValue::get(I.getContext(), ValueAsMetadata::get(V)); }; - if (isa<BitCastInst>(&I)) { - findDbgValues(DbgValues, &I); - for (auto *DVI : DbgValues) { - // Bitcasts are entirely irrelevant for debug info. Rewrite the dbg.value - // to use the cast's source. - DVI->setOperand(0, MDWrap(I.getOperand(0))); - DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); + auto applyOffset = [&](DbgValueInst *DVI, uint64_t Offset) { + auto *DIExpr = DVI->getExpression(); + DIExpr = DIExpression::prepend(DIExpr, DIExpression::NoDeref, Offset, + DIExpression::NoDeref, + DIExpression::WithStackValue); + DVI->setOperand(0, wrapMD(I.getOperand(0))); + DVI->setOperand(2, MetadataAsValue::get(I.getContext(), DIExpr)); + DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); + }; + + if (isa<BitCastInst>(&I) || isa<IntToPtrInst>(&I)) { + // Bitcasts are entirely irrelevant for debug info. Rewrite dbg.value, + // dbg.addr, and dbg.declare to use the cast's source. + SmallVector<DbgInfoIntrinsic *, 1> DbgUsers; + findDbgUsers(DbgUsers, &I); + for (auto *DII : DbgUsers) { + DII->setOperand(0, wrapMD(I.getOperand(0))); + DEBUG(dbgs() << "SALVAGE: " << *DII << '\n'); } } else if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { findDbgValues(DbgValues, &I); @@ -1343,27 +1404,27 @@ void llvm::salvageDebugInfo(Instruction &I) { // Rewrite a constant GEP into a DIExpression. Since we are performing // arithmetic to compute the variable's *value* in the DIExpression, we // need to mark the expression with a DW_OP_stack_value. - if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) { - auto *DIExpr = DVI->getExpression(); - DIBuilder DIB(M, /*AllowUnresolved*/ false); + if (GEP->accumulateConstantOffset(M.getDataLayout(), Offset)) // GEP offsets are i32 and thus always fit into an int64_t. - DIExpr = DIExpression::prepend(DIExpr, DIExpression::NoDeref, - Offset.getSExtValue(), - DIExpression::WithStackValue); - DVI->setOperand(0, MDWrap(I.getOperand(0))); - DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr)); - DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); - } + applyOffset(DVI, Offset.getSExtValue()); } + } else if (auto *BI = dyn_cast<BinaryOperator>(&I)) { + if (BI->getOpcode() == Instruction::Add) + if (auto *ConstInt = dyn_cast<ConstantInt>(I.getOperand(1))) + if (ConstInt->getBitWidth() <= 64) { + APInt Offset = ConstInt->getValue(); + findDbgValues(DbgValues, &I); + for (auto *DVI : DbgValues) + applyOffset(DVI, Offset.getSExtValue()); + } } else if (isa<LoadInst>(&I)) { findDbgValues(DbgValues, &I); for (auto *DVI : DbgValues) { // Rewrite the load into DW_OP_deref. auto *DIExpr = DVI->getExpression(); - DIBuilder DIB(M, /*AllowUnresolved*/ false); DIExpr = DIExpression::prepend(DIExpr, DIExpression::WithDeref); - DVI->setOperand(0, MDWrap(I.getOperand(0))); - DVI->setOperand(3, MetadataAsValue::get(I.getContext(), DIExpr)); + DVI->setOperand(0, wrapMD(I.getOperand(0))); + DVI->setOperand(2, MetadataAsValue::get(I.getContext(), DIExpr)); DEBUG(dbgs() << "SALVAGE: " << *DVI << '\n'); } } @@ -1480,7 +1541,6 @@ BasicBlock *llvm::changeToInvokeAndSplitBasicBlock(CallInst *CI, static bool markAliveBlocks(Function &F, SmallPtrSetImpl<BasicBlock*> &Reachable) { - SmallVector<BasicBlock*, 128> Worklist; BasicBlock *BB = &F.front(); Worklist.push_back(BB); @@ -1586,13 +1646,16 @@ static bool markAliveBlocks(Function &F, static CatchPadInst *getEmptyKey() { return DenseMapInfo<CatchPadInst *>::getEmptyKey(); } + static CatchPadInst *getTombstoneKey() { return DenseMapInfo<CatchPadInst *>::getTombstoneKey(); } + static unsigned getHashValue(CatchPadInst *CatchPad) { return static_cast<unsigned>(hash_combine_range( CatchPad->value_op_begin(), CatchPad->value_op_end())); } + static bool isEqual(CatchPadInst *LHS, CatchPadInst *RHS) { if (LHS == getEmptyKey() || LHS == getTombstoneKey() || RHS == getEmptyKey() || RHS == getTombstoneKey()) @@ -1832,7 +1895,8 @@ unsigned llvm::replaceDominatedUsesWith(Value *From, Value *To, return ::replaceDominatedUsesWith(From, To, BB, ProperlyDominates); } -bool llvm::callsGCLeafFunction(ImmutableCallSite CS) { +bool llvm::callsGCLeafFunction(ImmutableCallSite CS, + const TargetLibraryInfo &TLI) { // Check if the function is specifically marked as a gc leaf function. if (CS.hasFnAttr("gc-leaf-function")) return true; @@ -1846,6 +1910,14 @@ bool llvm::callsGCLeafFunction(ImmutableCallSite CS) { IID != Intrinsic::experimental_deoptimize; } + // Lib calls can be materialized by some passes, and won't be + // marked as 'gc-leaf-function.' All available Libcalls are + // GC-leaf. + LibFunc LF; + if (TLI.getLibFunc(CS, LF)) { + return TLI.has(LF); + } + return false; } @@ -1893,6 +1965,7 @@ void llvm::copyRangeMetadata(const DataLayout &DL, const LoadInst &OldLI, } namespace { + /// A potential constituent of a bitreverse or bswap expression. See /// collectBitParts for a fuller explanation. struct BitPart { @@ -1902,12 +1975,14 @@ struct BitPart { /// The Value that this is a bitreverse/bswap of. Value *Provider; + /// The "provenance" of each bit. Provenance[A] = B means that bit A /// in Provider becomes bit B in the result of this expression. SmallVector<int8_t, 32> Provenance; // int8_t means max size is i128. enum { Unset = -1 }; }; + } // end anonymous namespace /// Analyze the specified subexpression and see if it is capable of providing @@ -1933,7 +2008,6 @@ struct BitPart { /// /// Because we pass around references into \c BPS, we must use a container that /// does not invalidate internal references (std::map instead of DenseMap). -/// static const Optional<BitPart> & collectBitParts(Value *V, bool MatchBSwaps, bool MatchBitReversals, std::map<Value *, Optional<BitPart>> &BPS) { @@ -2069,8 +2143,6 @@ static bool bitTransformIsCorrectForBitReverse(unsigned From, unsigned To, return From == BitWidth - To - 1; } -/// Given an OR instruction, check to see if this is a bitreverse -/// idiom. If so, insert the new intrinsic and return true. bool llvm::recognizeBSwapOrBitReverseIdiom( Instruction *I, bool MatchBSwaps, bool MatchBitReversals, SmallVectorImpl<Instruction *> &InsertedInsts) { diff --git a/lib/Transforms/Utils/LoopSimplify.cpp b/lib/Transforms/Utils/LoopSimplify.cpp index e21e34df8ded..f43af9772771 100644 --- a/lib/Transforms/Utils/LoopSimplify.cpp +++ b/lib/Transforms/Utils/LoopSimplify.cpp @@ -258,7 +258,7 @@ static Loop *separateNestedLoop(Loop *L, BasicBlock *Preheader, placeSplitBlockCarefully(NewBB, OuterLoopPreds, L); // Create the new outer loop. - Loop *NewOuter = new Loop(); + Loop *NewOuter = LI->AllocateLoop(); // Change the parent loop to use the outer loop as its child now. if (Loop *Parent = L->getParentLoop()) diff --git a/lib/Transforms/Utils/LoopUnroll.cpp b/lib/Transforms/Utils/LoopUnroll.cpp index f2527f89e83e..dc98a39adcc5 100644 --- a/lib/Transforms/Utils/LoopUnroll.cpp +++ b/lib/Transforms/Utils/LoopUnroll.cpp @@ -21,8 +21,7 @@ #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/DataLayout.h" @@ -68,9 +67,23 @@ static inline void remapInstruction(Instruction *I, ValueToValueMapTy &VMap) { for (unsigned op = 0, E = I->getNumOperands(); op != E; ++op) { Value *Op = I->getOperand(op); + + // Unwrap arguments of dbg.value intrinsics. + bool Wrapped = false; + if (auto *V = dyn_cast<MetadataAsValue>(Op)) + if (auto *Unwrapped = dyn_cast<ValueAsMetadata>(V->getMetadata())) { + Op = Unwrapped->getValue(); + Wrapped = true; + } + + auto wrap = [&](Value *V) { + auto &C = I->getContext(); + return Wrapped ? MetadataAsValue::get(C, ValueAsMetadata::get(V)) : V; + }; + ValueToValueMapTy::iterator It = VMap.find(Op); if (It != VMap.end()) - I->setOperand(op, It->second); + I->setOperand(op, wrap(It->second)); } if (PHINode *PN = dyn_cast<PHINode>(I)) { @@ -200,7 +213,7 @@ const Loop* llvm::addClonedBlockToLoopInfo(BasicBlock *OriginalBB, assert(OriginalBB == OldLoop->getHeader() && "Header should be first in RPO"); - NewLoop = new Loop(); + NewLoop = LI->AllocateLoop(); Loop *NewLoopParent = NewLoops.lookup(OldLoop->getParentLoop()); if (NewLoopParent) @@ -255,8 +268,7 @@ static bool isEpilogProfitable(Loop *L) { return false; } -/// Unroll the given loop by Count. The loop must be in LCSSA form. Returns true -/// if unrolling was successful, or false if the loop was unmodified. Unrolling +/// Unroll the given loop by Count. The loop must be in LCSSA form. Unrolling /// can only fail when the loop's latch block is not terminated by a conditional /// branch instruction. However, if the trip count (and multiple) are not known, /// loop unrolling will mostly produce more code that is no faster. @@ -285,37 +297,36 @@ static bool isEpilogProfitable(Loop *L) { /// runtime-unroll the loop if computing RuntimeTripCount will be expensive and /// AllowExpensiveTripCount is false. /// -/// If we want to perform PGO-based loop peeling, PeelCount is set to the +/// If we want to perform PGO-based loop peeling, PeelCount is set to the /// number of iterations we want to peel off. /// /// The LoopInfo Analysis that is passed will be kept consistent. /// /// This utility preserves LoopInfo. It will also preserve ScalarEvolution and /// DominatorTree if they are non-null. -bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, - bool AllowRuntime, bool AllowExpensiveTripCount, - bool PreserveCondBr, bool PreserveOnlyFirst, - unsigned TripMultiple, unsigned PeelCount, LoopInfo *LI, - ScalarEvolution *SE, DominatorTree *DT, - AssumptionCache *AC, OptimizationRemarkEmitter *ORE, - bool PreserveLCSSA) { +LoopUnrollResult llvm::UnrollLoop( + Loop *L, unsigned Count, unsigned TripCount, bool Force, bool AllowRuntime, + bool AllowExpensiveTripCount, bool PreserveCondBr, bool PreserveOnlyFirst, + unsigned TripMultiple, unsigned PeelCount, bool UnrollRemainder, + LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, AssumptionCache *AC, + OptimizationRemarkEmitter *ORE, bool PreserveLCSSA) { BasicBlock *Preheader = L->getLoopPreheader(); if (!Preheader) { DEBUG(dbgs() << " Can't unroll; loop preheader-insertion failed.\n"); - return false; + return LoopUnrollResult::Unmodified; } BasicBlock *LatchBlock = L->getLoopLatch(); if (!LatchBlock) { DEBUG(dbgs() << " Can't unroll; loop exit-block-insertion failed.\n"); - return false; + return LoopUnrollResult::Unmodified; } // Loops with indirectbr cannot be cloned. if (!L->isSafeToClone()) { DEBUG(dbgs() << " Can't unroll; Loop body cannot be cloned.\n"); - return false; + return LoopUnrollResult::Unmodified; } // The current loop unroll pass can only unroll loops with a single latch @@ -329,7 +340,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, // The loop-rotate pass can be helpful to avoid this in many cases. DEBUG(dbgs() << " Can't unroll; loop not terminated by a conditional branch.\n"); - return false; + return LoopUnrollResult::Unmodified; } auto CheckSuccessors = [&](unsigned S1, unsigned S2) { @@ -339,14 +350,14 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (!CheckSuccessors(0, 1) && !CheckSuccessors(1, 0)) { DEBUG(dbgs() << "Can't unroll; only loops with one conditional latch" " exiting the loop can be unrolled\n"); - return false; + return LoopUnrollResult::Unmodified; } if (Header->hasAddressTaken()) { // The loop-rotate pass can be helpful to avoid this in many cases. DEBUG(dbgs() << " Won't unroll loop: address of header block is taken.\n"); - return false; + return LoopUnrollResult::Unmodified; } if (TripCount != 0) @@ -362,7 +373,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, // Don't enter the unroll code if there is nothing to do. if (TripCount == 0 && Count < 2 && PeelCount == 0) { DEBUG(dbgs() << "Won't unroll; almost nothing to do\n"); - return false; + return LoopUnrollResult::Unmodified; } assert(Count > 0); @@ -395,8 +406,19 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, "Did not expect runtime trip-count unrolling " "and peeling for the same loop"); - if (PeelCount) - peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA); + if (PeelCount) { + bool Peeled = peelLoop(L, PeelCount, LI, SE, DT, AC, PreserveLCSSA); + + // Successful peeling may result in a change in the loop preheader/trip + // counts. If we later unroll the loop, we want these to be updated. + if (Peeled) { + BasicBlock *ExitingBlock = L->getExitingBlock(); + assert(ExitingBlock && "Loop without exiting block?"); + Preheader = L->getLoopPreheader(); + TripCount = SE->getSmallConstantTripCount(L, ExitingBlock); + TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock); + } + } // Loops containing convergent instructions must have a count that divides // their TripMultiple. @@ -418,15 +440,15 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (RuntimeTripCount && TripMultiple % Count != 0 && !UnrollRuntimeLoopRemainder(L, Count, AllowExpensiveTripCount, - EpilogProfitability, LI, SE, DT, - PreserveLCSSA)) { + EpilogProfitability, UnrollRemainder, LI, SE, + DT, AC, PreserveLCSSA)) { if (Force) RuntimeTripCount = false; else { DEBUG( dbgs() << "Wont unroll; remainder loop could not be generated" "when assuming runtime trip count\n"); - return false; + return LoopUnrollResult::Unmodified; } } @@ -450,36 +472,53 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, // Report the unrolling decision. if (CompletelyUnroll) { DEBUG(dbgs() << "COMPLETELY UNROLLING loop %" << Header->getName() - << " with trip count " << TripCount << "!\n"); - ORE->emit(OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(), - L->getHeader()) - << "completely unrolled loop with " - << NV("UnrollCount", TripCount) << " iterations"); + << " with trip count " << TripCount << "!\n"); + if (ORE) + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "FullyUnrolled", L->getStartLoc(), + L->getHeader()) + << "completely unrolled loop with " + << NV("UnrollCount", TripCount) << " iterations"; + }); } else if (PeelCount) { DEBUG(dbgs() << "PEELING loop %" << Header->getName() << " with iteration count " << PeelCount << "!\n"); - ORE->emit(OptimizationRemark(DEBUG_TYPE, "Peeled", L->getStartLoc(), - L->getHeader()) - << " peeled loop by " << NV("PeelCount", PeelCount) - << " iterations"); + if (ORE) + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Peeled", L->getStartLoc(), + L->getHeader()) + << " peeled loop by " << NV("PeelCount", PeelCount) + << " iterations"; + }); } else { - OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(), - L->getHeader()); - Diag << "unrolled loop by a factor of " << NV("UnrollCount", Count); + auto DiagBuilder = [&]() { + OptimizationRemark Diag(DEBUG_TYPE, "PartialUnrolled", L->getStartLoc(), + L->getHeader()); + return Diag << "unrolled loop by a factor of " + << NV("UnrollCount", Count); + }; DEBUG(dbgs() << "UNROLLING loop %" << Header->getName() << " by " << Count); if (TripMultiple == 0 || BreakoutTrip != TripMultiple) { DEBUG(dbgs() << " with a breakout at trip " << BreakoutTrip); - ORE->emit(Diag << " with a breakout at trip " - << NV("BreakoutTrip", BreakoutTrip)); + if (ORE) + ORE->emit([&]() { + return DiagBuilder() << " with a breakout at trip " + << NV("BreakoutTrip", BreakoutTrip); + }); } else if (TripMultiple != 1) { DEBUG(dbgs() << " with " << TripMultiple << " trips per branch"); - ORE->emit(Diag << " with " << NV("TripMultiple", TripMultiple) - << " trips per branch"); + if (ORE) + ORE->emit([&]() { + return DiagBuilder() << " with " << NV("TripMultiple", TripMultiple) + << " trips per branch"; + }); } else if (RuntimeTripCount) { DEBUG(dbgs() << " with run-time trip count"); - ORE->emit(Diag << " with run-time trip count"); + if (ORE) + ORE->emit( + [&]() { return DiagBuilder() << " with run-time trip count"; }); } DEBUG(dbgs() << "!\n"); } @@ -523,8 +562,9 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (Header->getParent()->isDebugInfoForProfiling()) for (BasicBlock *BB : L->getBlocks()) for (Instruction &I : *BB) - if (const DILocation *DIL = I.getDebugLoc()) - I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count)); + if (!isa<DbgInfoIntrinsic>(&I)) + if (const DILocation *DIL = I.getDebugLoc()) + I.setDebugLoc(DIL->cloneWithDuplicationFactor(Count)); for (unsigned It = 1; It != Count; ++It) { std::vector<BasicBlock*> NewBlocks; @@ -796,7 +836,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, Loop *OuterL = L->getParentLoop(); // Update LoopInfo if the loop is completely removed. if (CompletelyUnroll) - LI->markAsRemoved(L); + LI->erase(L); // After complete unrolling most of the blocks should be contained in OuterL. // However, some of them might happen to be out of OuterL (e.g. if they @@ -821,7 +861,7 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, if (NeedToFixLCSSA) { // LCSSA must be performed on the outermost affected loop. The unrolled // loop's last loop latch is guaranteed to be in the outermost loop - // after LoopInfo's been updated by markAsRemoved. + // after LoopInfo's been updated by LoopInfo::erase. Loop *LatchLoop = LI->getLoopFor(Latches.back()); Loop *FixLCSSALoop = OuterL; if (!FixLCSSALoop->contains(LatchLoop)) @@ -844,7 +884,8 @@ bool llvm::UnrollLoop(Loop *L, unsigned Count, unsigned TripCount, bool Force, } } - return true; + return CompletelyUnroll ? LoopUnrollResult::FullyUnrolled + : LoopUnrollResult::PartiallyUnrolled; } /// Given an llvm.loop loop id metadata node, returns the loop hint metadata diff --git a/lib/Transforms/Utils/LoopUnrollPeel.cpp b/lib/Transforms/Utils/LoopUnrollPeel.cpp index 5c21490793e7..4273ce0b6200 100644 --- a/lib/Transforms/Utils/LoopUnrollPeel.cpp +++ b/lib/Transforms/Utils/LoopUnrollPeel.cpp @@ -1,4 +1,4 @@ -//===-- UnrollLoopPeel.cpp - Loop peeling utilities -----------------------===// +//===- UnrollLoopPeel.cpp - Loop peeling utilities ------------------------===// // // The LLVM Compiler Infrastructure // @@ -13,29 +13,42 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Metadata.h" -#include "llvm/IR/Module.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/UnrollLoop.h" +#include "llvm/Transforms/Utils/ValueMapper.h" #include <algorithm> +#include <cassert> +#include <cstdint> +#include <limits> using namespace llvm; #define DEBUG_TYPE "loop-unroll" + STATISTIC(NumPeeled, "Number of loops peeled"); static cl::opt<unsigned> UnrollPeelMaxCount( @@ -49,7 +62,8 @@ static cl::opt<unsigned> UnrollForcePeelCount( // Designates that a Phi is estimated to become invariant after an "infinite" // number of loop iterations (i.e. only may become an invariant if the loop is // fully unrolled). -static const unsigned InfiniteIterationsToInvariance = UINT_MAX; +static const unsigned InfiniteIterationsToInvariance = + std::numeric_limits<unsigned>::max(); // Check whether we are capable of peeling this loop. static bool canPeel(Loop *L) { @@ -210,8 +224,6 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, DEBUG(dbgs() << "Max peel cost: " << UP.Threshold << "\n"); } } - - return; } /// \brief Update the branch weights of the latch of a peeled-off loop @@ -236,7 +248,6 @@ void llvm::computePeelCount(Loop *L, unsigned LoopSize, static void updateBranchWeights(BasicBlock *Header, BranchInst *LatchBR, unsigned IterNumber, unsigned AvgIters, uint64_t &PeeledHeaderWeight) { - // FIXME: Pick a more realistic distribution. // Currently the proportion of weight we assign to the fall-through // side of the branch drops linearly with the iteration number, and we use @@ -272,7 +283,6 @@ static void cloneLoopBlocks(Loop *L, unsigned IterNumber, BasicBlock *InsertTop, LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, ValueToValueMapTy &LVMap, DominatorTree *DT, LoopInfo *LI) { - BasicBlock *Header = L->getHeader(); BasicBlock *Latch = L->getLoopLatch(); BasicBlock *PreHeader = L->getLoopPreheader(); diff --git a/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/lib/Transforms/Utils/LoopUnrollRuntime.cpp index d43ce7abb7cd..efff06f79cb7 100644 --- a/lib/Transforms/Utils/LoopUnrollRuntime.cpp +++ b/lib/Transforms/Utils/LoopUnrollRuntime.cpp @@ -25,7 +25,6 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/BasicBlock.h" @@ -294,7 +293,8 @@ static void ConnectEpilog(Loop *L, Value *ModVal, BasicBlock *NewExit, /// Return the new cloned loop that is created when CreateRemainderLoop is true. static Loop * CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, - const bool UseEpilogRemainder, BasicBlock *InsertTop, + const bool UseEpilogRemainder, const bool UnrollRemainder, + BasicBlock *InsertTop, BasicBlock *InsertBot, BasicBlock *Preheader, std::vector<BasicBlock *> &NewBlocks, LoopBlocksDFS &LoopBlocks, ValueToValueMapTy &VMap, DominatorTree *DT, LoopInfo *LI) { @@ -393,35 +393,14 @@ CloneLoopBlocks(Loop *L, Value *NewIter, const bool CreateRemainderLoop, if (CreateRemainderLoop) { Loop *NewLoop = NewLoops[L]; assert(NewLoop && "L should have been cloned"); - // Add unroll disable metadata to disable future unrolling for this loop. - SmallVector<Metadata *, 4> MDs; - // Reserve first location for self reference to the LoopID metadata node. - MDs.push_back(nullptr); - MDNode *LoopID = NewLoop->getLoopID(); - if (LoopID) { - // First remove any existing loop unrolling metadata. - for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { - bool IsUnrollMetadata = false; - MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i)); - if (MD) { - const MDString *S = dyn_cast<MDString>(MD->getOperand(0)); - IsUnrollMetadata = S && S->getString().startswith("llvm.loop.unroll."); - } - if (!IsUnrollMetadata) - MDs.push_back(LoopID->getOperand(i)); - } - } - LLVMContext &Context = NewLoop->getHeader()->getContext(); - SmallVector<Metadata *, 1> DisableOperands; - DisableOperands.push_back(MDString::get(Context, "llvm.loop.unroll.disable")); - MDNode *DisableNode = MDNode::get(Context, DisableOperands); - MDs.push_back(DisableNode); + // Only add loop metadata if the loop is not going to be completely + // unrolled. + if (UnrollRemainder) + return NewLoop; - MDNode *NewLoopID = MDNode::get(Context, MDs); - // Set operand 0 to refer to the loop id itself. - NewLoopID->replaceOperandWith(0, NewLoopID); - NewLoop->setLoopID(NewLoopID); + // Add unroll disable metadata to disable future unrolling for this loop. + NewLoop->setLoopAlreadyUnrolled(); return NewLoop; } else @@ -435,12 +414,9 @@ canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, BasicBlock *LatchExit, bool PreserveLCSSA, bool UseEpilogRemainder) { - // Support runtime unrolling for multiple exit blocks and multiple exiting - // blocks. - if (!UnrollRuntimeMultiExit) - return false; - // Even if runtime multi exit is enabled, we currently have some correctness - // constrains in unrolling a multi-exit loop. + // We currently have some correctness constrains in unrolling a multi-exit + // loop. Check for these below. + // We rely on LCSSA form being preserved when the exit blocks are transformed. if (!PreserveLCSSA) return false; @@ -470,7 +446,54 @@ canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, return true; } +/// Returns true if we can profitably unroll the multi-exit loop L. Currently, +/// we return true only if UnrollRuntimeMultiExit is set to true. +static bool canProfitablyUnrollMultiExitLoop( + Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, BasicBlock *LatchExit, + bool PreserveLCSSA, bool UseEpilogRemainder) { + +#if !defined(NDEBUG) + SmallVector<BasicBlock *, 8> OtherExitsDummyCheck; + assert(canSafelyUnrollMultiExitLoop(L, OtherExitsDummyCheck, LatchExit, + PreserveLCSSA, UseEpilogRemainder) && + "Should be safe to unroll before checking profitability!"); +#endif + + // Priority goes to UnrollRuntimeMultiExit if it's supplied. + if (UnrollRuntimeMultiExit.getNumOccurrences()) + return UnrollRuntimeMultiExit; + + // The main pain point with multi-exit loop unrolling is that once unrolled, + // we will not be able to merge all blocks into a straight line code. + // There are branches within the unrolled loop that go to the OtherExits. + // The second point is the increase in code size, but this is true + // irrespective of multiple exits. + + // Note: Both the heuristics below are coarse grained. We are essentially + // enabling unrolling of loops that have a single side exit other than the + // normal LatchExit (i.e. exiting into a deoptimize block). + // The heuristics considered are: + // 1. low number of branches in the unrolled version. + // 2. high predictability of these extra branches. + // We avoid unrolling loops that have more than two exiting blocks. This + // limits the total number of branches in the unrolled loop to be atmost + // the unroll factor (since one of the exiting blocks is the latch block). + SmallVector<BasicBlock*, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + if (ExitingBlocks.size() > 2) + return false; + // The second heuristic is that L has one exit other than the latchexit and + // that exit is a deoptimize block. We know that deoptimize blocks are rarely + // taken, which also implies the branch leading to the deoptimize block is + // highly predictable. + return (OtherExits.size() == 1 && + OtherExits[0]->getTerminatingDeoptimizeCall()); + // TODO: These can be fine-tuned further to consider code size or deopt states + // that are captured by the deoptimize exit block. + // Also, we can extend this to support more cases, if we actually + // know of kinds of multiexit loops that would benefit from unrolling. +} /// Insert code in the prolog/epilog code when unrolling a loop with a /// run-time trip-count. @@ -513,10 +536,14 @@ canSafelyUnrollMultiExitLoop(Loop *L, SmallVectorImpl<BasicBlock *> &OtherExits, bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, bool AllowExpensiveTripCount, bool UseEpilogRemainder, + bool UnrollRemainder, LoopInfo *LI, ScalarEvolution *SE, - DominatorTree *DT, bool PreserveLCSSA) { + DominatorTree *DT, AssumptionCache *AC, + bool PreserveLCSSA) { DEBUG(dbgs() << "Trying runtime unrolling on Loop: \n"); DEBUG(L->dump()); + DEBUG(UseEpilogRemainder ? dbgs() << "Using epilog remainder.\n" : + dbgs() << "Using prolog remainder.\n"); // Make sure the loop is in canonical form. if (!L->isLoopSimplifyForm()) { @@ -538,8 +565,11 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, "one of the loop latch successors should be the exit block!"); // These are exit blocks other than the target of the latch exiting block. SmallVector<BasicBlock *, 4> OtherExits; - bool isMultiExitUnrollingEnabled = canSafelyUnrollMultiExitLoop( - L, OtherExits, LatchExit, PreserveLCSSA, UseEpilogRemainder); + bool isMultiExitUnrollingEnabled = + canSafelyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, + UseEpilogRemainder) && + canProfitablyUnrollMultiExitLoop(L, OtherExits, LatchExit, PreserveLCSSA, + UseEpilogRemainder); // Support only single exit and exiting block unless multi-exit loop unrolling is enabled. if (!isMultiExitUnrollingEnabled && (!L->getExitingBlock() || OtherExits.size())) { @@ -724,7 +754,8 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, BasicBlock *InsertBot = UseEpilogRemainder ? LatchExit : PrologExit; BasicBlock *InsertTop = UseEpilogRemainder ? EpilogPreHeader : PrologPreHeader; Loop *remainderLoop = CloneLoopBlocks( - L, ModVal, CreateRemainderLoop, UseEpilogRemainder, InsertTop, InsertBot, + L, ModVal, CreateRemainderLoop, UseEpilogRemainder, UnrollRemainder, + InsertTop, InsertBot, NewPreHeader, NewBlocks, LoopBlocks, VMap, DT, LI); // Insert the cloned blocks into the function. @@ -753,11 +784,15 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, // Add the incoming values from the remainder code to the end of the phi // node. for (unsigned i =0; i < oldNumOperands; i++){ - Value *newVal = VMap[Phi->getIncomingValue(i)]; + Value *newVal = VMap.lookup(Phi->getIncomingValue(i)); // newVal can be a constant or derived from values outside the loop, and - // hence need not have a VMap value. - if (!newVal) + // hence need not have a VMap value. Also, since lookup already generated + // a default "null" VMap entry for this value, we need to populate that + // VMap entry correctly, with the mapped entry being itself. + if (!newVal) { newVal = Phi->getIncomingValue(i); + VMap[Phi->getIncomingValue(i)] = Phi->getIncomingValue(i); + } Phi->addIncoming(newVal, cast<BasicBlock>(VMap[Phi->getIncomingBlock(i)])); } @@ -868,6 +903,16 @@ bool llvm::UnrollRuntimeLoopRemainder(Loop *L, unsigned Count, formDedicatedExitBlocks(remainderLoop, DT, LI, PreserveLCSSA); } + if (remainderLoop && UnrollRemainder) { + DEBUG(dbgs() << "Unrolling remainder loop\n"); + UnrollLoop(remainderLoop, /*Count*/ Count - 1, /*TripCount*/ Count - 1, + /*Force*/ false, /*AllowRuntime*/ false, + /*AllowExpensiveTripCount*/ false, /*PreserveCondBr*/ true, + /*PreserveOnlyFirst*/ false, /*TripMultiple*/ 1, + /*PeelCount*/ 0, /*UnrollRemainder*/ false, LI, SE, DT, AC, + /*ORE*/ nullptr, PreserveLCSSA); + } + NumRuntimeUnrolled++; return true; } diff --git a/lib/Transforms/Utils/LoopUtils.cpp b/lib/Transforms/Utils/LoopUtils.cpp index 3c522786641a..c3fa05a11a24 100644 --- a/lib/Transforms/Utils/LoopUtils.cpp +++ b/lib/Transforms/Utils/LoopUtils.cpp @@ -432,7 +432,7 @@ RecurrenceDescriptor::isRecurrenceInstr(Instruction *I, RecurrenceKind Kind, InstDesc &Prev, bool HasFunNoNaNAttr) { bool FP = I->getType()->isFloatingPointTy(); Instruction *UAI = Prev.getUnsafeAlgebraInst(); - if (!UAI && FP && !I->hasUnsafeAlgebra()) + if (!UAI && FP && !I->isFast()) UAI = I; // Found an unsafe (unvectorizable) algebra instruction. switch (I->getOpcode()) { @@ -565,7 +565,8 @@ bool RecurrenceDescriptor::isFirstOrderRecurrence( auto *I = Phi->user_back(); if (I->isCast() && (I->getParent() == Phi->getParent()) && I->hasOneUse() && DT->dominates(Previous, I->user_back())) { - SinkAfter[I] = Previous; + if (!DT->dominates(Previous, I)) // Otherwise we're good w/o sinking. + SinkAfter[I] = Previous; return true; } } @@ -659,11 +660,11 @@ Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder, break; } - // We only match FP sequences with unsafe algebra, so we can unconditionally + // We only match FP sequences that are 'fast', so we can unconditionally // set it on any generated instructions. IRBuilder<>::FastMathFlagGuard FMFG(Builder); FastMathFlags FMF; - FMF.setUnsafeAlgebra(); + FMF.setFast(); Builder.setFastMathFlags(FMF); Value *Cmp; @@ -677,7 +678,8 @@ Value *RecurrenceDescriptor::createMinMaxOp(IRBuilder<> &Builder, } InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, - const SCEV *Step, BinaryOperator *BOp) + const SCEV *Step, BinaryOperator *BOp, + SmallVectorImpl<Instruction *> *Casts) : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) { assert(IK != IK_NoInduction && "Not an induction"); @@ -704,6 +706,12 @@ InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K, (InductionBinOp->getOpcode() == Instruction::FAdd || InductionBinOp->getOpcode() == Instruction::FSub))) && "Binary opcode should be specified for FP induction"); + + if (Casts) { + for (auto &Inst : *Casts) { + RedundantCasts.push_back(Inst); + } + } } int InductionDescriptor::getConsecutiveDirection() const { @@ -767,7 +775,7 @@ Value *InductionDescriptor::transform(IRBuilder<> &B, Value *Index, // Floating point operations had to be 'fast' to enable the induction. FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); Value *MulExp = B.CreateFMul(StepValue, Index); if (isa<Instruction>(MulExp)) @@ -807,7 +815,7 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, StartValue = Phi->getIncomingValue(1); } else { assert(TheLoop->contains(Phi->getIncomingBlock(1)) && - "Unexpected Phi node in the loop"); + "Unexpected Phi node in the loop"); BEValue = Phi->getIncomingValue(1); StartValue = Phi->getIncomingValue(0); } @@ -840,6 +848,110 @@ bool InductionDescriptor::isFPInductionPHI(PHINode *Phi, const Loop *TheLoop, return true; } +/// This function is called when we suspect that the update-chain of a phi node +/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, +/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime +/// predicate P under which the SCEV expression for the phi can be the +/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the +/// cast instructions that are involved in the update-chain of this induction. +/// A caller that adds the required runtime predicate can be free to drop these +/// cast instructions, and compute the phi using \p AR (instead of some scev +/// expression with casts). +/// +/// For example, without a predicate the scev expression can take the following +/// form: +/// (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy) +/// +/// It corresponds to the following IR sequence: +/// %for.body: +/// %x = phi i64 [ 0, %ph ], [ %add, %for.body ] +/// %casted_phi = "ExtTrunc i64 %x" +/// %add = add i64 %casted_phi, %step +/// +/// where %x is given in \p PN, +/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate, +/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of +/// several forms, for example, such as: +/// ExtTrunc1: %casted_phi = and %x, 2^n-1 +/// or: +/// ExtTrunc2: %t = shl %x, m +/// %casted_phi = ashr %t, m +/// +/// If we are able to find such sequence, we return the instructions +/// we found, namely %casted_phi and the instructions on its use-def chain up +/// to the phi (not including the phi). +bool getCastsForInductionPHI( + PredicatedScalarEvolution &PSE, const SCEVUnknown *PhiScev, + const SCEVAddRecExpr *AR, SmallVectorImpl<Instruction *> &CastInsts) { + + assert(CastInsts.empty() && "CastInsts is expected to be empty."); + auto *PN = cast<PHINode>(PhiScev->getValue()); + assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression"); + const Loop *L = AR->getLoop(); + + // Find any cast instructions that participate in the def-use chain of + // PhiScev in the loop. + // FORNOW/TODO: We currently expect the def-use chain to include only + // two-operand instructions, where one of the operands is an invariant. + // createAddRecFromPHIWithCasts() currently does not support anything more + // involved than that, so we keep the search simple. This can be + // extended/generalized as needed. + + auto getDef = [&](const Value *Val) -> Value * { + const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val); + if (!BinOp) + return nullptr; + Value *Op0 = BinOp->getOperand(0); + Value *Op1 = BinOp->getOperand(1); + Value *Def = nullptr; + if (L->isLoopInvariant(Op0)) + Def = Op1; + else if (L->isLoopInvariant(Op1)) + Def = Op0; + return Def; + }; + + // Look for the instruction that defines the induction via the + // loop backedge. + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch) + return false; + Value *Val = PN->getIncomingValueForBlock(Latch); + if (!Val) + return false; + + // Follow the def-use chain until the induction phi is reached. + // If on the way we encounter a Value that has the same SCEV Expr as the + // phi node, we can consider the instructions we visit from that point + // as part of the cast-sequence that can be ignored. + bool InCastSequence = false; + auto *Inst = dyn_cast<Instruction>(Val); + while (Val != PN) { + // If we encountered a phi node other than PN, or if we left the loop, + // we bail out. + if (!Inst || !L->contains(Inst)) { + return false; + } + auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val)); + if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR)) + InCastSequence = true; + if (InCastSequence) { + // Only the last instruction in the cast sequence is expected to have + // uses outside the induction def-use chain. + if (!CastInsts.empty()) + if (!Inst->hasOneUse()) + return false; + CastInsts.push_back(Inst); + } + Val = getDef(Val); + if (!Val) + return false; + Inst = dyn_cast<Instruction>(Val); + } + + return InCastSequence; +} + bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, PredicatedScalarEvolution &PSE, InductionDescriptor &D, @@ -869,13 +981,26 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, return false; } + // Record any Cast instructions that participate in the induction update + const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev); + // If we started from an UnknownSCEV, and managed to build an addRecurrence + // only after enabling Assume with PSCEV, this means we may have encountered + // cast instructions that required adding a runtime check in order to + // guarantee the correctness of the AddRecurence respresentation of the + // induction. + if (PhiScev != AR && SymbolicPhi) { + SmallVector<Instruction *, 2> Casts; + if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts)) + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts); + } + return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR); } -bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, - ScalarEvolution *SE, - InductionDescriptor &D, - const SCEV *Expr) { +bool InductionDescriptor::isInductionPHI( + PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE, + InductionDescriptor &D, const SCEV *Expr, + SmallVectorImpl<Instruction *> *CastsToIgnore) { Type *PhiTy = Phi->getType(); // We only handle integer and pointer inductions variables. if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy()) @@ -894,7 +1019,7 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, // FIXME: We should treat this as a uniform. Unfortunately, we // don't currently know how to handled uniform PHIs. DEBUG(dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n"); - return false; + return false; } Value *StartValue = @@ -907,7 +1032,8 @@ bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop, return false; if (PhiTy->isIntegerTy()) { - D = InductionDescriptor(StartValue, IK_IntInduction, Step); + D = InductionDescriptor(StartValue, IK_IntInduction, Step, /*BOp=*/ nullptr, + CastsToIgnore); return true; } @@ -1115,6 +1241,149 @@ Optional<const MDOperand *> llvm::findStringMetadataForLoop(Loop *TheLoop, return None; } +/// Does a BFS from a given node to all of its children inside a given loop. +/// The returned vector of nodes includes the starting point. +SmallVector<DomTreeNode *, 16> +llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) { + SmallVector<DomTreeNode *, 16> Worklist; + auto AddRegionToWorklist = [&](DomTreeNode *DTN) { + // Only include subregions in the top level loop. + BasicBlock *BB = DTN->getBlock(); + if (CurLoop->contains(BB)) + Worklist.push_back(DTN); + }; + + AddRegionToWorklist(N); + + for (size_t I = 0; I < Worklist.size(); I++) + for (DomTreeNode *Child : Worklist[I]->getChildren()) + AddRegionToWorklist(Child); + + return Worklist; +} + +void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr, + ScalarEvolution *SE = nullptr, + LoopInfo *LI = nullptr) { + assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!"); + auto *Preheader = L->getLoopPreheader(); + assert(Preheader && "Preheader should exist!"); + + // Now that we know the removal is safe, remove the loop by changing the + // branch from the preheader to go to the single exit block. + // + // Because we're deleting a large chunk of code at once, the sequence in which + // we remove things is very important to avoid invalidation issues. + + // Tell ScalarEvolution that the loop is deleted. Do this before + // deleting the loop so that ScalarEvolution can look at the loop + // to determine what it needs to clean up. + if (SE) + SE->forgetLoop(L); + + auto *ExitBlock = L->getUniqueExitBlock(); + assert(ExitBlock && "Should have a unique exit block!"); + assert(L->hasDedicatedExits() && "Loop should have dedicated exits!"); + + auto *OldBr = dyn_cast<BranchInst>(Preheader->getTerminator()); + assert(OldBr && "Preheader must end with a branch"); + assert(OldBr->isUnconditional() && "Preheader must have a single successor"); + // Connect the preheader to the exit block. Keep the old edge to the header + // around to perform the dominator tree update in two separate steps + // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge + // preheader -> header. + // + // + // 0. Preheader 1. Preheader 2. Preheader + // | | | | + // V | V | + // Header <--\ | Header <--\ | Header <--\ + // | | | | | | | | | | | + // | V | | | V | | | V | + // | Body --/ | | Body --/ | | Body --/ + // V V V V V + // Exit Exit Exit + // + // By doing this is two separate steps we can perform the dominator tree + // update without using the batch update API. + // + // Even when the loop is never executed, we cannot remove the edge from the + // source block to the exit block. Consider the case where the unexecuted loop + // branches back to an outer loop. If we deleted the loop and removed the edge + // coming to this inner loop, this will break the outer loop structure (by + // deleting the backedge of the outer loop). If the outer loop is indeed a + // non-loop, it will be deleted in a future iteration of loop deletion pass. + IRBuilder<> Builder(OldBr); + Builder.CreateCondBr(Builder.getFalse(), L->getHeader(), ExitBlock); + // Remove the old branch. The conditional branch becomes a new terminator. + OldBr->eraseFromParent(); + + // Rewrite phis in the exit block to get their inputs from the Preheader + // instead of the exiting block. + BasicBlock::iterator BI = ExitBlock->begin(); + while (PHINode *P = dyn_cast<PHINode>(BI)) { + // Set the zero'th element of Phi to be from the preheader and remove all + // other incoming values. Given the loop has dedicated exits, all other + // incoming values must be from the exiting blocks. + int PredIndex = 0; + P->setIncomingBlock(PredIndex, Preheader); + // Removes all incoming values from all other exiting blocks (including + // duplicate values from an exiting block). + // Nuke all entries except the zero'th entry which is the preheader entry. + // NOTE! We need to remove Incoming Values in the reverse order as done + // below, to keep the indices valid for deletion (removeIncomingValues + // updates getNumIncomingValues and shifts all values down into the operand + // being deleted). + for (unsigned i = 0, e = P->getNumIncomingValues() - 1; i != e; ++i) + P->removeIncomingValue(e - i, false); + + assert((P->getNumIncomingValues() == 1 && + P->getIncomingBlock(PredIndex) == Preheader) && + "Should have exactly one value and that's from the preheader!"); + ++BI; + } + + // Disconnect the loop body by branching directly to its exit. + Builder.SetInsertPoint(Preheader->getTerminator()); + Builder.CreateBr(ExitBlock); + // Remove the old branch. + Preheader->getTerminator()->eraseFromParent(); + + if (DT) { + // Update the dominator tree by informing it about the new edge from the + // preheader to the exit. + DT->insertEdge(Preheader, ExitBlock); + // Inform the dominator tree about the removed edge. + DT->deleteEdge(Preheader, L->getHeader()); + } + + // Remove the block from the reference counting scheme, so that we can + // delete it freely later. + for (auto *Block : L->blocks()) + Block->dropAllReferences(); + + if (LI) { + // Erase the instructions and the blocks without having to worry + // about ordering because we already dropped the references. + // NOTE: This iteration is safe because erasing the block does not remove + // its entry from the loop's block list. We do that in the next section. + for (Loop::block_iterator LpI = L->block_begin(), LpE = L->block_end(); + LpI != LpE; ++LpI) + (*LpI)->eraseFromParent(); + + // Finally, the blocks from loopinfo. This has to happen late because + // otherwise our loop iterators won't work. + + SmallPtrSet<BasicBlock *, 8> blocks; + blocks.insert(L->block_begin(), L->block_end()); + for (BasicBlock *BB : blocks) + LI->removeBlock(BB); + + // The last step is to update LoopInfo now that we've eliminated this loop. + LI->erase(L); + } +} + /// Returns true if the instruction in a loop is guaranteed to execute at least /// once. bool llvm::isGuaranteedToExecute(const Instruction &Inst, @@ -1194,7 +1463,7 @@ Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) { static Value *addFastMathFlag(Value *V) { if (isa<FPMathOperator>(V)) { FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); cast<Instruction>(V)->setFastMathFlags(Flags); } return V; @@ -1256,8 +1525,8 @@ Value *llvm::createSimpleTargetReduction( using RD = RecurrenceDescriptor; RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid; // TODO: Support creating ordered reductions. - FastMathFlags FMFUnsafe; - FMFUnsafe.setUnsafeAlgebra(); + FastMathFlags FMFFast; + FMFFast.setFast(); switch (Opcode) { case Instruction::Add: @@ -1278,14 +1547,14 @@ Value *llvm::createSimpleTargetReduction( case Instruction::FAdd: BuildFunc = [&]() { auto Rdx = Builder.CreateFAddReduce(ScalarUdf, Src); - cast<CallInst>(Rdx)->setFastMathFlags(FMFUnsafe); + cast<CallInst>(Rdx)->setFastMathFlags(FMFFast); return Rdx; }; break; case Instruction::FMul: BuildFunc = [&]() { auto Rdx = Builder.CreateFMulReduce(ScalarUdf, Src); - cast<CallInst>(Rdx)->setFastMathFlags(FMFUnsafe); + cast<CallInst>(Rdx)->setFastMathFlags(FMFFast); return Rdx; }; break; @@ -1321,55 +1590,39 @@ Value *llvm::createSimpleTargetReduction( } /// Create a vector reduction using a given recurrence descriptor. -Value *llvm::createTargetReduction(IRBuilder<> &Builder, +Value *llvm::createTargetReduction(IRBuilder<> &B, const TargetTransformInfo *TTI, RecurrenceDescriptor &Desc, Value *Src, bool NoNaN) { // TODO: Support in-order reductions based on the recurrence descriptor. - RecurrenceDescriptor::RecurrenceKind RecKind = Desc.getRecurrenceKind(); + using RD = RecurrenceDescriptor; + RD::RecurrenceKind RecKind = Desc.getRecurrenceKind(); TargetTransformInfo::ReductionFlags Flags; Flags.NoNaN = NoNaN; - auto getSimpleRdx = [&](unsigned Opc) { - return createSimpleTargetReduction(Builder, TTI, Opc, Src, Flags); - }; switch (RecKind) { - case RecurrenceDescriptor::RK_FloatAdd: - return getSimpleRdx(Instruction::FAdd); - case RecurrenceDescriptor::RK_FloatMult: - return getSimpleRdx(Instruction::FMul); - case RecurrenceDescriptor::RK_IntegerAdd: - return getSimpleRdx(Instruction::Add); - case RecurrenceDescriptor::RK_IntegerMult: - return getSimpleRdx(Instruction::Mul); - case RecurrenceDescriptor::RK_IntegerAnd: - return getSimpleRdx(Instruction::And); - case RecurrenceDescriptor::RK_IntegerOr: - return getSimpleRdx(Instruction::Or); - case RecurrenceDescriptor::RK_IntegerXor: - return getSimpleRdx(Instruction::Xor); - case RecurrenceDescriptor::RK_IntegerMinMax: { - switch (Desc.getMinMaxRecurrenceKind()) { - case RecurrenceDescriptor::MRK_SIntMax: - Flags.IsSigned = true; - Flags.IsMaxOp = true; - break; - case RecurrenceDescriptor::MRK_UIntMax: - Flags.IsMaxOp = true; - break; - case RecurrenceDescriptor::MRK_SIntMin: - Flags.IsSigned = true; - break; - case RecurrenceDescriptor::MRK_UIntMin: - break; - default: - llvm_unreachable("Unhandled MRK"); - } - return getSimpleRdx(Instruction::ICmp); + case RD::RK_FloatAdd: + return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags); + case RD::RK_FloatMult: + return createSimpleTargetReduction(B, TTI, Instruction::FMul, Src, Flags); + case RD::RK_IntegerAdd: + return createSimpleTargetReduction(B, TTI, Instruction::Add, Src, Flags); + case RD::RK_IntegerMult: + return createSimpleTargetReduction(B, TTI, Instruction::Mul, Src, Flags); + case RD::RK_IntegerAnd: + return createSimpleTargetReduction(B, TTI, Instruction::And, Src, Flags); + case RD::RK_IntegerOr: + return createSimpleTargetReduction(B, TTI, Instruction::Or, Src, Flags); + case RD::RK_IntegerXor: + return createSimpleTargetReduction(B, TTI, Instruction::Xor, Src, Flags); + case RD::RK_IntegerMinMax: { + RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind(); + Flags.IsMaxOp = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax); + Flags.IsSigned = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin); + return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags); } - case RecurrenceDescriptor::RK_FloatMinMax: { - Flags.IsMaxOp = - Desc.getMinMaxRecurrenceKind() == RecurrenceDescriptor::MRK_FloatMax; - return getSimpleRdx(Instruction::FCmp); + case RD::RK_FloatMinMax: { + Flags.IsMaxOp = Desc.getMinMaxRecurrenceKind() == RD::MRK_FloatMax; + return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags); } default: llvm_unreachable("Unhandled RecKind"); diff --git a/lib/Transforms/Utils/LowerMemIntrinsics.cpp b/lib/Transforms/Utils/LowerMemIntrinsics.cpp index 900450b40061..57dc225e9dab 100644 --- a/lib/Transforms/Utils/LowerMemIntrinsics.cpp +++ b/lib/Transforms/Utils/LowerMemIntrinsics.cpp @@ -168,13 +168,14 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, IntegerType *ILengthType = dyn_cast<IntegerType>(CopyLenType); assert(ILengthType && "expected size argument to memcpy to be an integer type!"); + Type *Int8Type = Type::getInt8Ty(Ctx); + bool LoopOpIsInt8 = LoopOpType == Int8Type; ConstantInt *CILoopOpSize = ConstantInt::get(ILengthType, LoopOpSize); - Value *RuntimeLoopCount = PLBuilder.CreateUDiv(CopyLen, CILoopOpSize); - Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize); - Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual); - + Value *RuntimeLoopCount = LoopOpIsInt8 ? + CopyLen : + PLBuilder.CreateUDiv(CopyLen, CILoopOpSize); BasicBlock *LoopBB = - BasicBlock::Create(Ctx, "loop-memcpy-expansion", ParentFunc, nullptr); + BasicBlock::Create(Ctx, "loop-memcpy-expansion", ParentFunc, PostLoopBB); IRBuilder<> LoopBuilder(LoopBB); PHINode *LoopIndex = LoopBuilder.CreatePHI(CopyLenType, 2, "loop-index"); @@ -189,11 +190,15 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(CopyLenType, 1U)); LoopIndex->addIncoming(NewIndex, LoopBB); - Type *Int8Type = Type::getInt8Ty(Ctx); - if (LoopOpType != Int8Type) { + if (!LoopOpIsInt8) { + // Add in the + Value *RuntimeResidual = PLBuilder.CreateURem(CopyLen, CILoopOpSize); + Value *RuntimeBytesCopied = PLBuilder.CreateSub(CopyLen, RuntimeResidual); + // Loop body for the residual copy. BasicBlock *ResLoopBB = BasicBlock::Create(Ctx, "loop-memcpy-residual", - PreLoopBB->getParent(), nullptr); + PreLoopBB->getParent(), + PostLoopBB); // Residual loop header. BasicBlock *ResHeaderBB = BasicBlock::Create( Ctx, "loop-memcpy-residual-header", PreLoopBB->getParent(), nullptr); @@ -258,61 +263,6 @@ void llvm::createMemCpyLoopUnknownSize(Instruction *InsertBefore, } } -void llvm::createMemCpyLoop(Instruction *InsertBefore, - Value *SrcAddr, Value *DstAddr, Value *CopyLen, - unsigned SrcAlign, unsigned DestAlign, - bool SrcIsVolatile, bool DstIsVolatile) { - Type *TypeOfCopyLen = CopyLen->getType(); - - BasicBlock *OrigBB = InsertBefore->getParent(); - Function *F = OrigBB->getParent(); - BasicBlock *NewBB = - InsertBefore->getParent()->splitBasicBlock(InsertBefore, "split"); - BasicBlock *LoopBB = BasicBlock::Create(F->getContext(), "loadstoreloop", - F, NewBB); - - IRBuilder<> Builder(OrigBB->getTerminator()); - - // SrcAddr and DstAddr are expected to be pointer types, - // so no check is made here. - unsigned SrcAS = cast<PointerType>(SrcAddr->getType())->getAddressSpace(); - unsigned DstAS = cast<PointerType>(DstAddr->getType())->getAddressSpace(); - - // Cast pointers to (char *) - SrcAddr = Builder.CreateBitCast(SrcAddr, Builder.getInt8PtrTy(SrcAS)); - DstAddr = Builder.CreateBitCast(DstAddr, Builder.getInt8PtrTy(DstAS)); - - Builder.CreateCondBr( - Builder.CreateICmpEQ(ConstantInt::get(TypeOfCopyLen, 0), CopyLen), NewBB, - LoopBB); - OrigBB->getTerminator()->eraseFromParent(); - - IRBuilder<> LoopBuilder(LoopBB); - PHINode *LoopIndex = LoopBuilder.CreatePHI(TypeOfCopyLen, 0); - LoopIndex->addIncoming(ConstantInt::get(TypeOfCopyLen, 0), OrigBB); - - // load from SrcAddr+LoopIndex - // TODO: we can leverage the align parameter of llvm.memcpy for more efficient - // word-sized loads and stores. - Value *Element = - LoopBuilder.CreateLoad(LoopBuilder.CreateInBoundsGEP( - LoopBuilder.getInt8Ty(), SrcAddr, LoopIndex), - SrcIsVolatile); - // store at DstAddr+LoopIndex - LoopBuilder.CreateStore(Element, - LoopBuilder.CreateInBoundsGEP(LoopBuilder.getInt8Ty(), - DstAddr, LoopIndex), - DstIsVolatile); - - // The value for LoopIndex coming from backedge is (LoopIndex + 1) - Value *NewIndex = - LoopBuilder.CreateAdd(LoopIndex, ConstantInt::get(TypeOfCopyLen, 1)); - LoopIndex->addIncoming(NewIndex, LoopBB); - - LoopBuilder.CreateCondBr(LoopBuilder.CreateICmpULT(NewIndex, CopyLen), LoopBB, - NewBB); -} - // Lower memmove to IR. memmove is required to correctly copy overlapping memory // regions; therefore, it has to check the relative positions of the source and // destination pointers and choose the copy direction accordingly. @@ -454,38 +404,26 @@ static void createMemSetLoop(Instruction *InsertBefore, void llvm::expandMemCpyAsLoop(MemCpyInst *Memcpy, const TargetTransformInfo &TTI) { - // Original implementation - if (!TTI.useWideIRMemcpyLoopLowering()) { - createMemCpyLoop(/* InsertBefore */ Memcpy, - /* SrcAddr */ Memcpy->getRawSource(), - /* DstAddr */ Memcpy->getRawDest(), - /* CopyLen */ Memcpy->getLength(), - /* SrcAlign */ Memcpy->getAlignment(), - /* DestAlign */ Memcpy->getAlignment(), - /* SrcIsVolatile */ Memcpy->isVolatile(), - /* DstIsVolatile */ Memcpy->isVolatile()); + if (ConstantInt *CI = dyn_cast<ConstantInt>(Memcpy->getLength())) { + createMemCpyLoopKnownSize(/* InsertBefore */ Memcpy, + /* SrcAddr */ Memcpy->getRawSource(), + /* DstAddr */ Memcpy->getRawDest(), + /* CopyLen */ CI, + /* SrcAlign */ Memcpy->getAlignment(), + /* DestAlign */ Memcpy->getAlignment(), + /* SrcIsVolatile */ Memcpy->isVolatile(), + /* DstIsVolatile */ Memcpy->isVolatile(), + /* TargetTransformInfo */ TTI); } else { - if (ConstantInt *CI = dyn_cast<ConstantInt>(Memcpy->getLength())) { - createMemCpyLoopKnownSize(/* InsertBefore */ Memcpy, + createMemCpyLoopUnknownSize(/* InsertBefore */ Memcpy, /* SrcAddr */ Memcpy->getRawSource(), /* DstAddr */ Memcpy->getRawDest(), - /* CopyLen */ CI, + /* CopyLen */ Memcpy->getLength(), /* SrcAlign */ Memcpy->getAlignment(), /* DestAlign */ Memcpy->getAlignment(), /* SrcIsVolatile */ Memcpy->isVolatile(), /* DstIsVolatile */ Memcpy->isVolatile(), - /* TargetTransformInfo */ TTI); - } else { - createMemCpyLoopUnknownSize(/* InsertBefore */ Memcpy, - /* SrcAddr */ Memcpy->getRawSource(), - /* DstAddr */ Memcpy->getRawDest(), - /* CopyLen */ Memcpy->getLength(), - /* SrcAlign */ Memcpy->getAlignment(), - /* DestAlign */ Memcpy->getAlignment(), - /* SrcIsVolatile */ Memcpy->isVolatile(), - /* DstIsVolatile */ Memcpy->isVolatile(), - /* TargetTransfomrInfo */ TTI); - } + /* TargetTransfomrInfo */ TTI); } } diff --git a/lib/Transforms/Utils/LowerSwitch.cpp b/lib/Transforms/Utils/LowerSwitch.cpp index 890afbc46e63..344cb35df986 100644 --- a/lib/Transforms/Utils/LowerSwitch.cpp +++ b/lib/Transforms/Utils/LowerSwitch.cpp @@ -13,46 +13,65 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" #include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <limits> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "lower-switch" namespace { + struct IntRange { int64_t Low, High; }; - // Return true iff R is covered by Ranges. - static bool IsInRanges(const IntRange &R, - const std::vector<IntRange> &Ranges) { - // Note: Ranges must be sorted, non-overlapping and non-adjacent. - - // Find the first range whose High field is >= R.High, - // then check if the Low field is <= R.Low. If so, we - // have a Range that covers R. - auto I = std::lower_bound( - Ranges.begin(), Ranges.end(), R, - [](const IntRange &A, const IntRange &B) { return A.High < B.High; }); - return I != Ranges.end() && I->Low <= R.Low; - } + +} // end anonymous namespace + +// Return true iff R is covered by Ranges. +static bool IsInRanges(const IntRange &R, + const std::vector<IntRange> &Ranges) { + // Note: Ranges must be sorted, non-overlapping and non-adjacent. + + // Find the first range whose High field is >= R.High, + // then check if the Low field is <= R.Low. If so, we + // have a Range that covers R. + auto I = std::lower_bound( + Ranges.begin(), Ranges.end(), R, + [](const IntRange &A, const IntRange &B) { return A.High < B.High; }); + return I != Ranges.end() && I->Low <= R.Low; +} + +namespace { /// Replace all SwitchInst instructions with chained branch instructions. class LowerSwitch : public FunctionPass { public: - static char ID; // Pass identification, replacement for typeid + // Pass identification, replacement for typeid + static char ID; + LowerSwitch() : FunctionPass(ID) { initializeLowerSwitchPass(*PassRegistry::getPassRegistry()); } @@ -68,8 +87,9 @@ namespace { : Low(low), High(high), BB(bb) {} }; - typedef std::vector<CaseRange> CaseVector; - typedef std::vector<CaseRange>::iterator CaseItr; + using CaseVector = std::vector<CaseRange>; + using CaseItr = std::vector<CaseRange>::iterator; + private: void processSwitchInst(SwitchInst *SI, SmallPtrSetImpl<BasicBlock*> &DeleteList); @@ -86,22 +106,24 @@ namespace { /// The comparison function for sorting the switch case values in the vector. /// WARNING: Case ranges should be disjoint! struct CaseCmp { - bool operator () (const LowerSwitch::CaseRange& C1, - const LowerSwitch::CaseRange& C2) { - + bool operator()(const LowerSwitch::CaseRange& C1, + const LowerSwitch::CaseRange& C2) { const ConstantInt* CI1 = cast<const ConstantInt>(C1.Low); const ConstantInt* CI2 = cast<const ConstantInt>(C2.High); return CI1->getValue().slt(CI2->getValue()); } }; -} + +} // end anonymous namespace char LowerSwitch::ID = 0; -INITIALIZE_PASS(LowerSwitch, "lowerswitch", - "Lower SwitchInst's to branches", false, false) // Publicly exposed interface to pass... char &llvm::LowerSwitchID = LowerSwitch::ID; + +INITIALIZE_PASS(LowerSwitch, "lowerswitch", + "Lower SwitchInst's to branches", false, false) + // createLowerSwitchPass - Interface to this file... FunctionPass *llvm::createLowerSwitchPass() { return new LowerSwitch(); @@ -136,6 +158,7 @@ bool LowerSwitch::runOnFunction(Function &F) { static raw_ostream& operator<<(raw_ostream &O, const LowerSwitch::CaseVector &C) LLVM_ATTRIBUTE_USED; + static raw_ostream& operator<<(raw_ostream &O, const LowerSwitch::CaseVector &C) { O << "["; @@ -186,7 +209,7 @@ static void fixPhis(BasicBlock *SuccBB, BasicBlock *OrigBB, BasicBlock *NewBB, } // Remove incoming values in the reverse order to prevent invalidating // *successive* index. - for (unsigned III : reverse(Indices)) + for (unsigned III : llvm::reverse(Indices)) PN->removeIncomingValue(III); } } @@ -294,8 +317,7 @@ LowerSwitch::switchConvert(CaseItr Begin, CaseItr End, ConstantInt *LowerBound, /// value, so the jump to the "default" branch is warranted. BasicBlock* LowerSwitch::newLeafBlock(CaseRange& Leaf, Value* Val, BasicBlock* OrigBlock, - BasicBlock* Default) -{ + BasicBlock* Default) { Function* F = OrigBlock->getParent(); BasicBlock* NewLeaf = BasicBlock::Create(Val->getContext(), "LeafBlock"); F->getBasicBlockList().insert(++OrigBlock->getIterator(), NewLeaf); @@ -442,7 +464,8 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, unsigned MaxPop = 0; BasicBlock *PopSucc = nullptr; - IntRange R = { INT64_MIN, INT64_MAX }; + IntRange R = {std::numeric_limits<int64_t>::min(), + std::numeric_limits<int64_t>::max()}; UnreachableRanges.push_back(R); for (const auto &I : Cases) { int64_t Low = I.Low->getSExtValue(); @@ -457,8 +480,8 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, assert(Low > LastRange.Low); LastRange.High = Low - 1; } - if (High != INT64_MAX) { - IntRange R = { High + 1, INT64_MAX }; + if (High != std::numeric_limits<int64_t>::max()) { + IntRange R = { High + 1, std::numeric_limits<int64_t>::max() }; UnreachableRanges.push_back(R); } @@ -487,8 +510,8 @@ void LowerSwitch::processSwitchInst(SwitchInst *SI, assert(MaxPop > 0 && PopSucc); Default = PopSucc; Cases.erase( - remove_if(Cases, - [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), + llvm::remove_if( + Cases, [PopSucc](const CaseRange &R) { return R.BB == PopSucc; }), Cases.end()); // If there are no cases left, just branch. diff --git a/lib/Transforms/Utils/Mem2Reg.cpp b/lib/Transforms/Utils/Mem2Reg.cpp index b659a2e4463f..29f289b62da0 100644 --- a/lib/Transforms/Utils/Mem2Reg.cpp +++ b/lib/Transforms/Utils/Mem2Reg.cpp @@ -15,12 +15,17 @@ #include "llvm/Transforms/Utils/Mem2Reg.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" -#include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" +#include <vector> + using namespace llvm; #define DEBUG_TYPE "mem2reg" @@ -33,7 +38,7 @@ static bool promoteMemoryToRegister(Function &F, DominatorTree &DT, BasicBlock &BB = F.getEntryBlock(); // Get the entry node for the function bool Changed = false; - while (1) { + while (true) { Allocas.clear(); // Find allocas that are safe to promote, by looking at all instructions in @@ -65,15 +70,17 @@ PreservedAnalyses PromotePass::run(Function &F, FunctionAnalysisManager &AM) { } namespace { + struct PromoteLegacyPass : public FunctionPass { - static char ID; // Pass identification, replacement for typeid + // Pass identification, replacement for typeid + static char ID; + PromoteLegacyPass() : FunctionPass(ID) { initializePromoteLegacyPassPass(*PassRegistry::getPassRegistry()); } // runOnFunction - To run this pass, first we calculate the alloca // instructions that are safe for promotion, then we promote each one. - // bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; @@ -89,10 +96,12 @@ struct PromoteLegacyPass : public FunctionPass { AU.addRequired<DominatorTreeWrapperPass>(); AU.setPreservesCFG(); } - }; -} // end of anonymous namespace +}; + +} // end anonymous namespace char PromoteLegacyPass::ID = 0; + INITIALIZE_PASS_BEGIN(PromoteLegacyPass, "mem2reg", "Promote Memory to " "Register", false, false) @@ -102,7 +111,6 @@ INITIALIZE_PASS_END(PromoteLegacyPass, "mem2reg", "Promote Memory to Register", false, false) // createPromoteMemoryToRegister - Provide an entry point to create this pass. -// FunctionPass *llvm::createPromoteMemoryToRegisterPass() { return new PromoteLegacyPass(); } diff --git a/lib/Transforms/Utils/MetaRenamer.cpp b/lib/Transforms/Utils/MetaRenamer.cpp index 9f2ad540c83d..0f7bd76c03ca 100644 --- a/lib/Transforms/Utils/MetaRenamer.cpp +++ b/lib/Transforms/Utils/MetaRenamer.cpp @@ -15,16 +15,30 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" #include "llvm/IR/TypeFinder.h" #include "llvm/Pass.h" #include "llvm/Transforms/IPO.h" + using namespace llvm; +static const char *const metaNames[] = { + // See http://en.wikipedia.org/wiki/Metasyntactic_variable + "foo", "bar", "baz", "quux", "barney", "snork", "zot", "blam", "hoge", + "wibble", "wobble", "widget", "wombat", "ham", "eggs", "pluto", "spam" +}; + namespace { // This PRNG is from the ISO C spec. It is intentionally simple and @@ -43,12 +57,6 @@ namespace { } }; - static const char *const metaNames[] = { - // See http://en.wikipedia.org/wiki/Metasyntactic_variable - "foo", "bar", "baz", "quux", "barney", "snork", "zot", "blam", "hoge", - "wibble", "wobble", "widget", "wombat", "ham", "eggs", "pluto", "spam" - }; - struct Renamer { Renamer(unsigned int seed) { prng.srand(seed); @@ -62,7 +70,9 @@ namespace { }; struct MetaRenamer : public ModulePass { - static char ID; // Pass identification, replacement for typeid + // Pass identification, replacement for typeid + static char ID; + MetaRenamer() : ModulePass(ID) { initializeMetaRenamerPass(*PassRegistry::getPassRegistry()); } @@ -123,7 +133,11 @@ namespace { TLI.getLibFunc(F, Tmp)) continue; - F.setName(renamer.newName()); + // Leave @main alone. The output of -metarenamer might be passed to + // lli for execution and the latter needs a main entry point. + if (Name != "main") + F.setName(renamer.newName()); + runOnFunction(F); } return true; @@ -144,14 +158,17 @@ namespace { return true; } }; -} + +} // end anonymous namespace char MetaRenamer::ID = 0; + INITIALIZE_PASS_BEGIN(MetaRenamer, "metarenamer", "Assign new names to everything", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_END(MetaRenamer, "metarenamer", "Assign new names to everything", false, false) + //===----------------------------------------------------------------------===// // // MetaRenamer - Rename everything with metasyntactic names. diff --git a/lib/Transforms/Utils/ModuleUtils.cpp b/lib/Transforms/Utils/ModuleUtils.cpp index 2ef3d6336ae2..ba4b7f3cc263 100644 --- a/lib/Transforms/Utils/ModuleUtils.cpp +++ b/lib/Transforms/Utils/ModuleUtils.cpp @@ -243,7 +243,7 @@ std::string llvm::getUniqueModuleId(Module *M) { bool ExportsSymbols = false; auto AddGlobal = [&](GlobalValue &GV) { if (GV.isDeclaration() || GV.getName().startswith("llvm.") || - !GV.hasExternalLinkage()) + !GV.hasExternalLinkage() || GV.hasComdat()) return; ExportsSymbols = true; Md5.update(GV.getName()); diff --git a/lib/Transforms/Utils/PredicateInfo.cpp b/lib/Transforms/Utils/PredicateInfo.cpp index d4cdaede6b86..d47be6ea566b 100644 --- a/lib/Transforms/Utils/PredicateInfo.cpp +++ b/lib/Transforms/Utils/PredicateInfo.cpp @@ -49,9 +49,10 @@ INITIALIZE_PASS_END(PredicateInfoPrinterLegacyPass, "print-predicateinfo", static cl::opt<bool> VerifyPredicateInfo( "verify-predicateinfo", cl::init(false), cl::Hidden, cl::desc("Verify PredicateInfo in legacy printer pass.")); -namespace { DEBUG_COUNTER(RenameCounter, "predicateinfo-rename", - "Controls which variables are renamed with predicateinfo") + "Controls which variables are renamed with predicateinfo"); + +namespace { // Given a predicate info that is a type of branching terminator, get the // branching block. const BasicBlock *getBranchBlock(const PredicateBase *PB) { @@ -610,7 +611,12 @@ void PredicateInfo::renameUses(SmallPtrSetImpl<Value *> &OpSet) { } convertUsesToDFSOrdered(Op, OrderedUses); - std::sort(OrderedUses.begin(), OrderedUses.end(), Compare); + // Here we require a stable sort because we do not bother to try to + // assign an order to the operands the uses represent. Thus, two + // uses in the same instruction do not have a strict sort order + // currently and will be considered equal. We could get rid of the + // stable sort by creating one if we wanted. + std::stable_sort(OrderedUses.begin(), OrderedUses.end(), Compare); SmallVector<ValueDFS, 8> RenameStack; // For each use, sorted into dfs order, push values and replaces uses with // top of stack, which will represent the reaching def. diff --git a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp index cdba982e6641..fcd3bd08482a 100644 --- a/lib/Transforms/Utils/PromoteMemoryToRegister.cpp +++ b/lib/Transforms/Utils/PromoteMemoryToRegister.cpp @@ -21,25 +21,38 @@ #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/ADT/TinyPtrVector.h" +#include "llvm/ADT/Twine.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/IteratedDominanceFrontier.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DIBuilder.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" -#include "llvm/IR/Metadata.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/Support/Casting.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/PromoteMemToReg.h" #include <algorithm> +#include <cassert> +#include <iterator> +#include <utility> +#include <vector> + using namespace llvm; #define DEBUG_TYPE "mem2reg" @@ -103,7 +116,7 @@ struct AllocaInfo { bool OnlyUsedInOneBlock; Value *AllocaPointerVal; - DbgDeclareInst *DbgDeclare; + TinyPtrVector<DbgInfoIntrinsic *> DbgDeclares; void clear() { DefiningBlocks.clear(); @@ -112,7 +125,7 @@ struct AllocaInfo { OnlyBlock = nullptr; OnlyUsedInOneBlock = true; AllocaPointerVal = nullptr; - DbgDeclare = nullptr; + DbgDeclares.clear(); } /// Scan the uses of the specified alloca, filling in the AllocaInfo used @@ -147,27 +160,21 @@ struct AllocaInfo { } } - DbgDeclare = FindAllocaDbgDeclare(AI); + DbgDeclares = FindDbgAddrUses(AI); } }; // Data package used by RenamePass() class RenamePassData { public: - typedef std::vector<Value *> ValVector; + using ValVector = std::vector<Value *>; + + RenamePassData(BasicBlock *B, BasicBlock *P, ValVector V) + : BB(B), Pred(P), Values(std::move(V)) {} - RenamePassData() : BB(nullptr), Pred(nullptr), Values() {} - RenamePassData(BasicBlock *B, BasicBlock *P, const ValVector &V) - : BB(B), Pred(P), Values(V) {} BasicBlock *BB; BasicBlock *Pred; ValVector Values; - - void swap(RenamePassData &RHS) { - std::swap(BB, RHS.BB); - std::swap(Pred, RHS.Pred); - Values.swap(RHS.Values); - } }; /// \brief This assigns and keeps a per-bb relative ordering of load/store @@ -223,12 +230,15 @@ public: struct PromoteMem2Reg { /// The alloca instructions being promoted. std::vector<AllocaInst *> Allocas; + DominatorTree &DT; DIBuilder DIB; + /// A cache of @llvm.assume intrinsics used by SimplifyInstruction. AssumptionCache *AC; const SimplifyQuery SQ; + /// Reverse mapping of Allocas. DenseMap<AllocaInst *, unsigned> AllocaLookup; @@ -252,10 +262,9 @@ struct PromoteMem2Reg { /// For each alloca, we keep track of the dbg.declare intrinsic that /// describes it, if any, so that we can convert it to a dbg.value /// intrinsic if the alloca gets promoted. - SmallVector<DbgDeclareInst *, 8> AllocaDbgDeclares; + SmallVector<TinyPtrVector<DbgInfoIntrinsic *>, 8> AllocaDbgDeclares; /// The set of basic blocks the renamer has already visited. - /// SmallPtrSet<BasicBlock *, 16> Visited; /// Contains a stable numbering of basic blocks to avoid non-determinstic @@ -298,7 +307,7 @@ private: bool QueuePhiNode(BasicBlock *BB, unsigned AllocaIdx, unsigned &Version); }; -} // end of anonymous namespace +} // end anonymous namespace /// Given a LoadInst LI this adds assume(LI != null) after it. static void addAssumeNonNull(AssumptionCache *AC, LoadInst *LI) { @@ -345,8 +354,8 @@ static void removeLifetimeIntrinsicUsers(AllocaInst *AI) { /// and thus must be phi-ed with undef. We fall back to the standard alloca /// promotion algorithm in that case. static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, - LargeBlockInfo &LBI, DominatorTree &DT, - AssumptionCache *AC) { + LargeBlockInfo &LBI, const DataLayout &DL, + DominatorTree &DT, AssumptionCache *AC) { StoreInst *OnlyStore = Info.OnlyStore; bool StoringGlobalVal = !isa<Instruction>(OnlyStore->getOperand(0)); BasicBlock *StoreBB = OnlyStore->getParent(); @@ -380,7 +389,6 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, Info.UsingBlocks.push_back(StoreBB); continue; } - } else if (LI->getParent() != StoreBB && !DT.dominates(StoreBB, LI->getParent())) { // If the load and store are in different blocks, use BB dominance to @@ -402,7 +410,7 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, // that information when we erase this Load. So we preserve // it with an assume. if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && - !llvm::isKnownNonNullAt(ReplVal, LI, &DT)) + !isKnownNonZero(ReplVal, DL, 0, AC, LI, &DT)) addAssumeNonNull(AC, LI); LI->replaceAllUsesWith(ReplVal); @@ -416,11 +424,11 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, // Record debuginfo for the store and remove the declaration's // debuginfo. - if (DbgDeclareInst *DDI = Info.DbgDeclare) { + for (DbgInfoIntrinsic *DII : Info.DbgDeclares) { DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); - ConvertDebugDeclareToDebugValue(DDI, Info.OnlyStore, DIB); - DDI->eraseFromParent(); - LBI.deleteValue(DDI); + ConvertDebugDeclareToDebugValue(DII, Info.OnlyStore, DIB); + DII->eraseFromParent(); + LBI.deleteValue(DII); } // Remove the (now dead) store and alloca. Info.OnlyStore->eraseFromParent(); @@ -449,6 +457,7 @@ static bool rewriteSingleStoreAlloca(AllocaInst *AI, AllocaInfo &Info, /// } static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, LargeBlockInfo &LBI, + const DataLayout &DL, DominatorTree &DT, AssumptionCache *AC) { // The trickiest case to handle is when we have large blocks. Because of this, @@ -457,7 +466,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, // make it efficient to get the index of various operations in the block. // Walk the use-def list of the alloca, getting the locations of all stores. - typedef SmallVector<std::pair<unsigned, StoreInst *>, 64> StoresByIndexTy; + using StoresByIndexTy = SmallVector<std::pair<unsigned, StoreInst *>, 64>; StoresByIndexTy StoresByIndex; for (User *U : AI->users()) @@ -497,7 +506,7 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, // information when we erase it. So we preserve it with an assume. Value *ReplVal = std::prev(I)->second->getOperand(0); if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && - !llvm::isKnownNonNullAt(ReplVal, LI, &DT)) + !isKnownNonZero(ReplVal, DL, 0, AC, LI, &DT)) addAssumeNonNull(AC, LI); LI->replaceAllUsesWith(ReplVal); @@ -511,9 +520,9 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, while (!AI->use_empty()) { StoreInst *SI = cast<StoreInst>(AI->user_back()); // Record debuginfo for the store before removing it. - if (DbgDeclareInst *DDI = Info.DbgDeclare) { + for (DbgInfoIntrinsic *DII : Info.DbgDeclares) { DIBuilder DIB(*AI->getModule(), /*AllowUnresolved*/ false); - ConvertDebugDeclareToDebugValue(DDI, SI, DIB); + ConvertDebugDeclareToDebugValue(DII, SI, DIB); } SI->eraseFromParent(); LBI.deleteValue(SI); @@ -523,9 +532,9 @@ static bool promoteSingleBlockAlloca(AllocaInst *AI, const AllocaInfo &Info, LBI.deleteValue(AI); // The alloca's debuginfo can be removed as well. - if (DbgDeclareInst *DDI = Info.DbgDeclare) { - DDI->eraseFromParent(); - LBI.deleteValue(DDI); + for (DbgInfoIntrinsic *DII : Info.DbgDeclares) { + DII->eraseFromParent(); + LBI.deleteValue(DII); } ++NumLocalPromoted; @@ -567,7 +576,7 @@ void PromoteMem2Reg::run() { // If there is only a single store to this value, replace any loads of // it that are directly dominated by the definition with the value stored. if (Info.DefiningBlocks.size() == 1) { - if (rewriteSingleStoreAlloca(AI, Info, LBI, DT, AC)) { + if (rewriteSingleStoreAlloca(AI, Info, LBI, SQ.DL, DT, AC)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); ++NumSingleStore; @@ -578,7 +587,7 @@ void PromoteMem2Reg::run() { // If the alloca is only read and written in one basic block, just perform a // linear sweep over the block to eliminate it. if (Info.OnlyUsedInOneBlock && - promoteSingleBlockAlloca(AI, Info, LBI, DT, AC)) { + promoteSingleBlockAlloca(AI, Info, LBI, SQ.DL, DT, AC)) { // The alloca has been processed, move on. RemoveFromAllocasList(AllocaNum); continue; @@ -593,8 +602,8 @@ void PromoteMem2Reg::run() { } // Remember the dbg.declare intrinsic describing this alloca, if any. - if (Info.DbgDeclare) - AllocaDbgDeclares[AllocaNum] = Info.DbgDeclare; + if (!Info.DbgDeclares.empty()) + AllocaDbgDeclares[AllocaNum] = Info.DbgDeclares; // Keep the reverse mapping of the 'Allocas' array for the rename pass. AllocaLookup[Allocas[AllocaNum]] = AllocaNum; @@ -604,7 +613,6 @@ void PromoteMem2Reg::run() { // nodes and see if we can optimize out some work by avoiding insertion of // dead phi nodes. - // Unique the set of defining blocks for efficient lookup. SmallPtrSet<BasicBlock *, 32> DefBlocks; DefBlocks.insert(Info.DefiningBlocks.begin(), Info.DefiningBlocks.end()); @@ -629,8 +637,8 @@ void PromoteMem2Reg::run() { }); unsigned CurrentVersion = 0; - for (unsigned i = 0, e = PHIBlocks.size(); i != e; ++i) - QueuePhiNode(PHIBlocks[i], AllocaNum, CurrentVersion); + for (BasicBlock *BB : PHIBlocks) + QueuePhiNode(BB, AllocaNum, CurrentVersion); } if (Allocas.empty()) @@ -641,19 +649,16 @@ void PromoteMem2Reg::run() { // Set the incoming values for the basic block to be null values for all of // the alloca's. We do this in case there is a load of a value that has not // been stored yet. In this case, it will get this null value. - // RenamePassData::ValVector Values(Allocas.size()); for (unsigned i = 0, e = Allocas.size(); i != e; ++i) Values[i] = UndefValue::get(Allocas[i]->getAllocatedType()); // Walks all basic blocks in the function performing the SSA rename algorithm // and inserting the phi nodes we marked as necessary - // std::vector<RenamePassData> RenamePassWorkList; RenamePassWorkList.emplace_back(&F.front(), nullptr, std::move(Values)); do { - RenamePassData RPD; - RPD.swap(RenamePassWorkList.back()); + RenamePassData RPD = std::move(RenamePassWorkList.back()); RenamePassWorkList.pop_back(); // RenamePass may add new worklist entries. RenamePass(RPD.BB, RPD.Pred, RPD.Values, RenamePassWorkList); @@ -663,9 +668,7 @@ void PromoteMem2Reg::run() { Visited.clear(); // Remove the allocas themselves from the function. - for (unsigned i = 0, e = Allocas.size(); i != e; ++i) { - Instruction *A = Allocas[i]; - + for (Instruction *A : Allocas) { // If there are any uses of the alloca instructions left, they must be in // unreachable basic blocks that were not processed by walking the dominator // tree. Just delete the users now. @@ -675,9 +678,9 @@ void PromoteMem2Reg::run() { } // Remove alloca's dbg.declare instrinsics from the function. - for (unsigned i = 0, e = AllocaDbgDeclares.size(); i != e; ++i) - if (DbgDeclareInst *DDI = AllocaDbgDeclares[i]) - DDI->eraseFromParent(); + for (auto &Declares : AllocaDbgDeclares) + for (auto *DII : Declares) + DII->eraseFromParent(); // Loop over all of the PHI nodes and see if there are any that we can get // rid of because they merge all of the same incoming values. This can @@ -714,7 +717,6 @@ void PromoteMem2Reg::run() { // hasn't traversed. If this is the case, the PHI nodes may not // have incoming values for all predecessors. Loop over all PHI nodes we have // created, inserting undef values if they are missing any incoming values. - // for (DenseMap<std::pair<unsigned, unsigned>, PHINode *>::iterator I = NewPhiNodes.begin(), E = NewPhiNodes.end(); @@ -762,8 +764,8 @@ void PromoteMem2Reg::run() { while ((SomePHI = dyn_cast<PHINode>(BBI++)) && SomePHI->getNumIncomingValues() == NumBadPreds) { Value *UndefVal = UndefValue::get(SomePHI->getType()); - for (unsigned pred = 0, e = Preds.size(); pred != e; ++pred) - SomePHI->addIncoming(UndefVal, Preds[pred]); + for (BasicBlock *Pred : Preds) + SomePHI->addIncoming(UndefVal, Pred); } } @@ -779,7 +781,6 @@ void PromoteMem2Reg::ComputeLiveInBlocks( AllocaInst *AI, AllocaInfo &Info, const SmallPtrSetImpl<BasicBlock *> &DefBlocks, SmallPtrSetImpl<BasicBlock *> &LiveInBlocks) { - // To determine liveness, we must iterate through the predecessors of blocks // where the def is live. Blocks are added to the worklist if we need to // check their predecessors. Start with all the using blocks. @@ -834,9 +835,7 @@ void PromoteMem2Reg::ComputeLiveInBlocks( // Since the value is live into BB, it is either defined in a predecessor or // live into it to. Add the preds to the worklist unless they are a // defining block. - for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { - BasicBlock *P = *PI; - + for (BasicBlock *P : predecessors(BB)) { // The value is not live into a predecessor if it defines the value. if (DefBlocks.count(P)) continue; @@ -906,8 +905,8 @@ NextIteration: // The currently active variable for this block is now the PHI. IncomingVals[AllocaNo] = APN; - if (DbgDeclareInst *DDI = AllocaDbgDeclares[AllocaNo]) - ConvertDebugDeclareToDebugValue(DDI, APN, DIB); + for (DbgInfoIntrinsic *DII : AllocaDbgDeclares[AllocaNo]) + ConvertDebugDeclareToDebugValue(DII, APN, DIB); // Get the next phi node. ++PNI; @@ -943,7 +942,7 @@ NextIteration: // that information when we erase this Load. So we preserve // it with an assume. if (AC && LI->getMetadata(LLVMContext::MD_nonnull) && - !llvm::isKnownNonNullAt(V, LI, &DT)) + !isKnownNonZero(V, SQ.DL, 0, AC, LI, &DT)) addAssumeNonNull(AC, LI); // Anything using the load now uses the current value. @@ -963,8 +962,8 @@ NextIteration: // what value were we writing? IncomingVals[ai->second] = SI->getOperand(0); // Record debuginfo for the store before removing it. - if (DbgDeclareInst *DDI = AllocaDbgDeclares[ai->second]) - ConvertDebugDeclareToDebugValue(DDI, SI, DIB); + for (DbgInfoIntrinsic *DII : AllocaDbgDeclares[ai->second]) + ConvertDebugDeclareToDebugValue(DII, SI, DIB); BB->getInstList().erase(SI); } } diff --git a/lib/Transforms/Utils/SSAUpdater.cpp b/lib/Transforms/Utils/SSAUpdater.cpp index 6ccf54e49dd3..e4b20b0faa15 100644 --- a/lib/Transforms/Utils/SSAUpdater.cpp +++ b/lib/Transforms/Utils/SSAUpdater.cpp @@ -15,7 +15,6 @@ #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" #include "llvm/ADT/TinyPtrVector.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/IR/BasicBlock.h" @@ -39,12 +38,13 @@ using namespace llvm; #define DEBUG_TYPE "ssaupdater" -typedef DenseMap<BasicBlock*, Value*> AvailableValsTy; +using AvailableValsTy = DenseMap<BasicBlock *, Value *>; + static AvailableValsTy &getAvailableVals(void *AV) { return *static_cast<AvailableValsTy*>(AV); } -SSAUpdater::SSAUpdater(SmallVectorImpl<PHINode*> *NewPHI) +SSAUpdater::SSAUpdater(SmallVectorImpl<PHINode *> *NewPHI) : InsertedPHIs(NewPHI) {} SSAUpdater::~SSAUpdater() { @@ -72,7 +72,7 @@ void SSAUpdater::AddAvailableValue(BasicBlock *BB, Value *V) { } static bool IsEquivalentPHI(PHINode *PHI, - SmallDenseMap<BasicBlock*, Value*, 8> &ValueMapping) { + SmallDenseMap<BasicBlock *, Value *, 8> &ValueMapping) { unsigned PHINumValues = PHI->getNumIncomingValues(); if (PHINumValues != ValueMapping.size()) return false; @@ -100,7 +100,7 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { // Otherwise, we have the hard case. Get the live-in values for each // predecessor. - SmallVector<std::pair<BasicBlock*, Value*>, 8> PredValues; + SmallVector<std::pair<BasicBlock *, Value *>, 8> PredValues; Value *SingularValue = nullptr; // We can get our predecessor info by walking the pred_iterator list, but it @@ -145,8 +145,8 @@ Value *SSAUpdater::GetValueInMiddleOfBlock(BasicBlock *BB) { // Otherwise, we do need a PHI: check to see if we already have one available // in this block that produces the right value. if (isa<PHINode>(BB->begin())) { - SmallDenseMap<BasicBlock*, Value*, 8> ValueMapping(PredValues.begin(), - PredValues.end()); + SmallDenseMap<BasicBlock *, Value *, 8> ValueMapping(PredValues.begin(), + PredValues.end()); PHINode *SomePHI; for (BasicBlock::iterator It = BB->begin(); (SomePHI = dyn_cast<PHINode>(It)); ++It) { @@ -218,11 +218,11 @@ namespace llvm { template<> class SSAUpdaterTraits<SSAUpdater> { public: - typedef BasicBlock BlkT; - typedef Value *ValT; - typedef PHINode PhiT; + using BlkT = BasicBlock; + using ValT = Value *; + using PhiT = PHINode; + using BlkSucc_iterator = succ_iterator; - typedef succ_iterator BlkSucc_iterator; static BlkSucc_iterator BlkSucc_begin(BlkT *BB) { return succ_begin(BB); } static BlkSucc_iterator BlkSucc_end(BlkT *BB) { return succ_end(BB); } @@ -253,7 +253,7 @@ public: /// FindPredecessorBlocks - Put the predecessors of Info->BB into the Preds /// vector, set Info->NumPreds, and allocate space in Info->Preds. static void FindPredecessorBlocks(BasicBlock *BB, - SmallVectorImpl<BasicBlock*> *Preds) { + SmallVectorImpl<BasicBlock *> *Preds) { // We can get our predecessor info by walking the pred_iterator list, // but it is relatively slow. If we already have PHI nodes in this // block, walk one of them to get the predecessor list instead. @@ -293,7 +293,6 @@ public: } /// ValueIsPHI - Check if a value is a PHI. - /// static PHINode *ValueIsPHI(Value *Val, SSAUpdater *Updater) { return dyn_cast<PHINode>(Val); } @@ -333,7 +332,7 @@ Value *SSAUpdater::GetValueAtEndOfBlockInternal(BasicBlock *BB) { //===----------------------------------------------------------------------===// LoadAndStorePromoter:: -LoadAndStorePromoter(ArrayRef<const Instruction*> Insts, +LoadAndStorePromoter(ArrayRef<const Instruction *> Insts, SSAUpdater &S, StringRef BaseName) : SSA(S) { if (Insts.empty()) return; @@ -349,11 +348,11 @@ LoadAndStorePromoter(ArrayRef<const Instruction*> Insts, } void LoadAndStorePromoter:: -run(const SmallVectorImpl<Instruction*> &Insts) const { +run(const SmallVectorImpl<Instruction *> &Insts) const { // First step: bucket up uses of the alloca by the block they occur in. // This is important because we have to handle multiple defs/uses in a block // ourselves: SSAUpdater is purely for cross-block references. - DenseMap<BasicBlock*, TinyPtrVector<Instruction*>> UsesByBlock; + DenseMap<BasicBlock *, TinyPtrVector<Instruction *>> UsesByBlock; for (Instruction *User : Insts) UsesByBlock[User->getParent()].push_back(User); @@ -361,12 +360,12 @@ run(const SmallVectorImpl<Instruction*> &Insts) const { // Okay, now we can iterate over all the blocks in the function with uses, // processing them. Keep track of which loads are loading a live-in value. // Walk the uses in the use-list order to be determinstic. - SmallVector<LoadInst*, 32> LiveInLoads; - DenseMap<Value*, Value*> ReplacedLoads; + SmallVector<LoadInst *, 32> LiveInLoads; + DenseMap<Value *, Value *> ReplacedLoads; for (Instruction *User : Insts) { BasicBlock *BB = User->getParent(); - TinyPtrVector<Instruction*> &BlockUses = UsesByBlock[BB]; + TinyPtrVector<Instruction *> &BlockUses = UsesByBlock[BB]; // If this block has already been processed, ignore this repeat use. if (BlockUses.empty()) continue; @@ -489,7 +488,7 @@ run(const SmallVectorImpl<Instruction*> &Insts) const { bool LoadAndStorePromoter::isInstInList(Instruction *I, - const SmallVectorImpl<Instruction*> &Insts) + const SmallVectorImpl<Instruction *> &Insts) const { return is_contained(Insts, I); } diff --git a/lib/Transforms/Utils/SimplifyCFG.cpp b/lib/Transforms/Utils/SimplifyCFG.cpp index 8784b9702141..f02f80cc1b78 100644 --- a/lib/Transforms/Utils/SimplifyCFG.cpp +++ b/lib/Transforms/Utils/SimplifyCFG.cpp @@ -22,12 +22,14 @@ #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/ConstantFolding.h" #include "llvm/Analysis/EHPersonalities.h" #include "llvm/Analysis/InstructionSimplify.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/CFG.h" #include "llvm/IR/CallSite.h" @@ -35,8 +37,8 @@ #include "llvm/IR/ConstantRange.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" #include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" #include "llvm/IR/GlobalValue.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/IRBuilder.h" @@ -53,6 +55,7 @@ #include "llvm/IR/Operator.h" #include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/Support/Casting.h" @@ -73,6 +76,7 @@ #include <iterator> #include <map> #include <set> +#include <tuple> #include <utility> #include <vector> @@ -141,12 +145,13 @@ namespace { // The first field contains the value that the switch produces when a certain // case group is selected, and the second field is a vector containing the // cases composing the case group. -typedef SmallVector<std::pair<Constant *, SmallVector<ConstantInt *, 4>>, 2> - SwitchCaseResultVectorTy; +using SwitchCaseResultVectorTy = + SmallVector<std::pair<Constant *, SmallVector<ConstantInt *, 4>>, 2>; + // The first field contains the phi node that generates a result of the switch // and the second field contains the value generated for a certain case in the // switch for that PHI. -typedef SmallVector<std::pair<PHINode *, Constant *>, 4> SwitchCaseResultsTy; +using SwitchCaseResultsTy = SmallVector<std::pair<PHINode *, Constant *>, 4>; /// ValueEqualityComparisonCase - Represents a case of a switch. struct ValueEqualityComparisonCase { @@ -167,11 +172,9 @@ struct ValueEqualityComparisonCase { class SimplifyCFGOpt { const TargetTransformInfo &TTI; const DataLayout &DL; - unsigned BonusInstThreshold; - AssumptionCache *AC; SmallPtrSetImpl<BasicBlock *> *LoopHeaders; - // See comments in SimplifyCFGOpt::SimplifySwitch. - bool LateSimplifyCFG; + const SimplifyCFGOptions &Options; + Value *isValueEqualityComparison(TerminatorInst *TI); BasicBlock *GetValueEqualityComparisonCases( TerminatorInst *TI, std::vector<ValueEqualityComparisonCase> &Cases); @@ -194,11 +197,9 @@ class SimplifyCFGOpt { public: SimplifyCFGOpt(const TargetTransformInfo &TTI, const DataLayout &DL, - unsigned BonusInstThreshold, AssumptionCache *AC, SmallPtrSetImpl<BasicBlock *> *LoopHeaders, - bool LateSimplifyCFG) - : TTI(TTI), DL(DL), BonusInstThreshold(BonusInstThreshold), AC(AC), - LoopHeaders(LoopHeaders), LateSimplifyCFG(LateSimplifyCFG) {} + const SimplifyCFGOptions &Opts) + : TTI(TTI), DL(DL), LoopHeaders(LoopHeaders), Options(Opts) {} bool run(BasicBlock *BB); }; @@ -438,18 +439,24 @@ namespace { /// fail. struct ConstantComparesGatherer { const DataLayout &DL; - Value *CompValue; /// Value found for the switch comparison - Value *Extra; /// Extra clause to be checked before the switch - SmallVector<ConstantInt *, 8> Vals; /// Set of integers to match in switch - unsigned UsedICmps; /// Number of comparisons matched in the and/or chain + + /// Value found for the switch comparison + Value *CompValue = nullptr; + + /// Extra clause to be checked before the switch + Value *Extra = nullptr; + + /// Set of integers to match in switch + SmallVector<ConstantInt *, 8> Vals; + + /// Number of comparisons matched in the and/or chain + unsigned UsedICmps = 0; /// Construct and compute the result for the comparison instruction Cond - ConstantComparesGatherer(Instruction *Cond, const DataLayout &DL) - : DL(DL), CompValue(nullptr), Extra(nullptr), UsedICmps(0) { + ConstantComparesGatherer(Instruction *Cond, const DataLayout &DL) : DL(DL) { gather(Cond); } - /// Prevent copy ConstantComparesGatherer(const ConstantComparesGatherer &) = delete; ConstantComparesGatherer & operator=(const ConstantComparesGatherer &) = delete; @@ -487,7 +494,6 @@ private: // (x & ~2^z) == y --> x == y || x == y|2^z // This undoes a transformation done by instcombine to fuse 2 compares. if (ICI->getPredicate() == (isEQ ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE)) { - // It's a little bit hard to see why the following transformations are // correct. Here is a CVC3 program to verify them for 64-bit values: @@ -770,6 +776,31 @@ static bool ValuesOverlap(std::vector<ValueEqualityComparisonCase> &C1, return false; } +// Set branch weights on SwitchInst. This sets the metadata if there is at +// least one non-zero weight. +static void setBranchWeights(SwitchInst *SI, ArrayRef<uint32_t> Weights) { + // Check that there is at least one non-zero weight. Otherwise, pass + // nullptr to setMetadata which will erase the existing metadata. + MDNode *N = nullptr; + if (llvm::any_of(Weights, [](uint32_t W) { return W != 0; })) + N = MDBuilder(SI->getParent()->getContext()).createBranchWeights(Weights); + SI->setMetadata(LLVMContext::MD_prof, N); +} + +// Similar to the above, but for branch and select instructions that take +// exactly 2 weights. +static void setBranchWeights(Instruction *I, uint32_t TrueWeight, + uint32_t FalseWeight) { + assert(isa<BranchInst>(I) || isa<SelectInst>(I)); + // Check that there is at least one non-zero weight. Otherwise, pass + // nullptr to setMetadata which will erase the existing metadata. + MDNode *N = nullptr; + if (TrueWeight || FalseWeight) + N = MDBuilder(I->getParent()->getContext()) + .createBranchWeights(TrueWeight, FalseWeight); + I->setMetadata(LLVMContext::MD_prof, N); +} + /// If TI is known to be a terminator instruction and its block is known to /// only have a single predecessor block, check to see if that predecessor is /// also a value comparison with the same value, and if that comparison @@ -859,9 +890,7 @@ bool SimplifyCFGOpt::SimplifyEqualityComparisonWithOnlyPredecessor( } } if (HasWeight && Weights.size() >= 2) - SI->setMetadata(LLVMContext::MD_prof, - MDBuilder(SI->getParent()->getContext()) - .createBranchWeights(Weights)); + setBranchWeights(SI, Weights); DEBUG(dbgs() << "Leaving: " << *TI << "\n"); return true; @@ -1166,9 +1195,7 @@ bool SimplifyCFGOpt::FoldValueComparisonIntoPredecessors(TerminatorInst *TI, SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - NewSI->setMetadata( - LLVMContext::MD_prof, - MDBuilder(BB->getContext()).createBranchWeights(MDWeights)); + setBranchWeights(NewSI, MDWeights); } EraseTerminatorInstAndDCECond(PTI); @@ -1279,9 +1306,7 @@ static bool HoistThenElseCodeToIf(BranchInst *BI, // I1 and I2 are being combined into a single instruction. Its debug // location is the merged locations of the original instructions. - if (!isa<CallInst>(I1)) - I1->setDebugLoc( - DILocation::getMergedLocation(I1->getDebugLoc(), I2->getDebugLoc())); + I1->applyMergedLocation(I1->getDebugLoc(), I2->getDebugLoc()); I2->eraseFromParent(); Changed = true; @@ -1535,20 +1560,20 @@ static bool sinkLastInstruction(ArrayRef<BasicBlock*> Blocks) { I0->getOperandUse(O).set(NewOperands[O]); I0->moveBefore(&*BBEnd->getFirstInsertionPt()); - // The debug location for the "common" instruction is the merged locations of - // all the commoned instructions. We start with the original location of the - // "common" instruction and iteratively merge each location in the loop below. - const DILocation *Loc = I0->getDebugLoc(); - // Update metadata and IR flags, and merge debug locations. for (auto *I : Insts) if (I != I0) { - Loc = DILocation::getMergedLocation(Loc, I->getDebugLoc()); + // The debug location for the "common" instruction is the merged locations + // of all the commoned instructions. We start with the original location + // of the "common" instruction and iteratively merge each location in the + // loop below. + // This is an N-way merge, which will be inefficient if I0 is a CallInst. + // However, as N-way merge for CallInst is rare, so we use simplified API + // instead of using complex API for N-way merge. + I0->applyMergedLocation(I0->getDebugLoc(), I->getDebugLoc()); combineMetadataForCSE(I0, I); I0->andIRFlags(I); } - if (!isa<CallInst>(I0)) - I0->setDebugLoc(Loc); if (!isa<StoreInst>(I0)) { // canSinkLastInstruction checked that all instructions were used by @@ -1582,9 +1607,9 @@ namespace { ArrayRef<BasicBlock*> Blocks; SmallVector<Instruction*,4> Insts; bool Fail; + public: - LockstepReverseIterator(ArrayRef<BasicBlock*> Blocks) : - Blocks(Blocks) { + LockstepReverseIterator(ArrayRef<BasicBlock*> Blocks) : Blocks(Blocks) { reset(); } @@ -1608,7 +1633,7 @@ namespace { return !Fail; } - void operator -- () { + void operator--() { if (Fail) return; for (auto *&Inst : Insts) { @@ -1916,6 +1941,8 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, // - All of their uses are in CondBB. SmallDenseMap<Instruction *, unsigned, 4> SinkCandidateUseCounts; + SmallVector<Instruction *, 4> SpeculatedDbgIntrinsics; + unsigned SpeculationCost = 0; Value *SpeculatedStoreValue = nullptr; StoreInst *SpeculatedStore = nullptr; @@ -1924,8 +1951,10 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, BBI != BBE; ++BBI) { Instruction *I = &*BBI; // Skip debug info. - if (isa<DbgInfoIntrinsic>(I)) + if (isa<DbgInfoIntrinsic>(I)) { + SpeculatedDbgIntrinsics.push_back(I); continue; + } // Only speculatively execute a single instruction (not counting the // terminator) for now. @@ -2030,11 +2059,10 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, if (Invert) std::swap(TrueV, FalseV); Value *S = Builder.CreateSelect( - BrCond, TrueV, FalseV, TrueV->getName() + "." + FalseV->getName(), BI); + BrCond, TrueV, FalseV, "spec.store.select", BI); SpeculatedStore->setOperand(0, S); - SpeculatedStore->setDebugLoc( - DILocation::getMergedLocation( - BI->getDebugLoc(), SpeculatedStore->getDebugLoc())); + SpeculatedStore->applyMergedLocation(BI->getDebugLoc(), + SpeculatedStore->getDebugLoc()); } // Metadata can be dependent on the condition we are hoisting above. @@ -2066,11 +2094,17 @@ static bool SpeculativelyExecuteBB(BranchInst *BI, BasicBlock *ThenBB, if (Invert) std::swap(TrueV, FalseV); Value *V = Builder.CreateSelect( - BrCond, TrueV, FalseV, TrueV->getName() + "." + FalseV->getName(), BI); + BrCond, TrueV, FalseV, "spec.select", BI); PN->setIncomingValue(OrigI, V); PN->setIncomingValue(ThenI, V); } + // Remove speculated dbg intrinsics. + // FIXME: Is it possible to do this in a more elegant way? Moving/merging the + // dbg value for the different flows and inserting it after the select. + for (Instruction *I : SpeculatedDbgIntrinsics) + I->eraseFromParent(); + ++NumSpeculations; return true; } @@ -2507,7 +2541,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { else { // For unconditional branch, check for a simple CFG pattern, where // BB has a single predecessor and BB's successor is also its predecessor's - // successor. If such pattern exisits, check for CSE between BB and its + // successor. If such pattern exists, check for CSE between BB and its // predecessor. if (BasicBlock *PB = BB->getSinglePredecessor()) if (BranchInst *PBI = dyn_cast<BranchInst>(PB->getTerminator())) @@ -2725,9 +2759,7 @@ bool llvm::FoldBranchToCommonDest(BranchInst *BI, unsigned BonusInstThreshold) { SmallVector<uint32_t, 8> MDWeights(NewWeights.begin(), NewWeights.end()); - PBI->setMetadata( - LLVMContext::MD_prof, - MDBuilder(BI->getContext()).createBranchWeights(MDWeights)); + setBranchWeights(PBI, MDWeights[0], MDWeights[1]); } else PBI->setMetadata(LLVMContext::MD_prof, nullptr); } else { @@ -2860,7 +2892,8 @@ static Value *ensureValueAvailableInSuccessor(Value *V, BasicBlock *BB, static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, BasicBlock *QTB, BasicBlock *QFB, BasicBlock *PostBB, Value *Address, - bool InvertPCond, bool InvertQCond) { + bool InvertPCond, bool InvertQCond, + const DataLayout &DL) { auto IsaBitcastOfPointerType = [](const Instruction &I) { return Operator::getOpcode(&I) == Instruction::BitCast && I.getType()->isPointerTy(); @@ -2887,7 +2920,9 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, else return false; } - return N <= PHINodeFoldingThreshold; + // The store we want to merge is counted in N, so add 1 to make sure + // we're counting the instructions that would be left. + return N <= (PHINodeFoldingThreshold + 1); }; if (!MergeCondStoresAggressively && @@ -2966,6 +3001,29 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, PStore->getAAMetadata(AAMD, /*Merge=*/false); PStore->getAAMetadata(AAMD, /*Merge=*/true); SI->setAAMetadata(AAMD); + unsigned PAlignment = PStore->getAlignment(); + unsigned QAlignment = QStore->getAlignment(); + unsigned TypeAlignment = + DL.getABITypeAlignment(SI->getValueOperand()->getType()); + unsigned MinAlignment; + unsigned MaxAlignment; + std::tie(MinAlignment, MaxAlignment) = std::minmax(PAlignment, QAlignment); + // Choose the minimum alignment. If we could prove both stores execute, we + // could use biggest one. In this case, though, we only know that one of the + // stores executes. And we don't know it's safe to take the alignment from a + // store that doesn't execute. + if (MinAlignment != 0) { + // Choose the minimum of all non-zero alignments. + SI->setAlignment(MinAlignment); + } else if (MaxAlignment != 0) { + // Choose the minimal alignment between the non-zero alignment and the ABI + // default alignment for the type of the stored value. + SI->setAlignment(std::min(MaxAlignment, TypeAlignment)); + } else { + // If both alignments are zero, use ABI default alignment for the type of + // the stored value. + SI->setAlignment(TypeAlignment); + } QStore->eraseFromParent(); PStore->eraseFromParent(); @@ -2973,7 +3031,8 @@ static bool mergeConditionalStoreToAddress(BasicBlock *PTB, BasicBlock *PFB, return true; } -static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { +static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI, + const DataLayout &DL) { // The intention here is to find diamonds or triangles (see below) where each // conditional block contains a store to the same address. Both of these // stores are conditional, so they can't be unconditionally sunk. But it may @@ -3001,7 +3060,6 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { // We model triangles as a type of diamond with a nullptr "true" block. // Triangles are canonicalized so that the fallthrough edge is represented by // a true condition, as in the diagram above. - // BasicBlock *PTB = PBI->getSuccessor(0); BasicBlock *PFB = PBI->getSuccessor(1); BasicBlock *QTB = QBI->getSuccessor(0); @@ -3076,7 +3134,7 @@ static bool mergeConditionalStores(BranchInst *PBI, BranchInst *QBI) { bool Changed = false; for (auto *Address : CommonAddresses) Changed |= mergeConditionalStoreToAddress( - PTB, PFB, QTB, QFB, PostBB, Address, InvertPCond, InvertQCond); + PTB, PFB, QTB, QFB, PostBB, Address, InvertPCond, InvertQCond, DL); return Changed; } @@ -3141,7 +3199,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // If both branches are conditional and both contain stores to the same // address, remove the stores from the conditionals and create a conditional // merged store at the end. - if (MergeCondStores && mergeConditionalStores(PBI, BI)) + if (MergeCondStores && mergeConditionalStores(PBI, BI, DL)) return true; // If this is a conditional branch in an empty block, and if any @@ -3270,9 +3328,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, // Halve the weights if any of them cannot fit in an uint32_t FitWeights(NewWeights); - PBI->setMetadata(LLVMContext::MD_prof, - MDBuilder(BI->getContext()) - .createBranchWeights(NewWeights[0], NewWeights[1])); + setBranchWeights(PBI, NewWeights[0], NewWeights[1]); } // OtherDest may have phi nodes. If so, add an entry from PBI's @@ -3310,9 +3366,7 @@ static bool SimplifyCondBranchToCondBranch(BranchInst *PBI, BranchInst *BI, FitWeights(NewWeights); - NV->setMetadata(LLVMContext::MD_prof, - MDBuilder(BI->getContext()) - .createBranchWeights(NewWeights[0], NewWeights[1])); + setBranchWeights(NV, NewWeights[0], NewWeights[1]); } } } @@ -3367,9 +3421,7 @@ static bool SimplifyTerminatorOnSelect(TerminatorInst *OldTerm, Value *Cond, // Create a conditional branch sharing the condition of the select. BranchInst *NewBI = Builder.CreateCondBr(Cond, TrueBB, FalseBB); if (TrueWeight != FalseWeight) - NewBI->setMetadata(LLVMContext::MD_prof, - MDBuilder(OldTerm->getContext()) - .createBranchWeights(TrueWeight, FalseWeight)); + setBranchWeights(NewBI, TrueWeight, FalseWeight); } } else if (KeepEdge1 && (KeepEdge2 || TrueBB == FalseBB)) { // Neither of the selected blocks were successors, so this @@ -3464,10 +3516,9 @@ static bool SimplifyIndirectBrOnSelect(IndirectBrInst *IBI, SelectInst *SI) { /// /// We prefer to split the edge to 'end' so that there is a true/false entry to /// the PHI, merging the third icmp into the switch. -static bool TryToSimplifyUncondBranchWithICmpInIt( +static bool tryToSimplifyUncondBranchWithICmpInIt( ICmpInst *ICI, IRBuilder<> &Builder, const DataLayout &DL, - const TargetTransformInfo &TTI, unsigned BonusInstThreshold, - AssumptionCache *AC) { + const TargetTransformInfo &TTI, const SimplifyCFGOptions &Options) { BasicBlock *BB = ICI->getParent(); // If the block has any PHIs in it or the icmp has multiple uses, it is too @@ -3502,7 +3553,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->eraseFromParent(); } // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } // Ok, the block is reachable from the default dest. If the constant we're @@ -3518,7 +3569,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( ICI->replaceAllUsesWith(V); ICI->eraseFromParent(); // BB is now empty, so it is likely to simplify away. - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } // The use of the icmp has to be in the 'end' block, by the only PHI node in @@ -3556,9 +3607,7 @@ static bool TryToSimplifyUncondBranchWithICmpInIt( Weights.push_back(Weights[0]); SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - SI->setMetadata( - LLVMContext::MD_prof, - MDBuilder(SI->getContext()).createBranchWeights(MDWeights)); + setBranchWeights(SI, MDWeights); } } SI->addCase(Cst, NewBB); @@ -4285,10 +4334,7 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { TrueWeight /= 2; FalseWeight /= 2; } - NewBI->setMetadata(LLVMContext::MD_prof, - MDBuilder(SI->getContext()) - .createBranchWeights((uint32_t)TrueWeight, - (uint32_t)FalseWeight)); + setBranchWeights(NewBI, TrueWeight, FalseWeight); } } @@ -4316,7 +4362,7 @@ static bool TurnSwitchRangeIntoICmp(SwitchInst *SI, IRBuilder<> &Builder) { /// Compute masked bits for the condition of a switch /// and use it to remove dead cases. -static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, +static bool eliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, const DataLayout &DL) { Value *Cond = SI->getCondition(); unsigned Bits = Cond->getType()->getIntegerBitWidth(); @@ -4385,9 +4431,7 @@ static bool EliminateDeadSwitchCases(SwitchInst *SI, AssumptionCache *AC, } if (HasWeight && Weights.size() >= 2) { SmallVector<uint32_t, 8> MDWeights(Weights.begin(), Weights.end()); - SI->setMetadata(LLVMContext::MD_prof, - MDBuilder(SI->getParent()->getContext()) - .createBranchWeights(MDWeights)); + setBranchWeights(SI, MDWeights); } return !DeadCases.empty(); @@ -4429,38 +4473,59 @@ static PHINode *FindPHIForConditionForwarding(ConstantInt *CaseValue, /// Try to forward the condition of a switch instruction to a phi node /// dominated by the switch, if that would mean that some of the destination -/// blocks of the switch can be folded away. -/// Returns true if a change is made. +/// blocks of the switch can be folded away. Return true if a change is made. static bool ForwardSwitchConditionToPHI(SwitchInst *SI) { - typedef DenseMap<PHINode *, SmallVector<int, 4>> ForwardingNodesMap; - ForwardingNodesMap ForwardingNodes; + using ForwardingNodesMap = DenseMap<PHINode *, SmallVector<int, 4>>; - for (auto Case : SI->cases()) { + ForwardingNodesMap ForwardingNodes; + BasicBlock *SwitchBlock = SI->getParent(); + bool Changed = false; + for (auto &Case : SI->cases()) { ConstantInt *CaseValue = Case.getCaseValue(); BasicBlock *CaseDest = Case.getCaseSuccessor(); - int PhiIndex; - PHINode *PHI = - FindPHIForConditionForwarding(CaseValue, CaseDest, &PhiIndex); - if (!PHI) - continue; + // Replace phi operands in successor blocks that are using the constant case + // value rather than the switch condition variable: + // switchbb: + // switch i32 %x, label %default [ + // i32 17, label %succ + // ... + // succ: + // %r = phi i32 ... [ 17, %switchbb ] ... + // --> + // %r = phi i32 ... [ %x, %switchbb ] ... + + for (Instruction &InstInCaseDest : *CaseDest) { + auto *Phi = dyn_cast<PHINode>(&InstInCaseDest); + if (!Phi) break; + + // This only works if there is exactly 1 incoming edge from the switch to + // a phi. If there is >1, that means multiple cases of the switch map to 1 + // value in the phi, and that phi value is not the switch condition. Thus, + // this transform would not make sense (the phi would be invalid because + // a phi can't have different incoming values from the same block). + int SwitchBBIdx = Phi->getBasicBlockIndex(SwitchBlock); + if (Phi->getIncomingValue(SwitchBBIdx) == CaseValue && + count(Phi->blocks(), SwitchBlock) == 1) { + Phi->setIncomingValue(SwitchBBIdx, SI->getCondition()); + Changed = true; + } + } - ForwardingNodes[PHI].push_back(PhiIndex); + // Collect phi nodes that are indirectly using this switch's case constants. + int PhiIdx; + if (auto *Phi = FindPHIForConditionForwarding(CaseValue, CaseDest, &PhiIdx)) + ForwardingNodes[Phi].push_back(PhiIdx); } - bool Changed = false; - - for (ForwardingNodesMap::iterator I = ForwardingNodes.begin(), - E = ForwardingNodes.end(); - I != E; ++I) { - PHINode *Phi = I->first; - SmallVectorImpl<int> &Indexes = I->second; - + for (auto &ForwardingNode : ForwardingNodes) { + PHINode *Phi = ForwardingNode.first; + SmallVectorImpl<int> &Indexes = ForwardingNode.second; if (Indexes.size() < 2) continue; - for (size_t I = 0, E = Indexes.size(); I != E; ++I) - Phi->setIncomingValue(Indexes[I], SI->getCondition()); + for (int Index : Indexes) + Phi->setIncomingValue(Index, SI->getCondition()); Changed = true; } @@ -4743,8 +4808,8 @@ static void RemoveSwitchAfterSelectConversion(SwitchInst *SI, PHINode *PHI, /// If the switch is only used to initialize one or more /// phi nodes in a common successor block with only two different /// constant values, replace the switch with select. -static bool SwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder, - AssumptionCache *AC, const DataLayout &DL, +static bool switchToSelect(SwitchInst *SI, IRBuilder<> &Builder, + const DataLayout &DL, const TargetTransformInfo &TTI) { Value *const Cond = SI->getCondition(); PHINode *PHI = nullptr; @@ -4816,18 +4881,18 @@ private: } Kind; // For SingleValueKind, this is the single value. - Constant *SingleValue; + Constant *SingleValue = nullptr; // For BitMapKind, this is the bitmap. - ConstantInt *BitMap; - IntegerType *BitMapElementTy; + ConstantInt *BitMap = nullptr; + IntegerType *BitMapElementTy = nullptr; // For LinearMapKind, these are the constants used to derive the value. - ConstantInt *LinearOffset; - ConstantInt *LinearMultiplier; + ConstantInt *LinearOffset = nullptr; + ConstantInt *LinearMultiplier = nullptr; // For ArrayKind, this is the array. - GlobalVariable *Array; + GlobalVariable *Array = nullptr; }; } // end anonymous namespace @@ -4835,9 +4900,7 @@ private: SwitchLookupTable::SwitchLookupTable( Module &M, uint64_t TableSize, ConstantInt *Offset, const SmallVectorImpl<std::pair<ConstantInt *, Constant *>> &Values, - Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName) - : SingleValue(nullptr), BitMap(nullptr), BitMapElementTy(nullptr), - LinearOffset(nullptr), LinearMultiplier(nullptr), Array(nullptr) { + Constant *DefaultValue, const DataLayout &DL, const StringRef &FuncName) { assert(Values.size() && "Can't build lookup table without values!"); assert(TableSize >= Values.size() && "Can't fit values in table!"); @@ -5083,7 +5146,6 @@ static void reuseTableCompare( User *PhiUser, BasicBlock *PhiBlock, BranchInst *RangeCheckBranch, Constant *DefaultValue, const SmallVectorImpl<std::pair<ConstantInt *, Constant *>> &Values) { - ICmpInst *CmpInst = dyn_cast<ICmpInst>(PhiUser); if (!CmpInst) return; @@ -5112,7 +5174,7 @@ static void reuseTableCompare( for (auto ValuePair : Values) { Constant *CaseConst = ConstantExpr::getICmp(CmpInst->getPredicate(), ValuePair.second, CmpOp1, true); - if (!CaseConst || CaseConst == DefaultConst) + if (!CaseConst || CaseConst == DefaultConst || isa<UndefValue>(CaseConst)) return; assert((CaseConst == TrueConst || CaseConst == FalseConst) && "Expect true or false as compare result."); @@ -5151,8 +5213,11 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, const TargetTransformInfo &TTI) { assert(SI->getNumCases() > 1 && "Degenerate switch?"); - // Only build lookup table when we have a target that supports it. - if (!TTI.shouldBuildLookupTables()) + Function *Fn = SI->getParent()->getParent(); + // Only build lookup table when we have a target that supports it or the + // attribute is not set. + if (!TTI.shouldBuildLookupTables() || + (Fn->getFnAttribute("no-jump-tables").getValueAsString() == "true")) return false; // FIXME: If the switch is too sparse for a lookup table, perhaps we could @@ -5163,8 +5228,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // string and lookup indices into that. // Ignore switches with less than three cases. Lookup tables will not make - // them - // faster, so we don't analyze them. + // them faster, so we don't analyze them. if (SI->getNumCases() < 3) return false; @@ -5176,8 +5240,10 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, ConstantInt *MaxCaseVal = CI->getCaseValue(); BasicBlock *CommonDest = nullptr; - typedef SmallVector<std::pair<ConstantInt *, Constant *>, 4> ResultListTy; + + using ResultListTy = SmallVector<std::pair<ConstantInt *, Constant *>, 4>; SmallDenseMap<PHINode *, ResultListTy> ResultLists; + SmallDenseMap<PHINode *, Constant *> DefaultResults; SmallDenseMap<PHINode *, Type *> ResultTypes; SmallVector<PHINode *, 4> PHIs; @@ -5190,7 +5256,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, MaxCaseVal = CaseVal; // Resulting value at phi nodes for this case value. - typedef SmallVector<std::pair<PHINode *, Constant *>, 4> ResultsTy; + using ResultsTy = SmallVector<std::pair<PHINode *, Constant *>, 4>; ResultsTy Results; if (!GetCaseResults(SI, CaseVal, CI->getCaseSuccessor(), &CommonDest, Results, DL, TTI)) @@ -5248,8 +5314,12 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // Compute the table index value. Builder.SetInsertPoint(SI); - Value *TableIndex = - Builder.CreateSub(SI->getCondition(), MinCaseVal, "switch.tableidx"); + Value *TableIndex; + if (MinCaseVal->isNullValue()) + TableIndex = SI->getCondition(); + else + TableIndex = Builder.CreateSub(SI->getCondition(), MinCaseVal, + "switch.tableidx"); // Compute the maximum table size representable by the integer type we are // switching upon. @@ -5282,15 +5352,14 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, Builder.SetInsertPoint(LookupBB); if (NeedMask) { - // Before doing the lookup we do the hole check. - // The LookupBB is therefore re-purposed to do the hole check - // and we create a new LookupBB. + // Before doing the lookup, we do the hole check. The LookupBB is therefore + // re-purposed to do the hole check, and we create a new LookupBB. BasicBlock *MaskBB = LookupBB; MaskBB->setName("switch.hole_check"); LookupBB = BasicBlock::Create(Mod.getContext(), "switch.lookup", CommonDest->getParent(), CommonDest); - // Make the mask's bitwidth at least 8bit and a power-of-2 to avoid + // Make the mask's bitwidth at least 8-bit and a power-of-2 to avoid // unnecessary illegal types. uint64_t TableSizePowOf2 = NextPowerOf2(std::max(7ULL, TableSize - 1ULL)); APInt MaskInt(TableSizePowOf2, 0); @@ -5320,7 +5389,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, } if (!DefaultIsReachable || GeneratingCoveredLookupTable) { - // We cached PHINodes in PHIs, to avoid accessing deleted PHINodes later, + // We cached PHINodes in PHIs. To avoid accessing deleted PHINodes later, // do not delete PHINodes here. SI->getDefaultDest()->removePredecessor(SI->getParent(), /*DontDeleteUselessPHIs=*/true); @@ -5333,7 +5402,7 @@ static bool SwitchToLookupTable(SwitchInst *SI, IRBuilder<> &Builder, // If using a bitmask, use any value to fill the lookup table holes. Constant *DV = NeedMask ? ResultLists[PHI][0].second : DefaultResults[PHI]; - StringRef FuncName = SI->getParent()->getParent()->getName(); + StringRef FuncName = Fn->getName(); SwitchLookupTable Table(Mod, TableSize, MinCaseVal, ResultList, DV, DL, FuncName); @@ -5391,14 +5460,14 @@ static bool isSwitchDense(ArrayRef<int64_t> Values) { return NumCases * 100 >= Range * MinDensity; } -// Try and transform a switch that has "holes" in it to a contiguous sequence -// of cases. -// -// A switch such as: switch(i) {case 5: case 9: case 13: case 17:} can be -// range-reduced to: switch ((i-5) / 4) {case 0: case 1: case 2: case 3:}. -// -// This converts a sparse switch into a dense switch which allows better -// lowering and could also allow transforming into a lookup table. +/// Try to transform a switch that has "holes" in it to a contiguous sequence +/// of cases. +/// +/// A switch such as: switch(i) {case 5: case 9: case 13: case 17:} can be +/// range-reduced to: switch ((i-5) / 4) {case 0: case 1: case 2: case 3:}. +/// +/// This converts a sparse switch into a dense switch which allows better +/// lowering and could also allow transforming into a lookup table. static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, const DataLayout &DL, const TargetTransformInfo &TTI) { @@ -5427,7 +5496,7 @@ static bool ReduceSwitchRange(SwitchInst *SI, IRBuilder<> &Builder, // First, transform the values such that they start at zero and ascend. int64_t Base = Values[0]; for (auto &V : Values) - V -= Base; + V -= (uint64_t)(Base); // Now we have signed numbers that have been shifted so that, given enough // precision, there are no negative values. Since the rest of the transform @@ -5492,12 +5561,12 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { // see if that predecessor totally determines the outcome of this switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(SI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; Value *Cond = SI->getCondition(); if (SelectInst *Select = dyn_cast<SelectInst>(Cond)) if (SimplifySwitchOnSelect(SI, Select)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; // If the block only contains the switch, see if we can fold the block // away into any preds. @@ -5507,33 +5576,34 @@ bool SimplifyCFGOpt::SimplifySwitch(SwitchInst *SI, IRBuilder<> &Builder) { ++BBI; if (SI == &*BBI) if (FoldValueComparisonIntoPredecessors(SI, Builder)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } // Try to transform the switch into an icmp and a branch. if (TurnSwitchRangeIntoICmp(SI, Builder)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; // Remove unreachable cases. - if (EliminateDeadSwitchCases(SI, AC, DL)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (eliminateDeadSwitchCases(SI, Options.AC, DL)) + return simplifyCFG(BB, TTI, Options) | true; - if (SwitchToSelect(SI, Builder, AC, DL, TTI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (switchToSelect(SI, Builder, DL, TTI)) + return simplifyCFG(BB, TTI, Options) | true; - if (ForwardSwitchConditionToPHI(SI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (Options.ForwardSwitchCondToPhi && ForwardSwitchConditionToPHI(SI)) + return simplifyCFG(BB, TTI, Options) | true; - // The conversion from switch to lookup tables results in difficult - // to analyze code and makes pruning branches much harder. - // This is a problem of the switch expression itself can still be - // restricted as a result of inlining or CVP. There only apply this - // transformation during late steps of the optimisation chain. - if (LateSimplifyCFG && SwitchToLookupTable(SI, Builder, DL, TTI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + // The conversion from switch to lookup tables results in difficult-to-analyze + // code and makes pruning branches much harder. This is a problem if the + // switch expression itself can still be restricted as a result of inlining or + // CVP. Therefore, only apply this transformation during late stages of the + // optimisation pipeline. + if (Options.ConvertSwitchToLookupTable && + SwitchToLookupTable(SI, Builder, DL, TTI)) + return simplifyCFG(BB, TTI, Options) | true; if (ReduceSwitchRange(SI, Builder, DL, TTI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; return false; } @@ -5571,7 +5641,7 @@ bool SimplifyCFGOpt::SimplifyIndirectBr(IndirectBrInst *IBI) { if (SelectInst *SI = dyn_cast<SelectInst>(IBI->getAddress())) { if (SimplifyIndirectBrOnSelect(IBI, SI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } return Changed; } @@ -5613,8 +5683,8 @@ static bool TryToMergeLandingPad(LandingPadInst *LPad, BranchInst *BI, LandingPadInst *LPad2 = dyn_cast<LandingPadInst>(I); if (!LPad2 || !LPad2->isIdenticalTo(LPad)) continue; - for (++I; isa<DbgInfoIntrinsic>(I); ++I) { - } + for (++I; isa<DbgInfoIntrinsic>(I); ++I) + ; BranchInst *BI2 = dyn_cast<BranchInst>(I); if (!BI2 || !BI2->isIdenticalTo(BI)) continue; @@ -5658,39 +5728,38 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, BasicBlock *BB = BI->getParent(); BasicBlock *Succ = BI->getSuccessor(0); - if (SinkCommon && SinkThenElseCodeToEnd(BI)) + if (SinkCommon && Options.SinkCommonInsts && SinkThenElseCodeToEnd(BI)) return true; // If the Terminator is the only non-phi instruction, simplify the block. - // if LoopHeader is provided, check if the block or its successor is a loop - // header (This is for early invocations before loop simplify and + // If LoopHeader is provided, check if the block or its successor is a loop + // header. (This is for early invocations before loop simplify and // vectorization to keep canonical loop forms for nested loops. These blocks // can be eliminated when the pass is invoked later in the back-end.) bool NeedCanonicalLoop = - !LateSimplifyCFG && + Options.NeedCanonicalLoop && (LoopHeaders && (LoopHeaders->count(BB) || LoopHeaders->count(Succ))); BasicBlock::iterator I = BB->getFirstNonPHIOrDbg()->getIterator(); if (I->isTerminator() && BB != &BB->getParent()->getEntryBlock() && !NeedCanonicalLoop && TryToSimplifyUncondBranchFromEmptyBlock(BB)) return true; - // If the only instruction in the block is a seteq/setne comparison - // against a constant, try to simplify the block. + // If the only instruction in the block is a seteq/setne comparison against a + // constant, try to simplify the block. if (ICmpInst *ICI = dyn_cast<ICmpInst>(I)) if (ICI->isEquality() && isa<ConstantInt>(ICI->getOperand(1))) { for (++I; isa<DbgInfoIntrinsic>(I); ++I) ; if (I->isTerminator() && - TryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, DL, TTI, - BonusInstThreshold, AC)) + tryToSimplifyUncondBranchWithICmpInIt(ICI, Builder, DL, TTI, Options)) return true; } // See if we can merge an empty landing pad block with another which is // equivalent. if (LandingPadInst *LPad = dyn_cast<LandingPadInst>(I)) { - for (++I; isa<DbgInfoIntrinsic>(I); ++I) { - } + for (++I; isa<DbgInfoIntrinsic>(I); ++I) + ; if (I->isTerminator() && TryToMergeLandingPad(LPad, BI, BB)) return true; } @@ -5699,8 +5768,8 @@ bool SimplifyCFGOpt::SimplifyUncondBranch(BranchInst *BI, // branches to us and our successor, fold the comparison into the // predecessor and use logical operations to update the incoming value // for PHI nodes in common successor. - if (FoldBranchToCommonDest(BI, BonusInstThreshold)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (FoldBranchToCommonDest(BI, Options.BonusInstThreshold)) + return simplifyCFG(BB, TTI, Options) | true; return false; } @@ -5725,7 +5794,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // switch. if (BasicBlock *OnlyPred = BB->getSinglePredecessor()) if (SimplifyEqualityComparisonWithOnlyPredecessor(BI, OnlyPred, Builder)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; // This block must be empty, except for the setcond inst, if it exists. // Ignore dbg intrinsics. @@ -5735,14 +5804,14 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { ++I; if (&*I == BI) { if (FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } else if (&*I == cast<Instruction>(BI->getCondition())) { ++I; // Ignore dbg intrinsics. while (isa<DbgInfoIntrinsic>(I)) ++I; if (&*I == BI && FoldValueComparisonIntoPredecessors(BI, Builder)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } } @@ -5758,9 +5827,9 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (PBI && PBI->isConditional() && PBI->getSuccessor(0) != PBI->getSuccessor(1)) { assert(PBI->getSuccessor(0) == BB || PBI->getSuccessor(1) == BB); - bool CondIsFalse = PBI->getSuccessor(1) == BB; + bool CondIsTrue = PBI->getSuccessor(0) == BB; Optional<bool> Implication = isImpliedCondition( - PBI->getCondition(), BI->getCondition(), DL, CondIsFalse); + PBI->getCondition(), BI->getCondition(), DL, CondIsTrue); if (Implication) { // Turn this into a branch on constant. auto *OldCond = BI->getCondition(); @@ -5769,7 +5838,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { : ConstantInt::getFalse(BB->getContext()); BI->setCondition(CI); RecursivelyDeleteTriviallyDeadInstructions(OldCond); - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } } } @@ -5777,8 +5846,8 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { // If this basic block is ONLY a compare and a branch, and if a predecessor // branches to us and one of our successors, fold the comparison into the // predecessor and use logical operations to pick the right destination. - if (FoldBranchToCommonDest(BI, BonusInstThreshold)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (FoldBranchToCommonDest(BI, Options.BonusInstThreshold)) + return simplifyCFG(BB, TTI, Options) | true; // We have a conditional branch to two blocks that are only reachable // from BI. We know that the condbr dominates the two blocks, so see if @@ -5787,7 +5856,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (BI->getSuccessor(0)->getSinglePredecessor()) { if (BI->getSuccessor(1)->getSinglePredecessor()) { if (HoistThenElseCodeToIf(BI, TTI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } else { // If Successor #1 has multiple preds, we may be able to conditionally // execute Successor #0 if it branches to Successor #1. @@ -5795,7 +5864,7 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ0TI->getNumSuccessors() == 1 && Succ0TI->getSuccessor(0) == BI->getSuccessor(1)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(0), TTI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } } else if (BI->getSuccessor(1)->getSinglePredecessor()) { // If Successor #0 has multiple preds, we may be able to conditionally @@ -5804,30 +5873,30 @@ bool SimplifyCFGOpt::SimplifyCondBranch(BranchInst *BI, IRBuilder<> &Builder) { if (Succ1TI->getNumSuccessors() == 1 && Succ1TI->getSuccessor(0) == BI->getSuccessor(0)) if (SpeculativelyExecuteBB(BI, BI->getSuccessor(1), TTI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; } // If this is a branch on a phi node in the current block, thread control // through this block if any PHI node entries are constants. if (PHINode *PN = dyn_cast<PHINode>(BI->getCondition())) if (PN->getParent() == BI->getParent()) - if (FoldCondBranchOnPHI(BI, DL, AC)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (FoldCondBranchOnPHI(BI, DL, Options.AC)) + return simplifyCFG(BB, TTI, Options) | true; // Scan predecessor blocks for conditional branches. for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) if (BranchInst *PBI = dyn_cast<BranchInst>((*PI)->getTerminator())) if (PBI != BI && PBI->isConditional()) if (SimplifyCondBranchToCondBranch(PBI, BI, DL)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + return simplifyCFG(BB, TTI, Options) | true; // Look for diamond patterns. if (MergeCondStores) if (BasicBlock *PrevBB = allPredecessorsComeFromSameSource(BB)) if (BranchInst *PBI = dyn_cast<BranchInst>(PrevBB->getTerminator())) if (PBI != BI && PBI->isConditional()) - if (mergeConditionalStores(PBI, BI)) - return SimplifyCFG(BB, TTI, BonusInstThreshold, AC) | true; + if (mergeConditionalStores(PBI, BI, DL)) + return simplifyCFG(BB, TTI, Options) | true; return false; } @@ -5936,7 +6005,6 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { // Merge basic blocks into their predecessor if there is only one distinct // pred, and if there is only one distinct successor of the predecessor, and // if there are no PHI nodes. - // if (MergeBlockIntoPredecessor(BB)) return true; @@ -5944,12 +6012,12 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { // If there is a trivial two-entry PHI node in this basic block, and we can // eliminate it, do so now. - if (PHINode *PN = dyn_cast<PHINode>(BB->begin())) + if (auto *PN = dyn_cast<PHINode>(BB->begin())) if (PN->getNumIncomingValues() == 2) Changed |= FoldTwoEntryPHINode(PN, TTI, DL); Builder.SetInsertPoint(BB->getTerminator()); - if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { + if (auto *BI = dyn_cast<BranchInst>(BB->getTerminator())) { if (BI->isUnconditional()) { if (SimplifyUncondBranch(BI, Builder)) return true; @@ -5957,25 +6025,22 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { if (SimplifyCondBranch(BI, Builder)) return true; } - } else if (ReturnInst *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { + } else if (auto *RI = dyn_cast<ReturnInst>(BB->getTerminator())) { if (SimplifyReturn(RI, Builder)) return true; - } else if (ResumeInst *RI = dyn_cast<ResumeInst>(BB->getTerminator())) { + } else if (auto *RI = dyn_cast<ResumeInst>(BB->getTerminator())) { if (SimplifyResume(RI, Builder)) return true; - } else if (CleanupReturnInst *RI = - dyn_cast<CleanupReturnInst>(BB->getTerminator())) { + } else if (auto *RI = dyn_cast<CleanupReturnInst>(BB->getTerminator())) { if (SimplifyCleanupReturn(RI)) return true; - } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { + } else if (auto *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { if (SimplifySwitch(SI, Builder)) return true; - } else if (UnreachableInst *UI = - dyn_cast<UnreachableInst>(BB->getTerminator())) { + } else if (auto *UI = dyn_cast<UnreachableInst>(BB->getTerminator())) { if (SimplifyUnreachable(UI)) return true; - } else if (IndirectBrInst *IBI = - dyn_cast<IndirectBrInst>(BB->getTerminator())) { + } else if (auto *IBI = dyn_cast<IndirectBrInst>(BB->getTerminator())) { if (SimplifyIndirectBr(IBI)) return true; } @@ -5983,16 +6048,10 @@ bool SimplifyCFGOpt::run(BasicBlock *BB) { return Changed; } -/// This function is used to do simplification of a CFG. -/// For example, it adjusts branches to branches to eliminate the extra hop, -/// eliminates unreachable basic blocks, and does other "peephole" optimization -/// of the CFG. It returns true if a modification was made. -/// -bool llvm::SimplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, - unsigned BonusInstThreshold, AssumptionCache *AC, - SmallPtrSetImpl<BasicBlock *> *LoopHeaders, - bool LateSimplifyCFG) { - return SimplifyCFGOpt(TTI, BB->getModule()->getDataLayout(), - BonusInstThreshold, AC, LoopHeaders, LateSimplifyCFG) +bool llvm::simplifyCFG(BasicBlock *BB, const TargetTransformInfo &TTI, + const SimplifyCFGOptions &Options, + SmallPtrSetImpl<BasicBlock *> *LoopHeaders) { + return SimplifyCFGOpt(TTI, BB->getModule()->getDataLayout(), LoopHeaders, + Options) .run(BB); } diff --git a/lib/Transforms/Utils/SimplifyIndVar.cpp b/lib/Transforms/Utils/SimplifyIndVar.cpp index 6d90e6b48358..ad1faea0a7ae 100644 --- a/lib/Transforms/Utils/SimplifyIndVar.cpp +++ b/lib/Transforms/Utils/SimplifyIndVar.cpp @@ -18,13 +18,11 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/LoopInfo.h" -#include "llvm/Analysis/LoopPass.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/Instructions.h" -#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/PatternMatch.h" #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" @@ -35,10 +33,14 @@ using namespace llvm; STATISTIC(NumElimIdentity, "Number of IV identities eliminated"); STATISTIC(NumElimOperand, "Number of IV operands folded into a use"); +STATISTIC(NumFoldedUser, "Number of IV users folded into a constant"); STATISTIC(NumElimRem , "Number of IV remainder operations eliminated"); STATISTIC( NumSimplifiedSDiv, "Number of IV signed division operations converted to unsigned division"); +STATISTIC( + NumSimplifiedSRem, + "Number of IV signed remainder operations converted to unsigned remainder"); STATISTIC(NumElimCmp , "Number of IV comparisons eliminated"); namespace { @@ -51,15 +53,17 @@ namespace { LoopInfo *LI; ScalarEvolution *SE; DominatorTree *DT; - + SCEVExpander &Rewriter; SmallVectorImpl<WeakTrackingVH> &DeadInsts; bool Changed; public: SimplifyIndvar(Loop *Loop, ScalarEvolution *SE, DominatorTree *DT, - LoopInfo *LI, SmallVectorImpl<WeakTrackingVH> &Dead) - : L(Loop), LI(LI), SE(SE), DT(DT), DeadInsts(Dead), Changed(false) { + LoopInfo *LI, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &Dead) + : L(Loop), LI(LI), SE(SE), DT(DT), Rewriter(Rewriter), DeadInsts(Dead), + Changed(false) { assert(LI && "IV simplification requires LoopInfo"); } @@ -73,12 +77,17 @@ namespace { Value *foldIVUser(Instruction *UseInst, Instruction *IVOperand); bool eliminateIdentitySCEV(Instruction *UseInst, Instruction *IVOperand); + bool replaceIVUserWithLoopInvariant(Instruction *UseInst); bool eliminateOverflowIntrinsic(CallInst *CI); bool eliminateIVUser(Instruction *UseInst, Instruction *IVOperand); + bool makeIVComparisonInvariant(ICmpInst *ICmp, Value *IVOperand); void eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand); - void eliminateIVRemainder(BinaryOperator *Rem, Value *IVOperand, - bool IsSigned); + void simplifyIVRemainder(BinaryOperator *Rem, Value *IVOperand, + bool IsSigned); + void replaceRemWithNumerator(BinaryOperator *Rem); + void replaceRemWithNumeratorOrZero(BinaryOperator *Rem); + void replaceSRemWithURem(BinaryOperator *Rem); bool eliminateSDiv(BinaryOperator *SDiv); bool strengthenOverflowingOperation(BinaryOperator *OBO, Value *IVOperand); bool strengthenRightShift(BinaryOperator *BO, Value *IVOperand); @@ -151,6 +160,74 @@ Value *SimplifyIndvar::foldIVUser(Instruction *UseInst, Instruction *IVOperand) return IVSrc; } +bool SimplifyIndvar::makeIVComparisonInvariant(ICmpInst *ICmp, + Value *IVOperand) { + unsigned IVOperIdx = 0; + ICmpInst::Predicate Pred = ICmp->getPredicate(); + if (IVOperand != ICmp->getOperand(0)) { + // Swapped + assert(IVOperand == ICmp->getOperand(1) && "Can't find IVOperand"); + IVOperIdx = 1; + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + // Get the SCEVs for the ICmp operands (in the specific context of the + // current loop) + const Loop *ICmpLoop = LI->getLoopFor(ICmp->getParent()); + const SCEV *S = SE->getSCEVAtScope(ICmp->getOperand(IVOperIdx), ICmpLoop); + const SCEV *X = SE->getSCEVAtScope(ICmp->getOperand(1 - IVOperIdx), ICmpLoop); + + ICmpInst::Predicate InvariantPredicate; + const SCEV *InvariantLHS, *InvariantRHS; + + auto *PN = dyn_cast<PHINode>(IVOperand); + if (!PN) + return false; + if (!SE->isLoopInvariantPredicate(Pred, S, X, L, InvariantPredicate, + InvariantLHS, InvariantRHS)) + return false; + + // Rewrite the comparison to a loop invariant comparison if it can be done + // cheaply, where cheaply means "we don't need to emit any new + // instructions". + + SmallDenseMap<const SCEV*, Value*> CheapExpansions; + CheapExpansions[S] = ICmp->getOperand(IVOperIdx); + CheapExpansions[X] = ICmp->getOperand(1 - IVOperIdx); + + // TODO: Support multiple entry loops? (We currently bail out of these in + // the IndVarSimplify pass) + if (auto *BB = L->getLoopPredecessor()) { + const int Idx = PN->getBasicBlockIndex(BB); + if (Idx >= 0) { + Value *Incoming = PN->getIncomingValue(Idx); + const SCEV *IncomingS = SE->getSCEV(Incoming); + CheapExpansions[IncomingS] = Incoming; + } + } + Value *NewLHS = CheapExpansions[InvariantLHS]; + Value *NewRHS = CheapExpansions[InvariantRHS]; + + if (!NewLHS) + if (auto *ConstLHS = dyn_cast<SCEVConstant>(InvariantLHS)) + NewLHS = ConstLHS->getValue(); + if (!NewRHS) + if (auto *ConstRHS = dyn_cast<SCEVConstant>(InvariantRHS)) + NewRHS = ConstRHS->getValue(); + + if (!NewLHS || !NewRHS) + // We could not find an existing value to replace either LHS or RHS. + // Generating new instructions has subtler tradeoffs, so avoid doing that + // for now. + return false; + + DEBUG(dbgs() << "INDVARS: Simplified comparison: " << *ICmp << '\n'); + ICmp->setPredicate(InvariantPredicate); + ICmp->setOperand(0, NewLHS); + ICmp->setOperand(1, NewRHS); + return true; +} + /// SimplifyIVUsers helper for eliminating useless /// comparisons against an induction variable. void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { @@ -164,17 +241,11 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { Pred = ICmpInst::getSwappedPredicate(Pred); } - // Get the SCEVs for the ICmp operands. - const SCEV *S = SE->getSCEV(ICmp->getOperand(IVOperIdx)); - const SCEV *X = SE->getSCEV(ICmp->getOperand(1 - IVOperIdx)); - - // Simplify unnecessary loops away. + // Get the SCEVs for the ICmp operands (in the specific context of the + // current loop) const Loop *ICmpLoop = LI->getLoopFor(ICmp->getParent()); - S = SE->getSCEVAtScope(S, ICmpLoop); - X = SE->getSCEVAtScope(X, ICmpLoop); - - ICmpInst::Predicate InvariantPredicate; - const SCEV *InvariantLHS, *InvariantRHS; + const SCEV *S = SE->getSCEVAtScope(ICmp->getOperand(IVOperIdx), ICmpLoop); + const SCEV *X = SE->getSCEVAtScope(ICmp->getOperand(1 - IVOperIdx), ICmpLoop); // If the condition is always true or always false, replace it with // a constant value. @@ -186,85 +257,8 @@ void SimplifyIndvar::eliminateIVComparison(ICmpInst *ICmp, Value *IVOperand) { ICmp->replaceAllUsesWith(ConstantInt::getFalse(ICmp->getContext())); DeadInsts.emplace_back(ICmp); DEBUG(dbgs() << "INDVARS: Eliminated comparison: " << *ICmp << '\n'); - } else if (isa<PHINode>(IVOperand) && - SE->isLoopInvariantPredicate(Pred, S, X, L, InvariantPredicate, - InvariantLHS, InvariantRHS)) { - - // Rewrite the comparison to a loop invariant comparison if it can be done - // cheaply, where cheaply means "we don't need to emit any new - // instructions". - - Value *NewLHS = nullptr, *NewRHS = nullptr; - - if (S == InvariantLHS || X == InvariantLHS) - NewLHS = - ICmp->getOperand(S == InvariantLHS ? IVOperIdx : (1 - IVOperIdx)); - - if (S == InvariantRHS || X == InvariantRHS) - NewRHS = - ICmp->getOperand(S == InvariantRHS ? IVOperIdx : (1 - IVOperIdx)); - - auto *PN = cast<PHINode>(IVOperand); - for (unsigned i = 0, e = PN->getNumIncomingValues(); - i != e && (!NewLHS || !NewRHS); - ++i) { - - // If this is a value incoming from the backedge, then it cannot be a loop - // invariant value (since we know that IVOperand is an induction variable). - if (L->contains(PN->getIncomingBlock(i))) - continue; - - // NB! This following assert does not fundamentally have to be true, but - // it is true today given how SCEV analyzes induction variables. - // Specifically, today SCEV will *not* recognize %iv as an induction - // variable in the following case: - // - // define void @f(i32 %k) { - // entry: - // br i1 undef, label %r, label %l - // - // l: - // %k.inc.l = add i32 %k, 1 - // br label %loop - // - // r: - // %k.inc.r = add i32 %k, 1 - // br label %loop - // - // loop: - // %iv = phi i32 [ %k.inc.l, %l ], [ %k.inc.r, %r ], [ %iv.inc, %loop ] - // %iv.inc = add i32 %iv, 1 - // br label %loop - // } - // - // but if it starts to, at some point, then the assertion below will have - // to be changed to a runtime check. - - Value *Incoming = PN->getIncomingValue(i); - -#ifndef NDEBUG - if (auto *I = dyn_cast<Instruction>(Incoming)) - assert(DT->dominates(I, ICmp) && "Should be a unique loop dominating value!"); -#endif - - const SCEV *IncomingS = SE->getSCEV(Incoming); - - if (!NewLHS && IncomingS == InvariantLHS) - NewLHS = Incoming; - if (!NewRHS && IncomingS == InvariantRHS) - NewRHS = Incoming; - } - - if (!NewLHS || !NewRHS) - // We could not find an existing value to replace either LHS or RHS. - // Generating new instructions has subtler tradeoffs, so avoid doing that - // for now. - return; - - DEBUG(dbgs() << "INDVARS: Simplified comparison: " << *ICmp << '\n'); - ICmp->setPredicate(InvariantPredicate); - ICmp->setOperand(0, NewLHS); - ICmp->setOperand(1, NewRHS); + } else if (makeIVComparisonInvariant(ICmp, IVOperand)) { + // fallthrough to end of function } else if (ICmpInst::isSigned(OriginalPred) && SE->isKnownNonNegative(S) && SE->isKnownNonNegative(X)) { // If we were unable to make anything above, all we can is to canonicalize @@ -309,54 +303,90 @@ bool SimplifyIndvar::eliminateSDiv(BinaryOperator *SDiv) { return false; } -/// SimplifyIVUsers helper for eliminating useless -/// remainder operations operating on an induction variable. -void SimplifyIndvar::eliminateIVRemainder(BinaryOperator *Rem, - Value *IVOperand, - bool IsSigned) { +// i %s n -> i %u n if i >= 0 and n >= 0 +void SimplifyIndvar::replaceSRemWithURem(BinaryOperator *Rem) { + auto *N = Rem->getOperand(0), *D = Rem->getOperand(1); + auto *URem = BinaryOperator::Create(BinaryOperator::URem, N, D, + Rem->getName() + ".urem", Rem); + Rem->replaceAllUsesWith(URem); + DEBUG(dbgs() << "INDVARS: Simplified srem: " << *Rem << '\n'); + ++NumSimplifiedSRem; + Changed = true; + DeadInsts.emplace_back(Rem); +} + +// i % n --> i if i is in [0,n). +void SimplifyIndvar::replaceRemWithNumerator(BinaryOperator *Rem) { + Rem->replaceAllUsesWith(Rem->getOperand(0)); + DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); + ++NumElimRem; + Changed = true; + DeadInsts.emplace_back(Rem); +} + +// (i+1) % n --> (i+1)==n?0:(i+1) if i is in [0,n). +void SimplifyIndvar::replaceRemWithNumeratorOrZero(BinaryOperator *Rem) { + auto *T = Rem->getType(); + auto *N = Rem->getOperand(0), *D = Rem->getOperand(1); + ICmpInst *ICmp = new ICmpInst(Rem, ICmpInst::ICMP_EQ, N, D); + SelectInst *Sel = + SelectInst::Create(ICmp, ConstantInt::get(T, 0), N, "iv.rem", Rem); + Rem->replaceAllUsesWith(Sel); + DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); + ++NumElimRem; + Changed = true; + DeadInsts.emplace_back(Rem); +} + +/// SimplifyIVUsers helper for eliminating useless remainder operations +/// operating on an induction variable or replacing srem by urem. +void SimplifyIndvar::simplifyIVRemainder(BinaryOperator *Rem, Value *IVOperand, + bool IsSigned) { + auto *NValue = Rem->getOperand(0); + auto *DValue = Rem->getOperand(1); // We're only interested in the case where we know something about - // the numerator. - if (IVOperand != Rem->getOperand(0)) + // the numerator, unless it is a srem, because we want to replace srem by urem + // in general. + bool UsedAsNumerator = IVOperand == NValue; + if (!UsedAsNumerator && !IsSigned) return; - // Get the SCEVs for the ICmp operands. - const SCEV *S = SE->getSCEV(Rem->getOperand(0)); - const SCEV *X = SE->getSCEV(Rem->getOperand(1)); + const SCEV *N = SE->getSCEV(NValue); // Simplify unnecessary loops away. const Loop *ICmpLoop = LI->getLoopFor(Rem->getParent()); - S = SE->getSCEVAtScope(S, ICmpLoop); - X = SE->getSCEVAtScope(X, ICmpLoop); - - // i % n --> i if i is in [0,n). - if ((!IsSigned || SE->isKnownNonNegative(S)) && - SE->isKnownPredicate(IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, - S, X)) - Rem->replaceAllUsesWith(Rem->getOperand(0)); - else { - // (i+1) % n --> (i+1)==n?0:(i+1) if i is in [0,n). - const SCEV *LessOne = SE->getMinusSCEV(S, SE->getOne(S->getType())); - if (IsSigned && !SE->isKnownNonNegative(LessOne)) - return; + N = SE->getSCEVAtScope(N, ICmpLoop); + + bool IsNumeratorNonNegative = !IsSigned || SE->isKnownNonNegative(N); + + // Do not proceed if the Numerator may be negative + if (!IsNumeratorNonNegative) + return; - if (!SE->isKnownPredicate(IsSigned ? - ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, - LessOne, X)) + const SCEV *D = SE->getSCEV(DValue); + D = SE->getSCEVAtScope(D, ICmpLoop); + + if (UsedAsNumerator) { + auto LT = IsSigned ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; + if (SE->isKnownPredicate(LT, N, D)) { + replaceRemWithNumerator(Rem); return; + } - ICmpInst *ICmp = new ICmpInst(Rem, ICmpInst::ICMP_EQ, - Rem->getOperand(0), Rem->getOperand(1)); - SelectInst *Sel = - SelectInst::Create(ICmp, - ConstantInt::get(Rem->getType(), 0), - Rem->getOperand(0), "tmp", Rem); - Rem->replaceAllUsesWith(Sel); + auto *T = Rem->getType(); + const auto *NLessOne = SE->getMinusSCEV(N, SE->getOne(T)); + if (SE->isKnownPredicate(LT, NLessOne, D)) { + replaceRemWithNumeratorOrZero(Rem); + return; + } } - DEBUG(dbgs() << "INDVARS: Simplified rem: " << *Rem << '\n'); - ++NumElimRem; - Changed = true; - DeadInsts.emplace_back(Rem); + // Try to replace SRem with URem, if both N and D are known non-negative. + // Since we had already check N, we only need to check D now + if (!IsSigned || !SE->isKnownNonNegative(D)) + return; + + replaceSRemWithURem(Rem); } bool SimplifyIndvar::eliminateOverflowIntrinsic(CallInst *CI) { @@ -474,7 +504,7 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, if (BinaryOperator *Bin = dyn_cast<BinaryOperator>(UseInst)) { bool IsSRem = Bin->getOpcode() == Instruction::SRem; if (IsSRem || Bin->getOpcode() == Instruction::URem) { - eliminateIVRemainder(Bin, IVOperand, IsSRem); + simplifyIVRemainder(Bin, IVOperand, IsSRem); return true; } @@ -492,6 +522,40 @@ bool SimplifyIndvar::eliminateIVUser(Instruction *UseInst, return false; } +static Instruction *GetLoopInvariantInsertPosition(Loop *L, Instruction *Hint) { + if (auto *BB = L->getLoopPreheader()) + return BB->getTerminator(); + + return Hint; +} + +/// Replace the UseInst with a constant if possible. +bool SimplifyIndvar::replaceIVUserWithLoopInvariant(Instruction *I) { + if (!SE->isSCEVable(I->getType())) + return false; + + // Get the symbolic expression for this instruction. + const SCEV *S = SE->getSCEV(I); + + if (!SE->isLoopInvariant(S, L)) + return false; + + // Do not generate something ridiculous even if S is loop invariant. + if (Rewriter.isHighCostExpansion(S, L, I)) + return false; + + auto *IP = GetLoopInvariantInsertPosition(L, I); + auto *Invariant = Rewriter.expandCodeFor(S, I->getType(), IP); + + I->replaceAllUsesWith(Invariant); + DEBUG(dbgs() << "INDVARS: Replace IV user: " << *I + << " with loop invariant: " << *S << '\n'); + ++NumFoldedUser; + Changed = true; + DeadInsts.emplace_back(I); + return true; +} + /// Eliminate any operation that SCEV can prove is an identity function. bool SimplifyIndvar::eliminateIdentitySCEV(Instruction *UseInst, Instruction *IVOperand) { @@ -627,7 +691,7 @@ bool SimplifyIndvar::strengthenRightShift(BinaryOperator *BO, /// Add all uses of Def to the current IV's worklist. static void pushIVUsers( - Instruction *Def, + Instruction *Def, Loop *L, SmallPtrSet<Instruction*,16> &Simplified, SmallVectorImpl< std::pair<Instruction*,Instruction*> > &SimpleIVUsers) { @@ -638,8 +702,19 @@ static void pushIVUsers( // Also ensure unique worklist users. // If Def is a LoopPhi, it may not be in the Simplified set, so check for // self edges first. - if (UI != Def && Simplified.insert(UI).second) - SimpleIVUsers.push_back(std::make_pair(UI, Def)); + if (UI == Def) + continue; + + // Only change the current Loop, do not change the other parts (e.g. other + // Loops). + if (!L->contains(UI)) + continue; + + // Do not push the same instruction more than once. + if (!Simplified.insert(UI).second) + continue; + + SimpleIVUsers.push_back(std::make_pair(UI, Def)); } } @@ -689,7 +764,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { // Push users of the current LoopPhi. In rare cases, pushIVUsers may be // called multiple times for the same LoopPhi. This is the proper thing to // do for loop header phis that use each other. - pushIVUsers(CurrIV, Simplified, SimpleIVUsers); + pushIVUsers(CurrIV, L, Simplified, SimpleIVUsers); while (!SimpleIVUsers.empty()) { std::pair<Instruction*, Instruction*> UseOper = @@ -699,6 +774,11 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { // Bypass back edges to avoid extra work. if (UseInst == CurrIV) continue; + // Try to replace UseInst with a loop invariant before any other + // simplifications. + if (replaceIVUserWithLoopInvariant(UseInst)) + continue; + Instruction *IVOperand = UseOper.second; for (unsigned N = 0; IVOperand; ++N) { assert(N <= Simplified.size() && "runaway iteration"); @@ -712,7 +792,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { continue; if (eliminateIVUser(UseOper.first, IVOperand)) { - pushIVUsers(IVOperand, Simplified, SimpleIVUsers); + pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers); continue; } @@ -722,7 +802,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { (isa<ShlOperator>(BO) && strengthenRightShift(BO, IVOperand))) { // re-queue uses of the now modified binary operator and fall // through to the checks that remain. - pushIVUsers(IVOperand, Simplified, SimpleIVUsers); + pushIVUsers(IVOperand, L, Simplified, SimpleIVUsers); } } @@ -732,7 +812,7 @@ void SimplifyIndvar::simplifyUsers(PHINode *CurrIV, IVVisitor *V) { continue; } if (isSimpleIVUser(UseOper.first, L, SE)) { - pushIVUsers(UseOper.first, Simplified, SimpleIVUsers); + pushIVUsers(UseOper.first, L, Simplified, SimpleIVUsers); } } } @@ -745,8 +825,9 @@ void IVVisitor::anchor() { } /// by using ScalarEvolution to analyze the IV's recurrence. bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, DominatorTree *DT, LoopInfo *LI, SmallVectorImpl<WeakTrackingVH> &Dead, - IVVisitor *V) { - SimplifyIndvar SIV(LI->getLoopFor(CurrIV->getParent()), SE, DT, LI, Dead); + SCEVExpander &Rewriter, IVVisitor *V) { + SimplifyIndvar SIV(LI->getLoopFor(CurrIV->getParent()), SE, DT, LI, Rewriter, + Dead); SIV.simplifyUsers(CurrIV, V); return SIV.hasChanged(); } @@ -755,9 +836,13 @@ bool simplifyUsersOfIV(PHINode *CurrIV, ScalarEvolution *SE, DominatorTree *DT, /// loop. This does not actually change or add IVs. bool simplifyLoopIVs(Loop *L, ScalarEvolution *SE, DominatorTree *DT, LoopInfo *LI, SmallVectorImpl<WeakTrackingVH> &Dead) { + SCEVExpander Rewriter(*SE, SE->getDataLayout(), "indvars"); +#ifndef NDEBUG + Rewriter.setDebugType(DEBUG_TYPE); +#endif bool Changed = false; for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { - Changed |= simplifyUsersOfIV(cast<PHINode>(I), SE, DT, LI, Dead); + Changed |= simplifyUsersOfIV(cast<PHINode>(I), SE, DT, LI, Dead, Rewriter); } return Changed; } diff --git a/lib/Transforms/Utils/SimplifyInstructions.cpp b/lib/Transforms/Utils/SimplifyInstructions.cpp index 2ea15f65cef9..f3d4f2ef38d7 100644 --- a/lib/Transforms/Utils/SimplifyInstructions.cpp +++ b/lib/Transforms/Utils/SimplifyInstructions.cpp @@ -20,7 +20,7 @@ #include "llvm/ADT/Statistic.h" #include "llvm/Analysis/AssumptionCache.h" #include "llvm/Analysis/InstructionSimplify.h" -#include "llvm/Analysis/OptimizationDiagnosticInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/Dominators.h" diff --git a/lib/Transforms/Utils/SimplifyLibCalls.cpp b/lib/Transforms/Utils/SimplifyLibCalls.cpp index 77c0a41929ac..03a1d55ddc30 100644 --- a/lib/Transforms/Utils/SimplifyLibCalls.cpp +++ b/lib/Transforms/Utils/SimplifyLibCalls.cpp @@ -18,10 +18,11 @@ #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" #include "llvm/ADT/Triple.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/IR/IntrinsicInst.h" @@ -484,10 +485,10 @@ Value *LibCallSimplifier::optimizeStringLength(CallInst *CI, IRBuilder<> &B, uint64_t LenTrue = GetStringLength(SI->getTrueValue(), CharSize); uint64_t LenFalse = GetStringLength(SI->getFalseValue(), CharSize); if (LenTrue && LenFalse) { - Function *Caller = CI->getParent()->getParent(); - emitOptimizationRemark(CI->getContext(), "simplify-libcalls", *Caller, - SI->getDebugLoc(), - "folded strlen(select) to select of constants"); + ORE.emit([&]() { + return OptimizationRemark("instcombine", "simplify-libcalls", CI) + << "folded strlen(select) to select of constants"; + }); return B.CreateSelect(SI->getCondition(), ConstantInt::get(CI->getType(), LenTrue - 1), ConstantInt::get(CI->getType(), LenFalse - 1)); @@ -509,6 +510,9 @@ Value *LibCallSimplifier::optimizeStrLen(CallInst *CI, IRBuilder<> &B) { Value *LibCallSimplifier::optimizeWcslen(CallInst *CI, IRBuilder<> &B) { Module &M = *CI->getParent()->getParent()->getParent(); unsigned WCharSize = TLI->getWCharSize(M) * 8; + // We cannot perform this optimization without wchar_size metadata. + if (WCharSize == 0) + return nullptr; return optimizeStringLength(CI, B, WCharSize); } @@ -753,29 +757,44 @@ Value *LibCallSimplifier::optimizeMemCmp(CallInst *CI, IRBuilder<> &B) { } // memcmp(S1,S2,N/8)==0 -> (*(intN_t*)S1 != *(intN_t*)S2)==0 + // TODO: The case where both inputs are constants does not need to be limited + // to legal integers or equality comparison. See block below this. if (DL.isLegalInteger(Len * 8) && isOnlyUsedInZeroEqualityComparison(CI)) { - IntegerType *IntType = IntegerType::get(CI->getContext(), Len * 8); unsigned PrefAlignment = DL.getPrefTypeAlignment(IntType); - if (getKnownAlignment(LHS, DL, CI) >= PrefAlignment && - getKnownAlignment(RHS, DL, CI) >= PrefAlignment) { - - Type *LHSPtrTy = - IntType->getPointerTo(LHS->getType()->getPointerAddressSpace()); - Type *RHSPtrTy = - IntType->getPointerTo(RHS->getType()->getPointerAddressSpace()); - - Value *LHSV = - B.CreateLoad(B.CreateBitCast(LHS, LHSPtrTy, "lhsc"), "lhsv"); - Value *RHSV = - B.CreateLoad(B.CreateBitCast(RHS, RHSPtrTy, "rhsc"), "rhsv"); + // First, see if we can fold either argument to a constant. + Value *LHSV = nullptr; + if (auto *LHSC = dyn_cast<Constant>(LHS)) { + LHSC = ConstantExpr::getBitCast(LHSC, IntType->getPointerTo()); + LHSV = ConstantFoldLoadFromConstPtr(LHSC, IntType, DL); + } + Value *RHSV = nullptr; + if (auto *RHSC = dyn_cast<Constant>(RHS)) { + RHSC = ConstantExpr::getBitCast(RHSC, IntType->getPointerTo()); + RHSV = ConstantFoldLoadFromConstPtr(RHSC, IntType, DL); + } + // Don't generate unaligned loads. If either source is constant data, + // alignment doesn't matter for that source because there is no load. + if ((LHSV || getKnownAlignment(LHS, DL, CI) >= PrefAlignment) && + (RHSV || getKnownAlignment(RHS, DL, CI) >= PrefAlignment)) { + if (!LHSV) { + Type *LHSPtrTy = + IntType->getPointerTo(LHS->getType()->getPointerAddressSpace()); + LHSV = B.CreateLoad(B.CreateBitCast(LHS, LHSPtrTy), "lhsv"); + } + if (!RHSV) { + Type *RHSPtrTy = + IntType->getPointerTo(RHS->getType()->getPointerAddressSpace()); + RHSV = B.CreateLoad(B.CreateBitCast(RHS, RHSPtrTy), "rhsv"); + } return B.CreateZExt(B.CreateICmpNE(LHSV, RHSV), CI->getType(), "memcmp"); } } - // Constant folding: memcmp(x, y, l) -> cnst (all arguments are constant) + // Constant folding: memcmp(x, y, Len) -> constant (all arguments are const). + // TODO: This is limited to i8 arrays. StringRef LHSStr, RHSStr; if (getConstantStringInfo(LHS, LHSStr) && getConstantStringInfo(RHS, RHSStr)) { @@ -1014,6 +1033,35 @@ static Value *optimizeBinaryDoubleFP(CallInst *CI, IRBuilder<> &B) { return B.CreateFPExt(V, B.getDoubleTy()); } +// cabs(z) -> sqrt((creal(z)*creal(z)) + (cimag(z)*cimag(z))) +Value *LibCallSimplifier::optimizeCAbs(CallInst *CI, IRBuilder<> &B) { + if (!CI->isFast()) + return nullptr; + + // Propagate fast-math flags from the existing call to new instructions. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + + Value *Real, *Imag; + if (CI->getNumArgOperands() == 1) { + Value *Op = CI->getArgOperand(0); + assert(Op->getType()->isArrayTy() && "Unexpected signature for cabs!"); + Real = B.CreateExtractValue(Op, 0, "real"); + Imag = B.CreateExtractValue(Op, 1, "imag"); + } else { + assert(CI->getNumArgOperands() == 2 && "Unexpected signature for cabs!"); + Real = CI->getArgOperand(0); + Imag = CI->getArgOperand(1); + } + + Value *RealReal = B.CreateFMul(Real, Real); + Value *ImagImag = B.CreateFMul(Imag, Imag); + + Function *FSqrt = Intrinsic::getDeclaration(CI->getModule(), Intrinsic::sqrt, + CI->getType()); + return B.CreateCall(FSqrt, B.CreateFAdd(RealReal, ImagImag), "cabs"); +} + Value *LibCallSimplifier::optimizeCos(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; @@ -1055,6 +1103,51 @@ static Value *getPow(Value *InnerChain[33], unsigned Exp, IRBuilder<> &B) { return InnerChain[Exp]; } +/// Use square root in place of pow(x, +/-0.5). +Value *LibCallSimplifier::replacePowWithSqrt(CallInst *Pow, IRBuilder<> &B) { + // TODO: There is some subset of 'fast' under which these transforms should + // be allowed. + if (!Pow->isFast()) + return nullptr; + + const APFloat *Arg1C; + if (!match(Pow->getArgOperand(1), m_APFloat(Arg1C))) + return nullptr; + if (!Arg1C->isExactlyValue(0.5) && !Arg1C->isExactlyValue(-0.5)) + return nullptr; + + // Fast-math flags from the pow() are propagated to all replacement ops. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(Pow->getFastMathFlags()); + Type *Ty = Pow->getType(); + Value *Sqrt; + if (Pow->hasFnAttr(Attribute::ReadNone)) { + // We know that errno is never set, so replace with an intrinsic: + // pow(x, 0.5) --> llvm.sqrt(x) + // llvm.pow(x, 0.5) --> llvm.sqrt(x) + auto *F = Intrinsic::getDeclaration(Pow->getModule(), Intrinsic::sqrt, Ty); + Sqrt = B.CreateCall(F, Pow->getArgOperand(0)); + } else if (hasUnaryFloatFn(TLI, Ty, LibFunc_sqrt, LibFunc_sqrtf, + LibFunc_sqrtl)) { + // Errno could be set, so we must use a sqrt libcall. + // TODO: We also should check that the target can in fact lower the sqrt + // libcall. We currently have no way to ask this question, so we ask + // whether the target has a sqrt libcall which is not exactly the same. + Sqrt = emitUnaryFloatFnCall(Pow->getArgOperand(0), + TLI->getName(LibFunc_sqrt), B, + Pow->getCalledFunction()->getAttributes()); + } else { + // We can't replace with an intrinsic or a libcall. + return nullptr; + } + + // If this is pow(x, -0.5), get the reciprocal. + if (Arg1C->isExactlyValue(-0.5)) + Sqrt = B.CreateFDiv(ConstantFP::get(Ty, 1.0), Sqrt); + + return Sqrt; +} + Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { Function *Callee = CI->getCalledFunction(); Value *Ret = nullptr; @@ -1092,7 +1185,7 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // Example: x = 1000, y = 0.001. // pow(exp(x), y) = pow(inf, 0.001) = inf, whereas exp(x*y) = exp(1). auto *OpC = dyn_cast<CallInst>(Op1); - if (OpC && OpC->hasUnsafeAlgebra() && CI->hasUnsafeAlgebra()) { + if (OpC && OpC->isFast() && CI->isFast()) { LibFunc Func; Function *OpCCallee = OpC->getCalledFunction(); if (OpCCallee && TLI->getLibFunc(OpCCallee->getName(), Func) && @@ -1105,6 +1198,9 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { } } + if (Value *Sqrt = replacePowWithSqrt(CI, B)) + return Sqrt; + ConstantFP *Op2C = dyn_cast<ConstantFP>(Op2); if (!Op2C) return Ret; @@ -1112,42 +1208,10 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { if (Op2C->getValueAPF().isZero()) // pow(x, 0.0) -> 1.0 return ConstantFP::get(CI->getType(), 1.0); - if (Op2C->isExactlyValue(-0.5) && - hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, - LibFunc_sqrtl)) { - // If -ffast-math: - // pow(x, -0.5) -> 1.0 / sqrt(x) - if (CI->hasUnsafeAlgebra()) { - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - // TODO: If the pow call is an intrinsic, we should lower to the sqrt - // intrinsic, so we match errno semantics. We also should check that the - // target can in fact lower the sqrt intrinsic -- we currently have no way - // to ask this question other than asking whether the target has a sqrt - // libcall, which is a sufficient but not necessary condition. - Value *Sqrt = emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B, - Callee->getAttributes()); - - return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Sqrt, "sqrtrecip"); - } - } - + // FIXME: Correct the transforms and pull this into replacePowWithSqrt(). if (Op2C->isExactlyValue(0.5) && hasUnaryFloatFn(TLI, Op2->getType(), LibFunc_sqrt, LibFunc_sqrtf, LibFunc_sqrtl)) { - - // In -ffast-math, pow(x, 0.5) -> sqrt(x). - if (CI->hasUnsafeAlgebra()) { - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - - // TODO: As above, we should lower to the sqrt intrinsic if the pow is an - // intrinsic, to match errno semantics. - return emitUnaryFloatFnCall(Op1, TLI->getName(LibFunc_sqrt), B, - Callee->getAttributes()); - } - // Expand pow(x, 0.5) to (x == -infinity ? +infinity : fabs(sqrt(x))). // This is faster than calling pow, and still handles negative zero // and negative infinity correctly. @@ -1169,15 +1233,21 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { return Sel; } - if (Op2C->isExactlyValue(1.0)) // pow(x, 1.0) -> x + // Propagate fast-math-flags from the call to any created instructions. + IRBuilder<>::FastMathFlagGuard Guard(B); + B.setFastMathFlags(CI->getFastMathFlags()); + // pow(x, 1.0) --> x + if (Op2C->isExactlyValue(1.0)) return Op1; - if (Op2C->isExactlyValue(2.0)) // pow(x, 2.0) -> x*x + // pow(x, 2.0) --> x * x + if (Op2C->isExactlyValue(2.0)) return B.CreateFMul(Op1, Op1, "pow2"); - if (Op2C->isExactlyValue(-1.0)) // pow(x, -1.0) -> 1.0/x + // pow(x, -1.0) --> 1.0 / x + if (Op2C->isExactlyValue(-1.0)) return B.CreateFDiv(ConstantFP::get(CI->getType(), 1.0), Op1, "powrecip"); // In -ffast-math, generate repeated fmul instead of generating pow(x, n). - if (CI->hasUnsafeAlgebra()) { + if (CI->isFast()) { APFloat V = abs(Op2C->getValueAPF()); // We limit to a max of 7 fmul(s). Thus max exponent is 32. // This transformation applies to integer exponents only. @@ -1185,10 +1255,6 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { !V.isInteger()) return nullptr; - // Propagate fast math flags. - IRBuilder<>::FastMathFlagGuard Guard(B); - B.setFastMathFlags(CI->getFastMathFlags()); - // We will memoize intermediate products of the Addition Chain. Value *InnerChain[33] = {nullptr}; InnerChain[1] = Op1; @@ -1196,8 +1262,8 @@ Value *LibCallSimplifier::optimizePow(CallInst *CI, IRBuilder<> &B) { // We cannot readily convert a non-double type (like float) to a double. // So we first convert V to something which could be converted to double. - bool ignored; - V.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &ignored); + bool Ignored; + V.convert(APFloat::IEEEdouble(), APFloat::rmTowardZero, &Ignored); Value *FMul = getPow(InnerChain, V.convertToDouble(), B); // For negative exponents simply compute the reciprocal. @@ -1265,9 +1331,9 @@ Value *LibCallSimplifier::optimizeFMinFMax(CallInst *CI, IRBuilder<> &B) { IRBuilder<>::FastMathFlagGuard Guard(B); FastMathFlags FMF; - if (CI->hasUnsafeAlgebra()) { - // Unsafe algebra sets all fast-math-flags to true. - FMF.setUnsafeAlgebra(); + if (CI->isFast()) { + // If the call is 'fast', then anything we create here will also be 'fast'. + FMF.setFast(); } else { // At a minimum, no-nans-fp-math must be true. if (!CI->hasNoNaNs()) @@ -1298,13 +1364,13 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { if (UnsafeFPShrink && hasFloatVersion(Name)) Ret = optimizeUnaryDoubleFP(CI, B, true); - if (!CI->hasUnsafeAlgebra()) + if (!CI->isFast()) return Ret; Value *Op1 = CI->getArgOperand(0); auto *OpC = dyn_cast<CallInst>(Op1); - // The earlier call must also be unsafe in order to do these transforms. - if (!OpC || !OpC->hasUnsafeAlgebra()) + // The earlier call must also be 'fast' in order to do these transforms. + if (!OpC || !OpC->isFast()) return Ret; // log(pow(x,y)) -> y*log(x) @@ -1314,7 +1380,7 @@ Value *LibCallSimplifier::optimizeLog(CallInst *CI, IRBuilder<> &B) { IRBuilder<>::FastMathFlagGuard Guard(B); FastMathFlags FMF; - FMF.setUnsafeAlgebra(); + FMF.setFast(); B.setFastMathFlags(FMF); LibFunc Func; @@ -1346,11 +1412,11 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { Callee->getIntrinsicID() == Intrinsic::sqrt)) Ret = optimizeUnaryDoubleFP(CI, B, true); - if (!CI->hasUnsafeAlgebra()) + if (!CI->isFast()) return Ret; Instruction *I = dyn_cast<Instruction>(CI->getArgOperand(0)); - if (!I || I->getOpcode() != Instruction::FMul || !I->hasUnsafeAlgebra()) + if (!I || I->getOpcode() != Instruction::FMul || !I->isFast()) return Ret; // We're looking for a repeated factor in a multiplication tree, @@ -1372,8 +1438,7 @@ Value *LibCallSimplifier::optimizeSqrt(CallInst *CI, IRBuilder<> &B) { Value *OtherMul0, *OtherMul1; if (match(Op0, m_FMul(m_Value(OtherMul0), m_Value(OtherMul1)))) { // Pattern: sqrt((x * y) * z) - if (OtherMul0 == OtherMul1 && - cast<Instruction>(Op0)->hasUnsafeAlgebra()) { + if (OtherMul0 == OtherMul1 && cast<Instruction>(Op0)->isFast()) { // Matched: sqrt((x * x) * z) RepeatOp = OtherMul0; OtherOp = Op1; @@ -1418,8 +1483,8 @@ Value *LibCallSimplifier::optimizeTan(CallInst *CI, IRBuilder<> &B) { if (!OpC) return Ret; - // Both calls must allow unsafe optimizations in order to remove them. - if (!CI->hasUnsafeAlgebra() || !OpC->hasUnsafeAlgebra()) + // Both calls must be 'fast' in order to remove them. + if (!CI->isFast() || !OpC->isFast()) return Ret; // tan(atan(x)) -> x @@ -2043,13 +2108,107 @@ Value *LibCallSimplifier::optimizeStringMemoryLibCall(CallInst *CI, return nullptr; } +Value *LibCallSimplifier::optimizeFloatingPointLibCall(CallInst *CI, + LibFunc Func, + IRBuilder<> &Builder) { + // Don't optimize calls that require strict floating point semantics. + if (CI->isStrictFP()) + return nullptr; + + switch (Func) { + case LibFunc_cosf: + case LibFunc_cos: + case LibFunc_cosl: + return optimizeCos(CI, Builder); + case LibFunc_sinpif: + case LibFunc_sinpi: + case LibFunc_cospif: + case LibFunc_cospi: + return optimizeSinCosPi(CI, Builder); + case LibFunc_powf: + case LibFunc_pow: + case LibFunc_powl: + return optimizePow(CI, Builder); + case LibFunc_exp2l: + case LibFunc_exp2: + case LibFunc_exp2f: + return optimizeExp2(CI, Builder); + case LibFunc_fabsf: + case LibFunc_fabs: + case LibFunc_fabsl: + return replaceUnaryCall(CI, Builder, Intrinsic::fabs); + case LibFunc_sqrtf: + case LibFunc_sqrt: + case LibFunc_sqrtl: + return optimizeSqrt(CI, Builder); + case LibFunc_log: + case LibFunc_log10: + case LibFunc_log1p: + case LibFunc_log2: + case LibFunc_logb: + return optimizeLog(CI, Builder); + case LibFunc_tan: + case LibFunc_tanf: + case LibFunc_tanl: + return optimizeTan(CI, Builder); + case LibFunc_ceil: + return replaceUnaryCall(CI, Builder, Intrinsic::ceil); + case LibFunc_floor: + return replaceUnaryCall(CI, Builder, Intrinsic::floor); + case LibFunc_round: + return replaceUnaryCall(CI, Builder, Intrinsic::round); + case LibFunc_nearbyint: + return replaceUnaryCall(CI, Builder, Intrinsic::nearbyint); + case LibFunc_rint: + return replaceUnaryCall(CI, Builder, Intrinsic::rint); + case LibFunc_trunc: + return replaceUnaryCall(CI, Builder, Intrinsic::trunc); + case LibFunc_acos: + case LibFunc_acosh: + case LibFunc_asin: + case LibFunc_asinh: + case LibFunc_atan: + case LibFunc_atanh: + case LibFunc_cbrt: + case LibFunc_cosh: + case LibFunc_exp: + case LibFunc_exp10: + case LibFunc_expm1: + case LibFunc_sin: + case LibFunc_sinh: + case LibFunc_tanh: + if (UnsafeFPShrink && hasFloatVersion(CI->getCalledFunction()->getName())) + return optimizeUnaryDoubleFP(CI, Builder, true); + return nullptr; + case LibFunc_copysign: + if (hasFloatVersion(CI->getCalledFunction()->getName())) + return optimizeBinaryDoubleFP(CI, Builder); + return nullptr; + case LibFunc_fminf: + case LibFunc_fmin: + case LibFunc_fminl: + case LibFunc_fmaxf: + case LibFunc_fmax: + case LibFunc_fmaxl: + return optimizeFMinFMax(CI, Builder); + case LibFunc_cabs: + case LibFunc_cabsf: + case LibFunc_cabsl: + return optimizeCAbs(CI, Builder); + default: + return nullptr; + } +} + Value *LibCallSimplifier::optimizeCall(CallInst *CI) { + // TODO: Split out the code below that operates on FP calls so that + // we can all non-FP calls with the StrictFP attribute to be + // optimized. if (CI->isNoBuiltin()) return nullptr; LibFunc Func; Function *Callee = CI->getCalledFunction(); - StringRef FuncName = Callee->getName(); SmallVector<OperandBundleDef, 2> OpBundles; CI->getOperandBundlesAsDefs(OpBundles); @@ -2057,15 +2216,19 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { bool isCallingConvC = isCallingConvCCompatible(CI); // Command-line parameter overrides instruction attribute. + // This can't be moved to optimizeFloatingPointLibCall() because it may be + // used by the intrinsic optimizations. if (EnableUnsafeFPShrink.getNumOccurrences() > 0) UnsafeFPShrink = EnableUnsafeFPShrink; - else if (isa<FPMathOperator>(CI) && CI->hasUnsafeAlgebra()) + else if (isa<FPMathOperator>(CI) && CI->isFast()) UnsafeFPShrink = true; // First, check for intrinsics. if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI)) { if (!isCallingConvC) return nullptr; + // The FP intrinsics have corresponding constrained versions so we don't + // need to check for the StrictFP attribute here. switch (II->getIntrinsicID()) { case Intrinsic::pow: return optimizePow(CI, Builder); @@ -2106,32 +2269,9 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return nullptr; if (Value *V = optimizeStringMemoryLibCall(CI, Builder)) return V; + if (Value *V = optimizeFloatingPointLibCall(CI, Func, Builder)) + return V; switch (Func) { - case LibFunc_cosf: - case LibFunc_cos: - case LibFunc_cosl: - return optimizeCos(CI, Builder); - case LibFunc_sinpif: - case LibFunc_sinpi: - case LibFunc_cospif: - case LibFunc_cospi: - return optimizeSinCosPi(CI, Builder); - case LibFunc_powf: - case LibFunc_pow: - case LibFunc_powl: - return optimizePow(CI, Builder); - case LibFunc_exp2l: - case LibFunc_exp2: - case LibFunc_exp2f: - return optimizeExp2(CI, Builder); - case LibFunc_fabsf: - case LibFunc_fabs: - case LibFunc_fabsl: - return replaceUnaryCall(CI, Builder, Intrinsic::fabs); - case LibFunc_sqrtf: - case LibFunc_sqrt: - case LibFunc_sqrtl: - return optimizeSqrt(CI, Builder); case LibFunc_ffs: case LibFunc_ffsl: case LibFunc_ffsll: @@ -2160,18 +2300,8 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeFWrite(CI, Builder); case LibFunc_fputs: return optimizeFPuts(CI, Builder); - case LibFunc_log: - case LibFunc_log10: - case LibFunc_log1p: - case LibFunc_log2: - case LibFunc_logb: - return optimizeLog(CI, Builder); case LibFunc_puts: return optimizePuts(CI, Builder); - case LibFunc_tan: - case LibFunc_tanf: - case LibFunc_tanl: - return optimizeTan(CI, Builder); case LibFunc_perror: return optimizeErrorReporting(CI, Builder); case LibFunc_vfprintf: @@ -2179,46 +2309,6 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { return optimizeErrorReporting(CI, Builder, 0); case LibFunc_fputc: return optimizeErrorReporting(CI, Builder, 1); - case LibFunc_ceil: - return replaceUnaryCall(CI, Builder, Intrinsic::ceil); - case LibFunc_floor: - return replaceUnaryCall(CI, Builder, Intrinsic::floor); - case LibFunc_round: - return replaceUnaryCall(CI, Builder, Intrinsic::round); - case LibFunc_nearbyint: - return replaceUnaryCall(CI, Builder, Intrinsic::nearbyint); - case LibFunc_rint: - return replaceUnaryCall(CI, Builder, Intrinsic::rint); - case LibFunc_trunc: - return replaceUnaryCall(CI, Builder, Intrinsic::trunc); - case LibFunc_acos: - case LibFunc_acosh: - case LibFunc_asin: - case LibFunc_asinh: - case LibFunc_atan: - case LibFunc_atanh: - case LibFunc_cbrt: - case LibFunc_cosh: - case LibFunc_exp: - case LibFunc_exp10: - case LibFunc_expm1: - case LibFunc_sin: - case LibFunc_sinh: - case LibFunc_tanh: - if (UnsafeFPShrink && hasFloatVersion(FuncName)) - return optimizeUnaryDoubleFP(CI, Builder, true); - return nullptr; - case LibFunc_copysign: - if (hasFloatVersion(FuncName)) - return optimizeBinaryDoubleFP(CI, Builder); - return nullptr; - case LibFunc_fminf: - case LibFunc_fmin: - case LibFunc_fminl: - case LibFunc_fmaxf: - case LibFunc_fmax: - case LibFunc_fmaxl: - return optimizeFMinFMax(CI, Builder); default: return nullptr; } @@ -2228,9 +2318,10 @@ Value *LibCallSimplifier::optimizeCall(CallInst *CI) { LibCallSimplifier::LibCallSimplifier( const DataLayout &DL, const TargetLibraryInfo *TLI, + OptimizationRemarkEmitter &ORE, function_ref<void(Instruction *, Value *)> Replacer) - : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), UnsafeFPShrink(false), - Replacer(Replacer) {} + : FortifiedSimplifier(TLI), DL(DL), TLI(TLI), ORE(ORE), + UnsafeFPShrink(false), Replacer(Replacer) {} void LibCallSimplifier::replaceAllUsesWith(Instruction *I, Value *With) { // Indirect through the replacer used in this instance. diff --git a/lib/Transforms/Utils/SplitModule.cpp b/lib/Transforms/Utils/SplitModule.cpp index e9a368f4faa4..968eb0208f43 100644 --- a/lib/Transforms/Utils/SplitModule.cpp +++ b/lib/Transforms/Utils/SplitModule.cpp @@ -13,32 +13,51 @@ // //===----------------------------------------------------------------------===// -#define DEBUG_TYPE "split-module" - #include "llvm/Transforms/Utils/SplitModule.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/EquivalenceClasses.h" -#include "llvm/ADT/Hashing.h" -#include "llvm/ADT/MapVector.h" -#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/Comdat.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" #include "llvm/IR/GlobalObject.h" +#include "llvm/IR/GlobalIndirectSymbol.h" #include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Module.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/MD5.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <memory> #include <queue> +#include <utility> +#include <vector> using namespace llvm; +#define DEBUG_TYPE "split-module" + namespace { -typedef EquivalenceClasses<const GlobalValue *> ClusterMapType; -typedef DenseMap<const Comdat *, const GlobalValue *> ComdatMembersType; -typedef DenseMap<const GlobalValue *, unsigned> ClusterIDMapType; -} + +using ClusterMapType = EquivalenceClasses<const GlobalValue *>; +using ComdatMembersType = DenseMap<const Comdat *, const GlobalValue *>; +using ClusterIDMapType = DenseMap<const GlobalValue *, unsigned>; + +} // end anonymous namespace static void addNonConstUser(ClusterMapType &GVtoClusterMap, const GlobalValue *GV, const User *U) { @@ -125,9 +144,9 @@ static void findPartitions(Module *M, ClusterIDMapType &ClusterIDMap, addAllGlobalValueUsers(GVtoClusterMap, &GV, &GV); }; - std::for_each(M->begin(), M->end(), recordGVSet); - std::for_each(M->global_begin(), M->global_end(), recordGVSet); - std::for_each(M->alias_begin(), M->alias_end(), recordGVSet); + llvm::for_each(M->functions(), recordGVSet); + llvm::for_each(M->globals(), recordGVSet); + llvm::for_each(M->aliases(), recordGVSet); // Assigned all GVs to merged clusters while balancing number of objects in // each. @@ -147,7 +166,8 @@ static void findPartitions(Module *M, ClusterIDMapType &ClusterIDMap, for (unsigned i = 0; i < N; ++i) BalancinQueue.push(std::make_pair(i, 0)); - typedef std::pair<unsigned, ClusterMapType::iterator> SortType; + using SortType = std::pair<unsigned, ClusterMapType::iterator>; + SmallVector<SortType, 64> Sets; SmallPtrSet<const GlobalValue *, 32> Visited; diff --git a/lib/Transforms/Utils/SymbolRewriter.cpp b/lib/Transforms/Utils/SymbolRewriter.cpp index 20107553665f..3640541e63cc 100644 --- a/lib/Transforms/Utils/SymbolRewriter.cpp +++ b/lib/Transforms/Utils/SymbolRewriter.cpp @@ -1,4 +1,4 @@ -//===- SymbolRewriter.cpp - Symbol Rewriter ---------------------*- C++ -*-===// +//===- SymbolRewriter.cpp - Symbol Rewriter -------------------------------===// // // The LLVM Compiler Infrastructure // @@ -57,25 +57,41 @@ // //===----------------------------------------------------------------------===// -#define DEBUG_TYPE "symbol-rewriter" #include "llvm/Transforms/Utils/SymbolRewriter.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallString.h" -#include "llvm/IR/LegacyPassManager.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/ilist.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/IR/Comdat.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalObject.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Value.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" -#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/ErrorOr.h" #include "llvm/Support/MemoryBuffer.h" #include "llvm/Support/Regex.h" #include "llvm/Support/SourceMgr.h" #include "llvm/Support/YAMLParser.h" -#include "llvm/Support/raw_ostream.h" +#include <memory> +#include <string> +#include <vector> using namespace llvm; using namespace SymbolRewriter; +#define DEBUG_TYPE "symbol-rewriter" + static cl::list<std::string> RewriteMapFiles("rewrite-map-file", cl::desc("Symbol Rewrite Map"), - cl::value_desc("filename")); + cl::value_desc("filename"), + cl::Hidden); static void rewriteComdat(Module &M, GlobalObject *GO, const std::string &Source, @@ -92,8 +108,9 @@ static void rewriteComdat(Module &M, GlobalObject *GO, } namespace { + template <RewriteDescriptor::Type DT, typename ValueType, - ValueType *(llvm::Module::*Get)(StringRef) const> + ValueType *(Module::*Get)(StringRef) const> class ExplicitRewriteDescriptor : public RewriteDescriptor { public: const std::string Source; @@ -110,8 +127,10 @@ public: } }; +} // end anonymous namespace + template <RewriteDescriptor::Type DT, typename ValueType, - ValueType *(llvm::Module::*Get)(StringRef) const> + ValueType *(Module::*Get)(StringRef) const> bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) { bool Changed = false; if (ValueType *S = (M.*Get)(Source)) { @@ -128,10 +147,12 @@ bool ExplicitRewriteDescriptor<DT, ValueType, Get>::performOnModule(Module &M) { return Changed; } +namespace { + template <RewriteDescriptor::Type DT, typename ValueType, - ValueType *(llvm::Module::*Get)(StringRef) const, + ValueType *(Module::*Get)(StringRef) const, iterator_range<typename iplist<ValueType>::iterator> - (llvm::Module::*Iterator)()> + (Module::*Iterator)()> class PatternRewriteDescriptor : public RewriteDescriptor { public: const std::string Pattern; @@ -147,10 +168,12 @@ public: } }; +} // end anonymous namespace + template <RewriteDescriptor::Type DT, typename ValueType, - ValueType *(llvm::Module::*Get)(StringRef) const, + ValueType *(Module::*Get)(StringRef) const, iterator_range<typename iplist<ValueType>::iterator> - (llvm::Module::*Iterator)()> + (Module::*Iterator)()> bool PatternRewriteDescriptor<DT, ValueType, Get, Iterator>:: performOnModule(Module &M) { bool Changed = false; @@ -178,55 +201,52 @@ performOnModule(Module &M) { return Changed; } +namespace { + /// Represents a rewrite for an explicitly named (function) symbol. Both the /// source function name and target function name of the transformation are /// explicitly spelt out. -typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, - llvm::Function, &llvm::Module::getFunction> - ExplicitRewriteFunctionDescriptor; +using ExplicitRewriteFunctionDescriptor = + ExplicitRewriteDescriptor<RewriteDescriptor::Type::Function, Function, + &Module::getFunction>; /// Represents a rewrite for an explicitly named (global variable) symbol. Both /// the source variable name and target variable name are spelt out. This /// applies only to module level variables. -typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, - llvm::GlobalVariable, - &llvm::Module::getGlobalVariable> - ExplicitRewriteGlobalVariableDescriptor; +using ExplicitRewriteGlobalVariableDescriptor = + ExplicitRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + GlobalVariable, &Module::getGlobalVariable>; /// Represents a rewrite for an explicitly named global alias. Both the source /// and target name are explicitly spelt out. -typedef ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, - llvm::GlobalAlias, - &llvm::Module::getNamedAlias> - ExplicitRewriteNamedAliasDescriptor; +using ExplicitRewriteNamedAliasDescriptor = + ExplicitRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias, + &Module::getNamedAlias>; /// Represents a rewrite for a regular expression based pattern for functions. /// A pattern for the function name is provided and a transformation for that /// pattern to determine the target function name create the rewrite rule. -typedef PatternRewriteDescriptor<RewriteDescriptor::Type::Function, - llvm::Function, &llvm::Module::getFunction, - &llvm::Module::functions> - PatternRewriteFunctionDescriptor; +using PatternRewriteFunctionDescriptor = + PatternRewriteDescriptor<RewriteDescriptor::Type::Function, Function, + &Module::getFunction, &Module::functions>; /// Represents a rewrite for a global variable based upon a matching pattern. /// Each global variable matching the provided pattern will be transformed as /// described in the transformation pattern for the target. Applies only to /// module level variables. -typedef PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, - llvm::GlobalVariable, - &llvm::Module::getGlobalVariable, - &llvm::Module::globals> - PatternRewriteGlobalVariableDescriptor; +using PatternRewriteGlobalVariableDescriptor = + PatternRewriteDescriptor<RewriteDescriptor::Type::GlobalVariable, + GlobalVariable, &Module::getGlobalVariable, + &Module::globals>; /// PatternRewriteNamedAliasDescriptor - represents a rewrite for global /// aliases which match a given pattern. The provided transformation will be /// applied to each of the matching names. -typedef PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, - llvm::GlobalAlias, - &llvm::Module::getNamedAlias, - &llvm::Module::aliases> - PatternRewriteNamedAliasDescriptor; -} // namespace +using PatternRewriteNamedAliasDescriptor = + PatternRewriteDescriptor<RewriteDescriptor::Type::NamedAlias, GlobalAlias, + &Module::getNamedAlias, &Module::aliases>; + +} // end anonymous namespace bool RewriteMapParser::parse(const std::string &MapFile, RewriteDescriptorList *DL) { @@ -497,6 +517,7 @@ parseRewriteGlobalAliasDescriptor(yaml::Stream &YS, yaml::ScalarNode *K, } namespace { + class RewriteSymbolsLegacyPass : public ModulePass { public: static char ID; // Pass identification, replacement for typeid @@ -510,9 +531,11 @@ private: RewriteSymbolPass Impl; }; +} // end anonymous namespace + char RewriteSymbolsLegacyPass::ID = 0; -RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID), Impl() { +RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass() : ModulePass(ID) { initializeRewriteSymbolsLegacyPassPass(*PassRegistry::getPassRegistry()); } @@ -523,9 +546,7 @@ RewriteSymbolsLegacyPass::RewriteSymbolsLegacyPass( bool RewriteSymbolsLegacyPass::runOnModule(Module &M) { return Impl.runImpl(M); } -} -namespace llvm { PreservedAnalyses RewriteSymbolPass::run(Module &M, ModuleAnalysisManager &AM) { if (!runImpl(M)) return PreservedAnalyses::all(); @@ -550,7 +571,6 @@ void RewriteSymbolPass::loadAndParseMapFiles() { for (const auto &MapFile : MapFiles) Parser.parse(MapFile, &Descriptors); } -} INITIALIZE_PASS(RewriteSymbolsLegacyPass, "rewrite-symbols", "Rewrite Symbols", false, false) diff --git a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp index 9385f825523c..ed444e4cf43c 100644 --- a/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp +++ b/lib/Transforms/Utils/UnifyFunctionExitNodes.cpp @@ -15,7 +15,6 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h" -#include "llvm/ADT/StringExtras.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Function.h" #include "llvm/IR/Instructions.h" diff --git a/lib/Transforms/Utils/ValueMapper.cpp b/lib/Transforms/Utils/ValueMapper.cpp index 930972924c3c..8c9ecbc3503e 100644 --- a/lib/Transforms/Utils/ValueMapper.cpp +++ b/lib/Transforms/Utils/ValueMapper.cpp @@ -13,17 +13,36 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ValueMapper.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" #include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" -#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/GlobalObject.h" #include "llvm/IR/GlobalVariable.h" #include "llvm/IR/InlineAsm.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Metadata.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include <cassert> +#include <limits> +#include <memory> +#include <utility> + using namespace llvm; // Out of line method to get vtable etc for class. @@ -85,7 +104,6 @@ struct MappingContext { : VM(&VM), Materializer(Materializer) {} }; -class MDNodeMapper; class Mapper { friend class MDNodeMapper; @@ -175,7 +193,7 @@ class MDNodeMapper { /// Data about a node in \a UniquedGraph. struct Data { bool HasChanged = false; - unsigned ID = ~0u; + unsigned ID = std::numeric_limits<unsigned>::max(); TempMDNode Placeholder; }; @@ -316,7 +334,7 @@ private: void remapOperands(MDNode &N, OperandMapper mapOperand); }; -} // end namespace +} // end anonymous namespace Value *Mapper::mapValue(const Value *V) { ValueToValueMapTy::iterator I = getVM().find(V); @@ -579,6 +597,7 @@ void MDNodeMapper::remapOperands(MDNode &N, OperandMapper mapOperand) { } namespace { + /// An entry in the worklist for the post-order traversal. struct POTWorklistEntry { MDNode *N; ///< Current node. @@ -590,7 +609,8 @@ struct POTWorklistEntry { POTWorklistEntry(MDNode &N) : N(&N), Op(N.op_begin()) {} }; -} // end namespace + +} // end anonymous namespace bool MDNodeMapper::createPOT(UniquedGraph &G, const MDNode &FirstN) { assert(G.Info.empty() && "Expected a fresh traversal"); @@ -653,7 +673,7 @@ void MDNodeMapper::UniquedGraph::propagateChanges() { if (D.HasChanged) continue; - if (none_of(N->operands(), [&](const Metadata *Op) { + if (llvm::none_of(N->operands(), [&](const Metadata *Op) { auto Where = Info.find(Op); return Where != Info.end() && Where->second.HasChanged; })) @@ -752,10 +772,11 @@ struct MapMetadataDisabler { MapMetadataDisabler(ValueToValueMapTy &VM) : VM(VM) { VM.disableMapMetadata(); } + ~MapMetadataDisabler() { VM.enableMapMetadata(); } }; -} // end namespace +} // end anonymous namespace Optional<Metadata *> Mapper::mapSimpleMetadata(const Metadata *MD) { // If the value already exists in the map, use it. @@ -1037,11 +1058,13 @@ public: explicit FlushingMapper(void *pImpl) : M(*getAsMapper(pImpl)) { assert(!M.hasWorkToDo() && "Expected to be flushed"); } + ~FlushingMapper() { M.flush(); } + Mapper *operator->() const { return &M; } }; -} // end namespace +} // end anonymous namespace ValueMapper::ValueMapper(ValueToValueMapTy &VM, RemapFlags Flags, ValueMapTypeRemapper *TypeMapper, diff --git a/lib/Transforms/Vectorize/CMakeLists.txt b/lib/Transforms/Vectorize/CMakeLists.txt index 1aea73cd4a32..7622ed6d194f 100644 --- a/lib/Transforms/Vectorize/CMakeLists.txt +++ b/lib/Transforms/Vectorize/CMakeLists.txt @@ -3,6 +3,7 @@ add_llvm_library(LLVMVectorize LoopVectorize.cpp SLPVectorizer.cpp Vectorize.cpp + VPlan.cpp ADDITIONAL_HEADER_DIRS ${LLVM_MAIN_INCLUDE_DIR}/llvm/Transforms diff --git a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp index 9cf66382b581..dc83b6d4d292 100644 --- a/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp +++ b/lib/Transforms/Vectorize/LoadStoreVectorizer.cpp @@ -1,4 +1,4 @@ -//===----- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer ----------===// +//===- LoadStoreVectorizer.cpp - GPU Load & Store Vectorizer --------------===// // // The LLVM Compiler Infrastructure // @@ -6,47 +6,67 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// -// -//===----------------------------------------------------------------------===// +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/PostOrderIterator.h" -#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/Triple.h" +#include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" #include "llvm/Analysis/OrderedBasicBlock.h" #include "llvm/Analysis/ScalarEvolution.h" -#include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" -#include "llvm/Support/CommandLine.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Vectorize.h" +#include <algorithm> +#include <cassert> +#include <cstdlib> +#include <tuple> +#include <utility> using namespace llvm; #define DEBUG_TYPE "load-store-vectorizer" + STATISTIC(NumVectorInstructions, "Number of vector accesses generated"); STATISTIC(NumScalarsVectorized, "Number of scalar accesses vectorized"); -namespace { - // FIXME: Assuming stack alignment of 4 is always good enough static const unsigned StackAdjustedAlignment = 4; -typedef SmallVector<Instruction *, 8> InstrList; -typedef MapVector<Value *, InstrList> InstrListMap; + +namespace { + +using InstrList = SmallVector<Instruction *, 8>; +using InstrListMap = MapVector<Value *, InstrList>; class Vectorizer { Function &F; @@ -163,7 +183,10 @@ public: AU.setPreservesCFG(); } }; -} + +} // end anonymous namespace + +char LoadStoreVectorizer::ID = 0; INITIALIZE_PASS_BEGIN(LoadStoreVectorizer, DEBUG_TYPE, "Vectorize load and Store instructions", false, false) @@ -175,8 +198,6 @@ INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_END(LoadStoreVectorizer, DEBUG_TYPE, "Vectorize load and store instructions", false, false) -char LoadStoreVectorizer::ID = 0; - Pass *llvm::createLoadStoreVectorizerPass() { return new LoadStoreVectorizer(); } @@ -480,6 +501,10 @@ Vectorizer::getVectorizablePrefix(ArrayRef<Instruction *> Chain) { MemoryInstrs.push_back(&I); else ChainInstrs.push_back(&I); + } else if (isa<IntrinsicInst>(&I) && + cast<IntrinsicInst>(&I)->getIntrinsicID() == + Intrinsic::sideeffect) { + // Ignore llvm.sideeffect calls. } else if (IsLoadChain && (I.mayWriteToMemory() || I.mayThrow())) { DEBUG(dbgs() << "LSV: Found may-write/throw operation: " << I << '\n'); break; @@ -593,7 +618,14 @@ Vectorizer::collectInstructions(BasicBlock *BB) { // Skip weird non-byte sizes. They probably aren't worth the effort of // handling correctly. unsigned TySize = DL.getTypeSizeInBits(Ty); - if (TySize < 8) + if ((TySize % 8) != 0) + continue; + + // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain + // functions are currently using an integer type for the vectorized + // load/store, and does not support casting between the integer type and a + // vector of pointers (e.g. i64 to <2 x i16*>) + if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy()) continue; Value *Ptr = LI->getPointerOperand(); @@ -605,7 +637,7 @@ Vectorizer::collectInstructions(BasicBlock *BB) { continue; // Make sure all the users of a vector are constant-index extracts. - if (isa<VectorType>(Ty) && !all_of(LI->users(), [](const User *U) { + if (isa<VectorType>(Ty) && !llvm::all_of(LI->users(), [](const User *U) { const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U); return EEI && isa<ConstantInt>(EEI->getOperand(1)); })) @@ -614,7 +646,6 @@ Vectorizer::collectInstructions(BasicBlock *BB) { // Save the load locations. Value *ObjPtr = GetUnderlyingObject(Ptr, DL); LoadRefs[ObjPtr].push_back(LI); - } else if (StoreInst *SI = dyn_cast<StoreInst>(&I)) { if (!SI->isSimple()) continue; @@ -627,19 +658,28 @@ Vectorizer::collectInstructions(BasicBlock *BB) { if (!VectorType::isValidElementType(Ty->getScalarType())) continue; + // Skip vectors of pointers. The vectorizeLoadChain/vectorizeStoreChain + // functions are currently using an integer type for the vectorized + // load/store, and does not support casting between the integer type and a + // vector of pointers (e.g. i64 to <2 x i16*>) + if (Ty->isVectorTy() && Ty->isPtrOrPtrVectorTy()) + continue; + // Skip weird non-byte sizes. They probably aren't worth the effort of // handling correctly. unsigned TySize = DL.getTypeSizeInBits(Ty); - if (TySize < 8) + if ((TySize % 8) != 0) continue; Value *Ptr = SI->getPointerOperand(); unsigned AS = Ptr->getType()->getPointerAddressSpace(); unsigned VecRegSize = TTI.getLoadStoreVecRegBitWidth(AS); + + // No point in looking at these if they're too big to vectorize. if (TySize > VecRegSize / 2) continue; - if (isa<VectorType>(Ty) && !all_of(SI->users(), [](const User *U) { + if (isa<VectorType>(Ty) && !llvm::all_of(SI->users(), [](const User *U) { const ExtractElementInst *EEI = dyn_cast<ExtractElementInst>(U); return EEI && isa<ConstantInt>(EEI->getOperand(1)); })) @@ -680,8 +720,8 @@ bool Vectorizer::vectorizeInstructions(ArrayRef<Instruction *> Instrs) { SmallVector<int, 16> Heads, Tails; int ConsecutiveChain[64]; - // Do a quadratic search on all of the given stores and find all of the pairs - // of stores that follow each other. + // Do a quadratic search on all of the given loads/stores and find all of the + // pairs of loads/stores that follow each other. for (int i = 0, e = Instrs.size(); i < e; ++i) { ConsecutiveChain[i] = -1; for (int j = e - 1; j >= 0; --j) { @@ -748,7 +788,7 @@ bool Vectorizer::vectorizeStoreChain( SmallPtrSet<Instruction *, 16> *InstructionsProcessed) { StoreInst *S0 = cast<StoreInst>(Chain[0]); - // If the vector has an int element, default to int for the whole load. + // If the vector has an int element, default to int for the whole store. Type *StoreTy; for (Instruction *I : Chain) { StoreTy = cast<StoreInst>(I)->getValueOperand()->getType(); diff --git a/lib/Transforms/Vectorize/LoopVectorize.cpp b/lib/Transforms/Vectorize/LoopVectorize.cpp index 012b10c8a9b0..fbcdc0df0f1c 100644 --- a/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -47,62 +47,97 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Vectorize/LoopVectorize.h" +#include "VPlan.h" +#include "VPlanBuilder.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" #include "llvm/ADT/Hashing.h" #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/SCCIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" -#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" #include "llvm/Analysis/LoopInfo.h" #include "llvm/Analysis/LoopIterator.h" -#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpander.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" -#include "llvm/Analysis/ValueTracking.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" -#include "llvm/IR/DebugInfo.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/DiagnosticInfo.h" #include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" #include "llvm/IR/Module.h" -#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Operator.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" #include "llvm/IR/User.h" #include "llvm/IR/Value.h" #include "llvm/IR/ValueHandle.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" -#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" -#include "llvm/Transforms/Scalar.h" #include "llvm/Transforms/Utils/BasicBlockUtils.h" -#include "llvm/Transforms/Utils/Local.h" #include "llvm/Transforms/Utils/LoopSimplify.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Utils/LoopVersioning.h" -#include "llvm/Transforms/Vectorize.h" #include <algorithm> -#include <map> +#include <cassert> +#include <cstdint> +#include <cstdlib> +#include <functional> +#include <iterator> +#include <limits> +#include <memory> +#include <string> #include <tuple> +#include <utility> +#include <vector> using namespace llvm; -using namespace llvm::PatternMatch; #define LV_NAME "loop-vectorize" #define DEBUG_TYPE LV_NAME @@ -245,12 +280,12 @@ createMissedAnalysis(const char *PassName, StringRef RemarkName, Loop *TheLoop, namespace { -// Forward declarations. -class LoopVectorizeHints; class LoopVectorizationLegality; class LoopVectorizationCostModel; class LoopVectorizationRequirements; +} // end anonymous namespace + /// Returns true if the given loop body has a cycle, excluding the loop /// itself. static bool hasCyclesInLoopBody(const Loop &L) { @@ -324,7 +359,6 @@ static unsigned getMemInstAddressSpace(Value *I) { /// type is irregular if its allocated size doesn't equal the store size of an /// element of the corresponding vector type at the given vectorization factor. static bool hasIrregularType(Type *Ty, const DataLayout &DL, unsigned VF) { - // Determine if an array of VF elements of type Ty is "bitcast compatible" // with a <VF x Ty> vector. if (VF > 1) { @@ -349,7 +383,7 @@ static unsigned getReciprocalPredBlockProb() { return 2; } static Value *addFastMathFlag(Value *V) { if (isa<FPMathOperator>(V)) { FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); cast<Instruction>(V)->setFastMathFlags(Flags); } return V; @@ -362,6 +396,8 @@ static Constant *getSignedIntOrFpConstant(Type *Ty, int64_t C) { : ConstantFP::get(Ty, C); } +namespace llvm { + /// InnerLoopVectorizer vectorizes loops which contain only one basic /// block to a specified vectorization factor (VF). /// This class performs the widening of scalars into vectors, or multiple @@ -387,16 +423,16 @@ public: LoopVectorizationCostModel *CM) : OrigLoop(OrigLoop), PSE(PSE), LI(LI), DT(DT), TLI(TLI), TTI(TTI), AC(AC), ORE(ORE), VF(VecWidth), UF(UnrollFactor), - Builder(PSE.getSE()->getContext()), Induction(nullptr), - OldInduction(nullptr), VectorLoopValueMap(UnrollFactor, VecWidth), - TripCount(nullptr), VectorTripCount(nullptr), Legal(LVL), Cost(CM), - AddedSafetyChecks(false) {} + Builder(PSE.getSE()->getContext()), + VectorLoopValueMap(UnrollFactor, VecWidth), Legal(LVL), Cost(CM) {} + virtual ~InnerLoopVectorizer() = default; /// Create a new empty loop. Unlink the old loop and connect the new one. - void createVectorizedLoopSkeleton(); + /// Return the pre-header block of the new loop. + BasicBlock *createVectorizedLoopSkeleton(); - /// Vectorize a single instruction within the innermost loop. - void vectorizeInstruction(Instruction &I); + /// Widen a single instruction within the innermost loop. + void widenInstruction(Instruction &I); /// Fix the vectorized code, taking care of header phi's, live-outs, and more. void fixVectorizedLoop(); @@ -404,28 +440,83 @@ public: // Return true if any runtime check is added. bool areSafetyChecksAdded() { return AddedSafetyChecks; } - virtual ~InnerLoopVectorizer() {} - -protected: - /// A small list of PHINodes. - typedef SmallVector<PHINode *, 4> PhiVector; - /// A type for vectorized values in the new loop. Each value from the /// original loop, when vectorized, is represented by UF vector values in the /// new unrolled loop, where UF is the unroll factor. - typedef SmallVector<Value *, 2> VectorParts; + using VectorParts = SmallVector<Value *, 2>; + + /// Vectorize a single PHINode in a block. This method handles the induction + /// variable canonicalization. It supports both VF = 1 for unrolled loops and + /// arbitrary length vectors. + void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF); + + /// A helper function to scalarize a single Instruction in the innermost loop. + /// Generates a sequence of scalar instances for each lane between \p MinLane + /// and \p MaxLane, times each part between \p MinPart and \p MaxPart, + /// inclusive.. + void scalarizeInstruction(Instruction *Instr, const VPIteration &Instance, + bool IfPredicateInstr); + + /// Widen an integer or floating-point induction variable \p IV. If \p Trunc + /// is provided, the integer induction variable will first be truncated to + /// the corresponding type. + void widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc = nullptr); + + /// getOrCreateVectorValue and getOrCreateScalarValue coordinate to generate a + /// vector or scalar value on-demand if one is not yet available. When + /// vectorizing a loop, we visit the definition of an instruction before its + /// uses. When visiting the definition, we either vectorize or scalarize the + /// instruction, creating an entry for it in the corresponding map. (In some + /// cases, such as induction variables, we will create both vector and scalar + /// entries.) Then, as we encounter uses of the definition, we derive values + /// for each scalar or vector use unless such a value is already available. + /// For example, if we scalarize a definition and one of its uses is vector, + /// we build the required vector on-demand with an insertelement sequence + /// when visiting the use. Otherwise, if the use is scalar, we can use the + /// existing scalar definition. + /// + /// Return a value in the new loop corresponding to \p V from the original + /// loop at unroll index \p Part. If the value has already been vectorized, + /// the corresponding vector entry in VectorLoopValueMap is returned. If, + /// however, the value has a scalar entry in VectorLoopValueMap, we construct + /// a new vector value on-demand by inserting the scalar values into a vector + /// with an insertelement sequence. If the value has been neither vectorized + /// nor scalarized, it must be loop invariant, so we simply broadcast the + /// value into a vector. + Value *getOrCreateVectorValue(Value *V, unsigned Part); + + /// Return a value in the new loop corresponding to \p V from the original + /// loop at unroll and vector indices \p Instance. If the value has been + /// vectorized but not scalarized, the necessary extractelement instruction + /// will be generated. + Value *getOrCreateScalarValue(Value *V, const VPIteration &Instance); + + /// Construct the vector value of a scalarized value \p V one lane at a time. + void packScalarIntoVectorValue(Value *V, const VPIteration &Instance); + + /// Try to vectorize the interleaved access group that \p Instr belongs to. + void vectorizeInterleaveGroup(Instruction *Instr); + + /// Vectorize Load and Store instructions, optionally masking the vector + /// operations if \p BlockInMask is non-null. + void vectorizeMemoryInstruction(Instruction *Instr, + VectorParts *BlockInMask = nullptr); + + /// \brief Set the debug location in the builder using the debug location in + /// the instruction. + void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr); + +protected: + friend class LoopVectorizationPlanner; + + /// A small list of PHINodes. + using PhiVector = SmallVector<PHINode *, 4>; /// A type for scalarized values in the new loop. Each value from the /// original loop, when scalarized, is represented by UF x VF scalar values /// in the new unrolled loop, where UF is the unroll factor and VF is the /// vectorization factor. - typedef SmallVector<SmallVector<Value *, 4>, 2> ScalarParts; - - // When we if-convert we need to create edge masks. We have to cache values - // so that we don't end up with exponential recursion/IR. - typedef DenseMap<std::pair<BasicBlock *, BasicBlock *>, VectorParts> - EdgeMaskCacheTy; - typedef DenseMap<BasicBlock *, VectorParts> BlockMaskCacheTy; + using ScalarParts = SmallVector<SmallVector<Value *, 4>, 2>; /// Set up the values of the IVs correctly when exiting the vector loop. void fixupIVUsers(PHINode *OrigPhi, const InductionDescriptor &II, @@ -457,40 +548,14 @@ protected: /// the block that was created for it. void sinkScalarOperands(Instruction *PredInst); - /// Predicate conditional instructions that require predication on their - /// respective conditions. - void predicateInstructions(); - /// Shrinks vector element sizes to the smallest bitwidth they can be legally /// represented as. void truncateToMinimalBitwidths(); - /// A helper function that computes the predicate of the block BB, assuming - /// that the header block of the loop is set to True. It returns the *entry* - /// mask for the block BB. - VectorParts createBlockInMask(BasicBlock *BB); - /// A helper function that computes the predicate of the edge between SRC - /// and DST. - VectorParts createEdgeMask(BasicBlock *Src, BasicBlock *Dst); - - /// Vectorize a single PHINode in a block. This method handles the induction - /// variable canonicalization. It supports both VF = 1 for unrolled loops and - /// arbitrary length vectors. - void widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF); - /// Insert the new loop to the loop hierarchy and pass manager /// and update the analysis passes. void updateAnalysis(); - /// This instruction is un-vectorizable. Implement it as a sequence - /// of scalars. If \p IfPredicateInstr is true we need to 'hide' each - /// scalarized instruction behind an if block predicated on the control - /// dependence of the instruction. - void scalarizeInstruction(Instruction *Instr, bool IfPredicateInstr = false); - - /// Vectorize Load and Store instructions, - virtual void vectorizeMemoryInstruction(Instruction *Instr); - /// Create a broadcast instruction. This method generates a broadcast /// instruction (shuffle) for loop invariant values and for the induction /// value. If this is the induction variable then we extend it to N, N+1, ... @@ -521,11 +586,6 @@ protected: void createVectorIntOrFpInductionPHI(const InductionDescriptor &II, Value *Step, Instruction *EntryVal); - /// Widen an integer or floating-point induction variable \p IV. If \p Trunc - /// is provided, the integer induction variable will first be truncated to - /// the corresponding type. - void widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc = nullptr); - /// Returns true if an instruction \p I should be scalarized instead of /// vectorized for the chosen vectorization factor. bool shouldScalarizeInstruction(Instruction *I) const; @@ -533,37 +593,19 @@ protected: /// Returns true if we should generate a scalar version of \p IV. bool needsScalarInduction(Instruction *IV) const; - /// getOrCreateVectorValue and getOrCreateScalarValue coordinate to generate a - /// vector or scalar value on-demand if one is not yet available. When - /// vectorizing a loop, we visit the definition of an instruction before its - /// uses. When visiting the definition, we either vectorize or scalarize the - /// instruction, creating an entry for it in the corresponding map. (In some - /// cases, such as induction variables, we will create both vector and scalar - /// entries.) Then, as we encounter uses of the definition, we derive values - /// for each scalar or vector use unless such a value is already available. - /// For example, if we scalarize a definition and one of its uses is vector, - /// we build the required vector on-demand with an insertelement sequence - /// when visiting the use. Otherwise, if the use is scalar, we can use the - /// existing scalar definition. - /// - /// Return a value in the new loop corresponding to \p V from the original - /// loop at unroll index \p Part. If the value has already been vectorized, - /// the corresponding vector entry in VectorLoopValueMap is returned. If, - /// however, the value has a scalar entry in VectorLoopValueMap, we construct - /// a new vector value on-demand by inserting the scalar values into a vector - /// with an insertelement sequence. If the value has been neither vectorized - /// nor scalarized, it must be loop invariant, so we simply broadcast the - /// value into a vector. - Value *getOrCreateVectorValue(Value *V, unsigned Part); - - /// Return a value in the new loop corresponding to \p V from the original - /// loop at unroll index \p Part and vector index \p Lane. If the value has - /// been vectorized but not scalarized, the necessary extractelement - /// instruction will be generated. - Value *getOrCreateScalarValue(Value *V, unsigned Part, unsigned Lane); - - /// Try to vectorize the interleaved access group that \p Instr belongs to. - void vectorizeInterleaveGroup(Instruction *Instr); + /// If there is a cast involved in the induction variable \p ID, which should + /// be ignored in the vectorized loop body, this function records the + /// VectorLoopValue of the respective Phi also as the VectorLoopValue of the + /// cast. We had already proved that the casted Phi is equal to the uncasted + /// Phi in the vectorized loop (under a runtime guard), and therefore + /// there is no need to vectorize the cast - the same value can be used in the + /// vector loop for both the Phi and the cast. + /// If \p VectorLoopValue is a scalarized value, \p Lane is also specified, + /// Otherwise, \p VectorLoopValue is a widened/vectorized value. + void recordVectorLoopValueForInductionCast (const InductionDescriptor &ID, + Value *VectorLoopValue, + unsigned Part, + unsigned Lane = UINT_MAX); /// Generate a shuffle sequence that will reverse the vector Vec. virtual Value *reverseVector(Value *Vec); @@ -574,12 +616,19 @@ protected: /// Returns (and creates if needed) the trip count of the widened loop. Value *getOrCreateVectorTripCount(Loop *NewLoop); + /// Returns a bitcasted value to the requested vector type. + /// Also handles bitcasts of vector<float> <-> vector<pointer> types. + Value *createBitOrPointerCast(Value *V, VectorType *DstVTy, + const DataLayout &DL); + /// Emit a bypass check to see if the vector trip count is zero, including if /// it overflows. void emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass); + /// Emit a bypass check to see if all of the SCEV assumptions we've /// had to make are correct. void emitSCEVChecks(Loop *L, BasicBlock *Bypass); + /// Emit bypass checks to check any memory assumptions we may have made. void emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass); @@ -601,149 +650,32 @@ protected: /// vector of instructions. void addMetadata(ArrayRef<Value *> To, Instruction *From); - /// \brief Set the debug location in the builder using the debug location in - /// the instruction. - void setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr); - - /// This is a helper class for maintaining vectorization state. It's used for - /// mapping values from the original loop to their corresponding values in - /// the new loop. Two mappings are maintained: one for vectorized values and - /// one for scalarized values. Vectorized values are represented with UF - /// vector values in the new loop, and scalarized values are represented with - /// UF x VF scalar values in the new loop. UF and VF are the unroll and - /// vectorization factors, respectively. - /// - /// Entries can be added to either map with setVectorValue and setScalarValue, - /// which assert that an entry was not already added before. If an entry is to - /// replace an existing one, call resetVectorValue. This is currently needed - /// to modify the mapped values during "fix-up" operations that occur once the - /// first phase of widening is complete. These operations include type - /// truncation and the second phase of recurrence widening. - /// - /// Entries from either map can be retrieved using the getVectorValue and - /// getScalarValue functions, which assert that the desired value exists. - - struct ValueMap { - - /// Construct an empty map with the given unroll and vectorization factors. - ValueMap(unsigned UF, unsigned VF) : UF(UF), VF(VF) {} - - /// \return True if the map has any vector entry for \p Key. - bool hasAnyVectorValue(Value *Key) const { - return VectorMapStorage.count(Key); - } - - /// \return True if the map has a vector entry for \p Key and \p Part. - bool hasVectorValue(Value *Key, unsigned Part) const { - assert(Part < UF && "Queried Vector Part is too large."); - if (!hasAnyVectorValue(Key)) - return false; - const VectorParts &Entry = VectorMapStorage.find(Key)->second; - assert(Entry.size() == UF && "VectorParts has wrong dimensions."); - return Entry[Part] != nullptr; - } - - /// \return True if the map has any scalar entry for \p Key. - bool hasAnyScalarValue(Value *Key) const { - return ScalarMapStorage.count(Key); - } - - /// \return True if the map has a scalar entry for \p Key, \p Part and - /// \p Part. - bool hasScalarValue(Value *Key, unsigned Part, unsigned Lane) const { - assert(Part < UF && "Queried Scalar Part is too large."); - assert(Lane < VF && "Queried Scalar Lane is too large."); - if (!hasAnyScalarValue(Key)) - return false; - const ScalarParts &Entry = ScalarMapStorage.find(Key)->second; - assert(Entry.size() == UF && "ScalarParts has wrong dimensions."); - assert(Entry[Part].size() == VF && "ScalarParts has wrong dimensions."); - return Entry[Part][Lane] != nullptr; - } - - /// Retrieve the existing vector value that corresponds to \p Key and - /// \p Part. - Value *getVectorValue(Value *Key, unsigned Part) { - assert(hasVectorValue(Key, Part) && "Getting non-existent value."); - return VectorMapStorage[Key][Part]; - } - - /// Retrieve the existing scalar value that corresponds to \p Key, \p Part - /// and \p Lane. - Value *getScalarValue(Value *Key, unsigned Part, unsigned Lane) { - assert(hasScalarValue(Key, Part, Lane) && "Getting non-existent value."); - return ScalarMapStorage[Key][Part][Lane]; - } - - /// Set a vector value associated with \p Key and \p Part. Assumes such a - /// value is not already set. If it is, use resetVectorValue() instead. - void setVectorValue(Value *Key, unsigned Part, Value *Vector) { - assert(!hasVectorValue(Key, Part) && "Vector value already set for part"); - if (!VectorMapStorage.count(Key)) { - VectorParts Entry(UF); - VectorMapStorage[Key] = Entry; - } - VectorMapStorage[Key][Part] = Vector; - } - - /// Set a scalar value associated with \p Key for \p Part and \p Lane. - /// Assumes such a value is not already set. - void setScalarValue(Value *Key, unsigned Part, unsigned Lane, - Value *Scalar) { - assert(!hasScalarValue(Key, Part, Lane) && "Scalar value already set"); - if (!ScalarMapStorage.count(Key)) { - ScalarParts Entry(UF); - for (unsigned Part = 0; Part < UF; ++Part) - Entry[Part].resize(VF, nullptr); - // TODO: Consider storing uniform values only per-part, as they occupy - // lane 0 only, keeping the other VF-1 redundant entries null. - ScalarMapStorage[Key] = Entry; - } - ScalarMapStorage[Key][Part][Lane] = Scalar; - } - - /// Reset the vector value associated with \p Key for the given \p Part. - /// This function can be used to update values that have already been - /// vectorized. This is the case for "fix-up" operations including type - /// truncation and the second phase of recurrence vectorization. - void resetVectorValue(Value *Key, unsigned Part, Value *Vector) { - assert(hasVectorValue(Key, Part) && "Vector value not set for part"); - VectorMapStorage[Key][Part] = Vector; - } - - private: - /// The unroll factor. Each entry in the vector map contains UF vector - /// values. - unsigned UF; - - /// The vectorization factor. Each entry in the scalar map contains UF x VF - /// scalar values. - unsigned VF; - - /// The vector and scalar map storage. We use std::map and not DenseMap - /// because insertions to DenseMap invalidate its iterators. - std::map<Value *, VectorParts> VectorMapStorage; - std::map<Value *, ScalarParts> ScalarMapStorage; - }; - /// The original loop. Loop *OrigLoop; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. Applies /// dynamic knowledge to simplify SCEV expressions and converts them to a /// more usable form. PredicatedScalarEvolution &PSE; + /// Loop Info. LoopInfo *LI; + /// Dominator Tree. DominatorTree *DT; + /// Alias Analysis. AliasAnalysis *AA; + /// Target Library Info. const TargetLibraryInfo *TLI; + /// Target Transform Info. const TargetTransformInfo *TTI; + /// Assumption Cache. AssumptionCache *AC; + /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; @@ -758,7 +690,6 @@ protected: /// vector elements. unsigned VF; -protected: /// The vectorization unroll factor to use. Each scalar is vectorized to this /// many different vector instructions. unsigned UF; @@ -770,39 +701,45 @@ protected: /// The vector-loop preheader. BasicBlock *LoopVectorPreHeader; + /// The scalar-loop preheader. BasicBlock *LoopScalarPreHeader; + /// Middle Block between the vector and the scalar. BasicBlock *LoopMiddleBlock; + /// The ExitBlock of the scalar loop. BasicBlock *LoopExitBlock; + /// The vector loop body. BasicBlock *LoopVectorBody; + /// The scalar loop body. BasicBlock *LoopScalarBody; + /// A list of all bypass blocks. The first block is the entry of the loop. SmallVector<BasicBlock *, 4> LoopBypassBlocks; /// The new Induction variable which was added to the new block. - PHINode *Induction; + PHINode *Induction = nullptr; + /// The induction variable of the old basic block. - PHINode *OldInduction; + PHINode *OldInduction = nullptr; /// Maps values from the original loop to their corresponding values in the /// vectorized loop. A key value can map to either vector values, scalar /// values or both kinds of values, depending on whether the key was /// vectorized and scalarized. - ValueMap VectorLoopValueMap; + VectorizerValueMap VectorLoopValueMap; + + /// Store instructions that were predicated. + SmallVector<Instruction *, 4> PredicatedInstructions; - /// Store instructions that should be predicated, as a pair - /// <StoreInst, Predicate> - SmallVector<std::pair<Instruction *, Value *>, 4> PredicatedInstructions; - EdgeMaskCacheTy EdgeMaskCache; - BlockMaskCacheTy BlockMaskCache; /// Trip count of the original loop. - Value *TripCount; + Value *TripCount = nullptr; + /// Trip count of the widened loop (TripCount - TripCount % (VF*UF)) - Value *VectorTripCount; + Value *VectorTripCount = nullptr; /// The legality analysis. LoopVectorizationLegality *Legal; @@ -811,7 +748,7 @@ protected: LoopVectorizationCostModel *Cost; // Record whether runtime checks are added. - bool AddedSafetyChecks; + bool AddedSafetyChecks = false; // Holds the end values for each induction variable. We save the end values // so we can later fix-up the external users of the induction variables. @@ -831,7 +768,6 @@ public: UnrollFactor, LVL, CM) {} private: - void vectorizeMemoryInstruction(Instruction *Instr) override; Value *getBroadcastInstrs(Value *V) override; Value *getStepVector(Value *Val, int StartIdx, Value *Step, Instruction::BinaryOps Opcode = @@ -839,6 +775,8 @@ private: Value *reverseVector(Value *Vec) override; }; +} // end namespace llvm + /// \brief Look for a meaningful debug location on the instruction or it's /// operands. static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { @@ -861,7 +799,8 @@ static Instruction *getDebugLocFromInstOrOperands(Instruction *I) { void InnerLoopVectorizer::setDebugLocFromInst(IRBuilder<> &B, const Value *Ptr) { if (const Instruction *Inst = dyn_cast_or_null<Instruction>(Ptr)) { const DILocation *DIL = Inst->getDebugLoc(); - if (DIL && Inst->getFunction()->isDebugInfoForProfiling()) + if (DIL && Inst->getFunction()->isDebugInfoForProfiling() && + !isa<DbgInfoIntrinsic>(Inst)) B.SetCurrentDebugLocation(DIL->cloneWithDuplicationFactor(UF * VF)); else B.SetCurrentDebugLocation(DIL); @@ -908,6 +847,8 @@ void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To, } } +namespace llvm { + /// \brief The group of interleaved loads/stores sharing the same stride and /// close to each other. /// @@ -937,7 +878,7 @@ void InnerLoopVectorizer::addMetadata(ArrayRef<Value *> To, class InterleaveGroup { public: InterleaveGroup(Instruction *Instr, int Stride, unsigned Align) - : Align(Align), SmallestKey(0), LargestKey(0), InsertPos(Instr) { + : Align(Align), InsertPos(Instr) { assert(Align && "The alignment should be non-zero"); Factor = std::abs(Stride); @@ -1010,13 +951,26 @@ public: Instruction *getInsertPos() const { return InsertPos; } void setInsertPos(Instruction *Inst) { InsertPos = Inst; } + /// Add metadata (e.g. alias info) from the instructions in this group to \p + /// NewInst. + /// + /// FIXME: this function currently does not add noalias metadata a'la + /// addNewMedata. To do that we need to compute the intersection of the + /// noalias info from all members. + void addMetadata(Instruction *NewInst) const { + SmallVector<Value *, 4> VL; + std::transform(Members.begin(), Members.end(), std::back_inserter(VL), + [](std::pair<int, Instruction *> p) { return p.second; }); + propagateMetadata(NewInst, VL); + } + private: unsigned Factor; // Interleave Factor. bool Reverse; unsigned Align; DenseMap<int, Instruction *> Members; - int SmallestKey; - int LargestKey; + int SmallestKey = 0; + int LargestKey = 0; // To avoid breaking dependences, vectorized instructions of an interleave // group should be inserted at either the first load or the last store in @@ -1031,6 +985,9 @@ private: // store i32 %odd // Insert Position Instruction *InsertPos; }; +} // end namespace llvm + +namespace { /// \brief Drive the analysis of interleaved memory accesses in the loop. /// @@ -1044,8 +1001,7 @@ class InterleavedAccessInfo { public: InterleavedAccessInfo(PredicatedScalarEvolution &PSE, Loop *L, DominatorTree *DT, LoopInfo *LI) - : PSE(PSE), TheLoop(L), DT(DT), LI(LI), LAI(nullptr), - RequiresScalarEpilogue(false) {} + : PSE(PSE), TheLoop(L), DT(DT), LI(LI) {} ~InterleavedAccessInfo() { SmallSet<InterleaveGroup *, 4> DelSet; @@ -1065,14 +1021,6 @@ public: return InterleaveGroupMap.count(Instr); } - /// \brief Return the maximum interleave factor of all interleaved groups. - unsigned getMaxInterleaveFactor() const { - unsigned MaxFactor = 1; - for (auto &Entry : InterleaveGroupMap) - MaxFactor = std::max(MaxFactor, Entry.second->getFactor()); - return MaxFactor; - } - /// \brief Get the interleave group that \p Instr belongs to. /// /// \returns nullptr if doesn't have such group. @@ -1095,15 +1043,16 @@ private: /// The interleaved access analysis can also add new predicates (for example /// by versioning strides of pointers). PredicatedScalarEvolution &PSE; + Loop *TheLoop; DominatorTree *DT; LoopInfo *LI; - const LoopAccessInfo *LAI; + const LoopAccessInfo *LAI = nullptr; /// True if the loop may contain non-reversed interleaved groups with /// out-of-bounds accesses. We ensure we don't speculatively access memory /// out-of-bounds by executing at least one scalar epilogue iteration. - bool RequiresScalarEpilogue; + bool RequiresScalarEpilogue = false; /// Holds the relationships between the members and the interleave group. DenseMap<Instruction *, InterleaveGroup *> InterleaveGroupMap; @@ -1114,21 +1063,26 @@ private: /// \brief The descriptor for a strided memory access. struct StrideDescriptor { + StrideDescriptor() = default; StrideDescriptor(int64_t Stride, const SCEV *Scev, uint64_t Size, unsigned Align) : Stride(Stride), Scev(Scev), Size(Size), Align(Align) {} - StrideDescriptor() = default; - // The access's stride. It is negative for a reverse access. int64_t Stride = 0; - const SCEV *Scev = nullptr; // The scalar expression of this access - uint64_t Size = 0; // The size of the memory object. - unsigned Align = 0; // The alignment of this access. + + // The scalar expression of this access. + const SCEV *Scev = nullptr; + + // The size of the memory object. + uint64_t Size = 0; + + // The alignment of this access. + unsigned Align = 0; }; /// \brief A type for holding instructions and their stride descriptors. - typedef std::pair<Instruction *, StrideDescriptor> StrideEntry; + using StrideEntry = std::pair<Instruction *, StrideDescriptor>; /// \brief Create a new interleave group with the given instruction \p Instr, /// stride \p Stride and alignment \p Align. @@ -1179,7 +1133,6 @@ private: /// not necessary or is prevented because \p A and \p B may be dependent. bool canReorderMemAccessesForInterleavedGroups(StrideEntry *A, StrideEntry *B) const { - // Code motion for interleaved accesses can potentially hoist strided loads // and sink strided stores. The code below checks the legality of the // following two conditions: @@ -1244,7 +1197,7 @@ private: /// for example 'force', means a decision has been made. So, we need to be /// careful NOT to add them if the user hasn't specifically asked so. class LoopVectorizeHints { - enum HintKind { HK_WIDTH, HK_UNROLL, HK_FORCE }; + enum HintKind { HK_WIDTH, HK_UNROLL, HK_FORCE, HK_ISVECTORIZED }; /// Hint - associates name and validation with the hint value. struct Hint { @@ -1263,6 +1216,8 @@ class LoopVectorizeHints { return isPowerOf2_32(Val) && Val <= MaxInterleaveFactor; case HK_FORCE: return (Val <= 1); + case HK_ISVECTORIZED: + return (Val==0 || Val==1); } return false; } @@ -1270,16 +1225,21 @@ class LoopVectorizeHints { /// Vectorization width. Hint Width; + /// Vectorization interleave factor. Hint Interleave; + /// Vectorization forced Hint Force; + /// Already Vectorized + Hint IsVectorized; + /// Return the loop metadata prefix. static StringRef Prefix() { return "llvm.loop."; } /// True if there is any unsafe math in the loop. - bool PotentiallyUnsafe; + bool PotentiallyUnsafe = false; public: enum ForceKind { @@ -1294,7 +1254,7 @@ public: HK_WIDTH), Interleave("interleave.count", DisableInterleaving, HK_UNROLL), Force("vectorize.enable", FK_Undefined, HK_FORCE), - PotentiallyUnsafe(false), TheLoop(L), ORE(ORE) { + IsVectorized("isvectorized", 0, HK_ISVECTORIZED), TheLoop(L), ORE(ORE) { // Populate values with existing loop metadata. getHintsFromMetadata(); @@ -1302,14 +1262,19 @@ public: if (VectorizerParams::isInterleaveForced()) Interleave.Value = VectorizerParams::VectorizationInterleave; + if (IsVectorized.Value != 1) + // If the vectorization width and interleaving count are both 1 then + // consider the loop to have been already vectorized because there's + // nothing more that we can do. + IsVectorized.Value = Width.Value == 1 && Interleave.Value == 1; DEBUG(if (DisableInterleaving && Interleave.Value == 1) dbgs() << "LV: Interleaving disabled by the pass manager\n"); } /// Mark the loop L as already vectorized by setting the width to 1. void setAlreadyVectorized() { - Width.Value = Interleave.Value = 1; - Hint Hints[] = {Width, Interleave}; + IsVectorized.Value = 1; + Hint Hints[] = {IsVectorized}; writeHintsToMetadata(Hints); } @@ -1326,19 +1291,19 @@ public: return false; } - if (getWidth() == 1 && getInterleave() == 1) { - // FIXME: Add a separate metadata to indicate when the loop has already - // been vectorized instead of setting width and count to 1. + if (getIsVectorized() == 1) { DEBUG(dbgs() << "LV: Not vectorizing: Disabled/already vectorized.\n"); // FIXME: Add interleave.disable metadata. This will allow // vectorize.disable to be used without disabling the pass and errors // to differentiate between disabled vectorization and a width of 1. - ORE.emit(OptimizationRemarkAnalysis(vectorizeAnalysisPassName(), + ORE.emit([&]() { + return OptimizationRemarkAnalysis(vectorizeAnalysisPassName(), "AllDisabled", L->getStartLoc(), L->getHeader()) << "loop not vectorized: vectorization and interleaving are " - "explicitly disabled, or vectorize width and interleave " - "count are both set to 1"); + "explicitly disabled, or the loop has already been " + "vectorized"; + }); return false; } @@ -1348,29 +1313,35 @@ public: /// Dumps all the hint information. void emitRemarkWithHints() const { using namespace ore; - if (Force.Value == LoopVectorizeHints::FK_Disabled) - ORE.emit(OptimizationRemarkMissed(LV_NAME, "MissedExplicitlyDisabled", + + ORE.emit([&]() { + if (Force.Value == LoopVectorizeHints::FK_Disabled) + return OptimizationRemarkMissed(LV_NAME, "MissedExplicitlyDisabled", TheLoop->getStartLoc(), TheLoop->getHeader()) - << "loop not vectorized: vectorization is explicitly disabled"); - else { - OptimizationRemarkMissed R(LV_NAME, "MissedDetails", - TheLoop->getStartLoc(), TheLoop->getHeader()); - R << "loop not vectorized"; - if (Force.Value == LoopVectorizeHints::FK_Enabled) { - R << " (Force=" << NV("Force", true); - if (Width.Value != 0) - R << ", Vector Width=" << NV("VectorWidth", Width.Value); - if (Interleave.Value != 0) - R << ", Interleave Count=" << NV("InterleaveCount", Interleave.Value); - R << ")"; + << "loop not vectorized: vectorization is explicitly disabled"; + else { + OptimizationRemarkMissed R(LV_NAME, "MissedDetails", + TheLoop->getStartLoc(), + TheLoop->getHeader()); + R << "loop not vectorized"; + if (Force.Value == LoopVectorizeHints::FK_Enabled) { + R << " (Force=" << NV("Force", true); + if (Width.Value != 0) + R << ", Vector Width=" << NV("VectorWidth", Width.Value); + if (Interleave.Value != 0) + R << ", Interleave Count=" + << NV("InterleaveCount", Interleave.Value); + R << ")"; + } + return R; } - ORE.emit(R); - } + }); } unsigned getWidth() const { return Width.Value; } unsigned getInterleave() const { return Interleave.Value; } + unsigned getIsVectorized() const { return IsVectorized.Value; } enum ForceKind getForce() const { return (ForceKind)Force.Value; } /// \brief If hints are provided that force vectorization, use the AlwaysPrint @@ -1454,7 +1425,7 @@ private: return; unsigned Val = C->getZExtValue(); - Hint *Hints[] = {&Width, &Interleave, &Force}; + Hint *Hints[] = {&Width, &Interleave, &Force, &IsVectorized}; for (auto H : Hints) { if (Name == H->Name) { if (H->validate(Val)) @@ -1489,7 +1460,7 @@ private: /// Sets current hints into loop metadata, keeping other values intact. void writeHintsToMetadata(ArrayRef<Hint> HintTypes) { - if (HintTypes.size() == 0) + if (HintTypes.empty()) return; // Reserve the first element to LoopID (see below). @@ -1525,6 +1496,8 @@ private: OptimizationRemarkEmitter &ORE; }; +} // end anonymous namespace + static void emitMissedWarning(Function *F, Loop *L, const LoopVectorizeHints &LH, OptimizationRemarkEmitter *ORE) { @@ -1546,6 +1519,8 @@ static void emitMissedWarning(Function *F, Loop *L, } } +namespace { + /// LoopVectorizationLegality checks if it is legal to vectorize a loop, and /// to what vectorization factor. /// This class does not look at the profitability of vectorization, only the @@ -1568,22 +1543,20 @@ public: std::function<const LoopAccessInfo &(Loop &)> *GetLAA, LoopInfo *LI, OptimizationRemarkEmitter *ORE, LoopVectorizationRequirements *R, LoopVectorizeHints *H) - : NumPredStores(0), TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT), - GetLAA(GetLAA), LAI(nullptr), ORE(ORE), InterleaveInfo(PSE, L, DT, LI), - PrimaryInduction(nullptr), WidestIndTy(nullptr), HasFunNoNaNAttr(false), - Requirements(R), Hints(H) {} + : TheLoop(L), PSE(PSE), TLI(TLI), TTI(TTI), DT(DT), GetLAA(GetLAA), + ORE(ORE), InterleaveInfo(PSE, L, DT, LI), Requirements(R), Hints(H) {} /// ReductionList contains the reduction descriptors for all /// of the reductions that were found in the loop. - typedef DenseMap<PHINode *, RecurrenceDescriptor> ReductionList; + using ReductionList = DenseMap<PHINode *, RecurrenceDescriptor>; /// InductionList saves induction variables and maps them to the /// induction descriptor. - typedef MapVector<PHINode *, InductionDescriptor> InductionList; + using InductionList = MapVector<PHINode *, InductionDescriptor>; /// RecurrenceSet contains the phi nodes that are recurrences other than /// inductions and reductions. - typedef SmallPtrSet<const PHINode *, 8> RecurrenceSet; + using RecurrenceSet = SmallPtrSet<const PHINode *, 8>; /// Returns true if it is legal to vectorize this loop. /// This does not mean that it is profitable to vectorize this @@ -1608,7 +1581,17 @@ public: /// Returns the widest induction type. Type *getWidestInductionType() { return WidestIndTy; } - /// Returns True if V is an induction variable in this loop. + /// Returns True if V is a Phi node of an induction variable in this loop. + bool isInductionPhi(const Value *V); + + /// Returns True if V is a cast that is part of an induction def-use chain, + /// and had been proven to be redundant under a runtime guard (in other + /// words, the cast has the same SCEV expression as the induction phi). + bool isCastedInductionVariable(const Value *V); + + /// Returns True if V can be considered as an induction variable in this + /// loop. V can be the induction phi, or some redundant cast in the def-use + /// chain of the inducion phi. bool isInductionVariable(const Value *V); /// Returns True if PN is a reduction variable in this loop. @@ -1629,6 +1612,8 @@ public: /// 0 - Stride is unknown or non-consecutive. /// 1 - Address is consecutive. /// -1 - Address is consecutive, and decreasing. + /// NOTE: This method must only be used before modifying the original scalar + /// loop. Do not use after invoking 'createVectorizedLoopSkeleton' (PR34965). int isConsecutivePtr(Value *Ptr); /// Returns true if the value V is uniform within the loop. @@ -1646,11 +1631,6 @@ public: return InterleaveInfo.isInterleaved(Instr); } - /// \brief Return the maximum interleave factor of all interleaved groups. - unsigned getMaxInterleaveFactor() const { - return InterleaveInfo.getMaxInterleaveFactor(); - } - /// \brief Get the interleaved access group that \p Instr belongs to. const InterleaveGroup *getInterleavedAccessGroup(Instruction *Instr) { return InterleaveInfo.getInterleaveGroup(Instr); @@ -1664,6 +1644,10 @@ public: unsigned getMaxSafeDepDistBytes() { return LAI->getMaxSafeDepDistBytes(); } + uint64_t getMaxSafeRegisterWidth() const { + return LAI->getDepChecker().getMaxSafeRegisterWidth(); + } + bool hasStride(Value *V) { return LAI->hasStride(V); } /// Returns true if the target machine supports masked store operation @@ -1671,21 +1655,25 @@ public: bool isLegalMaskedStore(Type *DataType, Value *Ptr) { return isConsecutivePtr(Ptr) && TTI->isLegalMaskedStore(DataType); } + /// Returns true if the target machine supports masked load operation /// for the given \p DataType and kind of access to \p Ptr. bool isLegalMaskedLoad(Type *DataType, Value *Ptr) { return isConsecutivePtr(Ptr) && TTI->isLegalMaskedLoad(DataType); } + /// Returns true if the target machine supports masked scatter operation /// for the given \p DataType. bool isLegalMaskedScatter(Type *DataType) { return TTI->isLegalMaskedScatter(DataType); } + /// Returns true if the target machine supports masked gather operation /// for the given \p DataType. bool isLegalMaskedGather(Type *DataType) { return TTI->isLegalMaskedGather(DataType); } + /// Returns true if the target machine can represent \p V as a masked gather /// or scatter operation. bool isLegalGatherOrScatter(Value *V) { @@ -1701,6 +1689,7 @@ public: /// Returns true if vector representation of the instruction \p I /// requires mask. bool isMaskRequired(const Instruction *I) { return (MaskedOp.count(I) != 0); } + unsigned getNumStores() const { return LAI->getNumStores(); } unsigned getNumLoads() const { return LAI->getNumLoads(); } unsigned getNumPredStores() const { return NumPredStores; } @@ -1766,27 +1755,34 @@ private: return LAI ? &LAI->getSymbolicStrides() : nullptr; } - unsigned NumPredStores; + unsigned NumPredStores = 0; /// The loop that we evaluate. Loop *TheLoop; + /// A wrapper around ScalarEvolution used to add runtime SCEV checks. /// Applies dynamic knowledge to simplify SCEV expressions in the context /// of existing SCEV assumptions. The analysis will also add a minimal set /// of new predicates if this is required to enable vectorization and /// unrolling. PredicatedScalarEvolution &PSE; + /// Target Library Info. TargetLibraryInfo *TLI; + /// Target Transform Info const TargetTransformInfo *TTI; + /// Dominator Tree. DominatorTree *DT; + // LoopAccess analysis. std::function<const LoopAccessInfo &(Loop &)> *GetLAA; + // And the loop-accesses info corresponding to this loop. This pointer is // null until canVectorizeMemory sets it up. - const LoopAccessInfo *LAI; + const LoopAccessInfo *LAI = nullptr; + /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; @@ -1798,27 +1794,38 @@ private: /// Holds the primary induction variable. This is the counter of the /// loop. - PHINode *PrimaryInduction; + PHINode *PrimaryInduction = nullptr; + /// Holds the reduction variables. ReductionList Reductions; + /// Holds all of the induction variables that we found in the loop. /// Notice that inductions don't need to start at zero and that induction /// variables can be pointers. InductionList Inductions; + + /// Holds all the casts that participate in the update chain of the induction + /// variables, and that have been proven to be redundant (possibly under a + /// runtime guard). These casts can be ignored when creating the vectorized + /// loop body. + SmallPtrSet<Instruction *, 4> InductionCastsToIgnore; + /// Holds the phi nodes that are first-order recurrences. RecurrenceSet FirstOrderRecurrences; + /// Holds instructions that need to sink past other instructions to handle /// first-order recurrences. DenseMap<Instruction *, Instruction *> SinkAfter; + /// Holds the widest induction type encountered. - Type *WidestIndTy; + Type *WidestIndTy = nullptr; /// Allowed outside users. This holds the induction and reduction /// vars which can be accessed from outside the loop. SmallPtrSet<Value *, 4> AllowedExit; /// Can we assume the absence of NaNs. - bool HasFunNoNaNAttr; + bool HasFunNoNaNAttr = false; /// Vectorization requirements that will go through late-evaluation. LoopVectorizationRequirements *Requirements; @@ -1856,9 +1863,13 @@ public: /// Information about vectorization costs struct VectorizationFactor { - unsigned Width; // Vector width with best cost - unsigned Cost; // Cost of the loop with that width + // Vector width with best cost + unsigned Width; + + // Cost of the loop with that width + unsigned Cost; }; + /// \return The most profitable vectorization factor and the cost of that VF. /// This method checks every power of two up to MaxVF. If UserVF is not ZERO /// then this vectorization factor will be selected if vectorization is @@ -1897,8 +1908,10 @@ public: struct RegisterUsage { /// Holds the number of loop invariant values that are used in the loop. unsigned LoopInvariantRegs; + /// Holds the maximum number of concurrent live intervals in the loop. unsigned MaxLocalUsers; + /// Holds the number of instructions in the loop. unsigned NumInstructions; }; @@ -1920,6 +1933,7 @@ public: /// \returns True if it is more profitable to scalarize instruction \p I for /// vectorization factor \p VF. bool isProfitableToScalarize(Instruction *I, unsigned VF) const { + assert(VF > 1 && "Profitable to scalarize relevant only for VF > 1."); auto Scalars = InstsToScalarize.find(VF); assert(Scalars != InstsToScalarize.end() && "VF not yet analyzed for scalarization profitability"); @@ -1954,7 +1968,8 @@ public: /// Decision that was taken during cost calculation for memory instruction. enum InstWidening { CM_Unknown, - CM_Widen, + CM_Widen, // For consecutive accesses with stride +1. + CM_Widen_Reverse, // For consecutive accesses with stride -1. CM_Interleave, CM_GatherScatter, CM_Scalarize @@ -2010,7 +2025,6 @@ public: /// is an induction variable. Such a truncate will be removed by adding a new /// induction variable with the destination type. bool isOptimizableIVTruncate(Instruction *I, unsigned VF) { - // If the instruction is not a truncate, return false. auto *Trunc = dyn_cast<TruncInst>(I); if (!Trunc) @@ -2030,13 +2044,29 @@ public: return false; // If the truncated value is not an induction variable, return false. - return Legal->isInductionVariable(Op); + return Legal->isInductionPhi(Op); + } + + /// Collects the instructions to scalarize for each predicated instruction in + /// the loop. + void collectInstsToScalarize(unsigned VF); + + /// Collect Uniform and Scalar values for the given \p VF. + /// The sets depend on CM decision for Load/Store instructions + /// that may be vectorized as interleave, gather-scatter or scalarized. + void collectUniformsAndScalars(unsigned VF) { + // Do the analysis once. + if (VF == 1 || Uniforms.count(VF)) + return; + setCostBasedWideningDecision(VF); + collectLoopUniforms(VF); + collectLoopScalars(VF); } private: /// \return An upper bound for the vectorization factor, larger than zero. /// One is returned if vectorization should best be avoided due to cost. - unsigned computeFeasibleMaxVF(bool OptForSize); + unsigned computeFeasibleMaxVF(bool OptForSize, unsigned ConstTripCount); /// The vectorization cost is a combination of the cost itself and a boolean /// indicating whether any of the contributing operations will actually @@ -2045,7 +2075,7 @@ private: /// is /// false, then all operations will be scalarized (i.e. no vectorization has /// actually taken place). - typedef std::pair<unsigned, bool> VectorizationCostTy; + using VectorizationCostTy = std::pair<unsigned, bool>; /// Returns the expected execution cost. The unit of the cost does /// not matter because we use the 'cost' units to compare different @@ -2102,7 +2132,7 @@ private: /// A type representing the costs for instructions if they were to be /// scalarized rather than vectorized. The entries are Instruction-Cost /// pairs. - typedef DenseMap<Instruction *, unsigned> ScalarCostsTy; + using ScalarCostsTy = DenseMap<Instruction *, unsigned>; /// A set containing all BasicBlocks that are known to present after /// vectorization as a predicated block. @@ -2134,10 +2164,6 @@ private: int computePredInstDiscount(Instruction *PredInst, ScalarCostsTy &ScalarCosts, unsigned VF); - /// Collects the instructions to scalarize for each predicated instruction in - /// the loop. - void collectInstsToScalarize(unsigned VF); - /// Collect the instructions that are uniform after vectorization. An /// instruction is uniform if we represent it with a single scalar value in /// the vectorized loop corresponding to each vector iteration. Examples of @@ -2156,72 +2182,137 @@ private: /// iteration of the original scalar loop. void collectLoopScalars(unsigned VF); - /// Collect Uniform and Scalar values for the given \p VF. - /// The sets depend on CM decision for Load/Store instructions - /// that may be vectorized as interleave, gather-scatter or scalarized. - void collectUniformsAndScalars(unsigned VF) { - // Do the analysis once. - if (VF == 1 || Uniforms.count(VF)) - return; - setCostBasedWideningDecision(VF); - collectLoopUniforms(VF); - collectLoopScalars(VF); - } - /// Keeps cost model vectorization decision and cost for instructions. /// Right now it is used for memory instructions only. - typedef DenseMap<std::pair<Instruction *, unsigned>, - std::pair<InstWidening, unsigned>> - DecisionList; + using DecisionList = DenseMap<std::pair<Instruction *, unsigned>, + std::pair<InstWidening, unsigned>>; DecisionList WideningDecisions; public: /// The loop that we evaluate. Loop *TheLoop; + /// Predicated scalar evolution analysis. PredicatedScalarEvolution &PSE; + /// Loop Info analysis. LoopInfo *LI; + /// Vectorization legality. LoopVectorizationLegality *Legal; + /// Vector target information. const TargetTransformInfo &TTI; + /// Target Library Info. const TargetLibraryInfo *TLI; + /// Demanded bits analysis. DemandedBits *DB; + /// Assumption cache. AssumptionCache *AC; + /// Interface to emit optimization remarks. OptimizationRemarkEmitter *ORE; const Function *TheFunction; + /// Loop Vectorize Hint. const LoopVectorizeHints *Hints; + /// Values to ignore in the cost model. SmallPtrSet<const Value *, 16> ValuesToIgnore; + /// Values to ignore in the cost model when VF > 1. SmallPtrSet<const Value *, 16> VecValuesToIgnore; }; +} // end anonymous namespace + +namespace llvm { + +/// InnerLoopVectorizer vectorizes loops which contain only one basic /// LoopVectorizationPlanner - drives the vectorization process after having /// passed Legality checks. +/// The planner builds and optimizes the Vectorization Plans which record the +/// decisions how to vectorize the given loop. In particular, represent the +/// control-flow of the vectorized version, the replication of instructions that +/// are to be scalarized, and interleave access groups. class LoopVectorizationPlanner { + /// The loop that we evaluate. + Loop *OrigLoop; + + /// Loop Info analysis. + LoopInfo *LI; + + /// Target Library Info. + const TargetLibraryInfo *TLI; + + /// Target Transform Info. + const TargetTransformInfo *TTI; + + /// The legality analysis. + LoopVectorizationLegality *Legal; + + /// The profitablity analysis. + LoopVectorizationCostModel &CM; + + using VPlanPtr = std::unique_ptr<VPlan>; + + SmallVector<VPlanPtr, 4> VPlans; + + /// This class is used to enable the VPlan to invoke a method of ILV. This is + /// needed until the method is refactored out of ILV and becomes reusable. + struct VPCallbackILV : public VPCallback { + InnerLoopVectorizer &ILV; + + VPCallbackILV(InnerLoopVectorizer &ILV) : ILV(ILV) {} + + Value *getOrCreateVectorValues(Value *V, unsigned Part) override { + return ILV.getOrCreateVectorValue(V, Part); + } + }; + + /// A builder used to construct the current plan. + VPBuilder Builder; + + /// When we if-convert we need to create edge masks. We have to cache values + /// so that we don't end up with exponential recursion/IR. Note that + /// if-conversion currently takes place during VPlan-construction, so these + /// caches are only used at that stage. + using EdgeMaskCacheTy = + DenseMap<std::pair<BasicBlock *, BasicBlock *>, VPValue *>; + using BlockMaskCacheTy = DenseMap<BasicBlock *, VPValue *>; + EdgeMaskCacheTy EdgeMaskCache; + BlockMaskCacheTy BlockMaskCache; + + unsigned BestVF = 0; + unsigned BestUF = 0; + public: - LoopVectorizationPlanner(Loop *OrigLoop, LoopInfo *LI, + LoopVectorizationPlanner(Loop *L, LoopInfo *LI, const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, LoopVectorizationLegality *Legal, LoopVectorizationCostModel &CM) - : OrigLoop(OrigLoop), LI(LI), Legal(Legal), CM(CM) {} - - ~LoopVectorizationPlanner() {} + : OrigLoop(L), LI(LI), TLI(TLI), TTI(TTI), Legal(Legal), CM(CM) {} /// Plan how to best vectorize, return the best VF and its cost. LoopVectorizationCostModel::VectorizationFactor plan(bool OptForSize, unsigned UserVF); - /// Generate the IR code for the vectorized loop. - void executePlan(InnerLoopVectorizer &ILV); + /// Finalize the best decision and dispose of all other VPlans. + void setBestPlan(unsigned VF, unsigned UF); + + /// Generate the IR code for the body of the vectorized loop according to the + /// best selected VPlan. + void executePlan(InnerLoopVectorizer &LB, DominatorTree *DT); + + void printPlans(raw_ostream &O) { + for (const auto &Plan : VPlans) + O << *Plan; + } protected: /// Collect the instructions from the original loop that would be trivially @@ -2229,20 +2320,102 @@ protected: void collectTriviallyDeadInstructions( SmallPtrSetImpl<Instruction *> &DeadInstructions); -private: - /// The loop that we evaluate. - Loop *OrigLoop; + /// A range of powers-of-2 vectorization factors with fixed start and + /// adjustable end. The range includes start and excludes end, e.g.,: + /// [1, 9) = {1, 2, 4, 8} + struct VFRange { + // A power of 2. + const unsigned Start; - /// Loop Info analysis. - LoopInfo *LI; + // Need not be a power of 2. If End <= Start range is empty. + unsigned End; + }; - /// The legality analysis. - LoopVectorizationLegality *Legal; + /// Test a \p Predicate on a \p Range of VF's. Return the value of applying + /// \p Predicate on Range.Start, possibly decreasing Range.End such that the + /// returned value holds for the entire \p Range. + bool getDecisionAndClampRange(const std::function<bool(unsigned)> &Predicate, + VFRange &Range); - /// The profitablity analysis. - LoopVectorizationCostModel &CM; + /// Build VPlans for power-of-2 VF's between \p MinVF and \p MaxVF inclusive, + /// according to the information gathered by Legal when it checked if it is + /// legal to vectorize the loop. + void buildVPlans(unsigned MinVF, unsigned MaxVF); + +private: + /// A helper function that computes the predicate of the block BB, assuming + /// that the header block of the loop is set to True. It returns the *entry* + /// mask for the block BB. + VPValue *createBlockInMask(BasicBlock *BB, VPlanPtr &Plan); + + /// A helper function that computes the predicate of the edge between SRC + /// and DST. + VPValue *createEdgeMask(BasicBlock *Src, BasicBlock *Dst, VPlanPtr &Plan); + + /// Check if \I belongs to an Interleave Group within the given VF \p Range, + /// \return true in the first returned value if so and false otherwise. + /// Build a new VPInterleaveGroup Recipe if \I is the primary member of an IG + /// for \p Range.Start, and provide it as the second returned value. + /// Note that if \I is an adjunct member of an IG for \p Range.Start, the + /// \return value is <true, nullptr>, as it is handled by another recipe. + /// \p Range.End may be decreased to ensure same decision from \p Range.Start + /// to \p Range.End. + VPInterleaveRecipe *tryToInterleaveMemory(Instruction *I, VFRange &Range); + + // Check if \I is a memory instruction to be widened for \p Range.Start and + // potentially masked. Such instructions are handled by a recipe that takes an + // additional VPInstruction for the mask. + VPWidenMemoryInstructionRecipe *tryToWidenMemory(Instruction *I, + VFRange &Range, + VPlanPtr &Plan); + + /// Check if an induction recipe should be constructed for \I within the given + /// VF \p Range. If so build and return it. If not, return null. \p Range.End + /// may be decreased to ensure same decision from \p Range.Start to + /// \p Range.End. + VPWidenIntOrFpInductionRecipe *tryToOptimizeInduction(Instruction *I, + VFRange &Range); + + /// Handle non-loop phi nodes. Currently all such phi nodes are turned into + /// a sequence of select instructions as the vectorizer currently performs + /// full if-conversion. + VPBlendRecipe *tryToBlend(Instruction *I, VPlanPtr &Plan); + + /// Check if \p I can be widened within the given VF \p Range. If \p I can be + /// widened for \p Range.Start, check if the last recipe of \p VPBB can be + /// extended to include \p I or else build a new VPWidenRecipe for it and + /// append it to \p VPBB. Return true if \p I can be widened for Range.Start, + /// false otherwise. Range.End may be decreased to ensure same decision from + /// \p Range.Start to \p Range.End. + bool tryToWiden(Instruction *I, VPBasicBlock *VPBB, VFRange &Range); + + /// Build a VPReplicationRecipe for \p I and enclose it within a Region if it + /// is predicated. \return \p VPBB augmented with this new recipe if \p I is + /// not predicated, otherwise \return a new VPBasicBlock that succeeds the new + /// Region. Update the packing decision of predicated instructions if they + /// feed \p I. Range.End may be decreased to ensure same recipe behavior from + /// \p Range.Start to \p Range.End. + VPBasicBlock *handleReplication( + Instruction *I, VFRange &Range, VPBasicBlock *VPBB, + DenseMap<Instruction *, VPReplicateRecipe *> &PredInst2Recipe, + VPlanPtr &Plan); + + /// Create a replicating region for instruction \p I that requires + /// predication. \p PredRecipe is a VPReplicateRecipe holding \p I. + VPRegionBlock *createReplicateRegion(Instruction *I, VPRecipeBase *PredRecipe, + VPlanPtr &Plan); + + /// Build a VPlan according to the information gathered by Legal. \return a + /// VPlan for vectorization factors \p Range.Start and up to \p Range.End + /// exclusive, possibly decreasing \p Range.End. + VPlanPtr buildVPlan(VFRange &Range, + const SmallPtrSetImpl<Value *> &NeedDef); }; +} // end namespace llvm + +namespace { + /// \brief This holds vectorization requirements that must be verified late in /// the process. The requirements are set by legalize and costmodel. Once /// vectorization has been determined to be possible and profitable the @@ -2257,8 +2430,7 @@ private: /// followed by a non-expert user. class LoopVectorizationRequirements { public: - LoopVectorizationRequirements(OptimizationRemarkEmitter &ORE) - : NumRuntimePointerChecks(0), UnsafeAlgebraInst(nullptr), ORE(ORE) {} + LoopVectorizationRequirements(OptimizationRemarkEmitter &ORE) : ORE(ORE) {} void addUnsafeAlgebraInst(Instruction *I) { // First unsafe algebra instruction. @@ -2272,12 +2444,14 @@ public: const char *PassName = Hints.vectorizeAnalysisPassName(); bool Failed = false; if (UnsafeAlgebraInst && !Hints.allowReordering()) { - ORE.emit( - OptimizationRemarkAnalysisFPCommute(PassName, "CantReorderFPOps", - UnsafeAlgebraInst->getDebugLoc(), - UnsafeAlgebraInst->getParent()) - << "loop not vectorized: cannot prove it is safe to reorder " - "floating-point operations"); + ORE.emit([&]() { + return OptimizationRemarkAnalysisFPCommute( + PassName, "CantReorderFPOps", + UnsafeAlgebraInst->getDebugLoc(), + UnsafeAlgebraInst->getParent()) + << "loop not vectorized: cannot prove it is safe to reorder " + "floating-point operations"; + }); Failed = true; } @@ -2288,11 +2462,13 @@ public: NumRuntimePointerChecks > VectorizerParams::RuntimeMemoryCheckThreshold; if ((ThresholdReached && !Hints.allowReordering()) || PragmaThresholdReached) { - ORE.emit(OptimizationRemarkAnalysisAliasing(PassName, "CantReorderMemOps", + ORE.emit([&]() { + return OptimizationRemarkAnalysisAliasing(PassName, "CantReorderMemOps", L->getStartLoc(), L->getHeader()) << "loop not vectorized: cannot prove it is safe to reorder " - "memory operations"); + "memory operations"; + }); DEBUG(dbgs() << "LV: Too many memory checks needed.\n"); Failed = true; } @@ -2301,13 +2477,15 @@ public: } private: - unsigned NumRuntimePointerChecks; - Instruction *UnsafeAlgebraInst; + unsigned NumRuntimePointerChecks = 0; + Instruction *UnsafeAlgebraInst = nullptr; /// Interface to emit optimization remarks. OptimizationRemarkEmitter &ORE; }; +} // end anonymous namespace + static void addAcyclicInnerLoop(Loop &L, SmallVectorImpl<Loop *> &V) { if (L.empty()) { if (!hasCyclesInLoopBody(L)) @@ -2318,11 +2496,15 @@ static void addAcyclicInnerLoop(Loop &L, SmallVectorImpl<Loop *> &V) { addAcyclicInnerLoop(*InnerL, V); } +namespace { + /// The LoopVectorize Pass. struct LoopVectorize : public FunctionPass { /// Pass identification, replacement for typeid static char ID; + LoopVectorizePass Impl; + explicit LoopVectorize(bool NoUnrolling = false, bool AlwaysVectorize = true) : FunctionPass(ID) { Impl.DisableUnrolling = NoUnrolling; @@ -2330,8 +2512,6 @@ struct LoopVectorize : public FunctionPass { initializeLoopVectorizePass(*PassRegistry::getPassRegistry()); } - LoopVectorizePass Impl; - bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; @@ -2450,6 +2630,7 @@ void InnerLoopVectorizer::createVectorIntOrFpInductionPHI( Instruction *LastInduction = VecInd; for (unsigned Part = 0; Part < UF; ++Part) { VectorLoopValueMap.setVectorValue(EntryVal, Part, LastInduction); + recordVectorLoopValueForInductionCast(II, LastInduction, Part); if (isa<TruncInst>(EntryVal)) addMetadata(LastInduction, EntryVal); LastInduction = cast<Instruction>(addFastMathFlag( @@ -2480,11 +2661,26 @@ bool InnerLoopVectorizer::needsScalarInduction(Instruction *IV) const { auto *I = cast<Instruction>(U); return (OrigLoop->contains(I) && shouldScalarizeInstruction(I)); }; - return any_of(IV->users(), isScalarInst); + return llvm::any_of(IV->users(), isScalarInst); } -void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) { +void InnerLoopVectorizer::recordVectorLoopValueForInductionCast( + const InductionDescriptor &ID, Value *VectorLoopVal, unsigned Part, + unsigned Lane) { + const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts(); + if (Casts.empty()) + return; + // Only the first Cast instruction in the Casts vector is of interest. + // The rest of the Casts (if exist) have no uses outside the + // induction update chain itself. + Instruction *CastInst = *Casts.begin(); + if (Lane < UINT_MAX) + VectorLoopValueMap.setScalarValue(CastInst, {Part, Lane}, VectorLoopVal); + else + VectorLoopValueMap.setVectorValue(CastInst, Part, VectorLoopVal); +} +void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) { assert((IV->getType()->isIntegerTy() || IV != OldInduction) && "Primary induction variable must have an integer type"); @@ -2564,6 +2760,7 @@ void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) { Value *EntryPart = getStepVector(Broadcasted, VF * Part, Step, ID.getInductionOpcode()); VectorLoopValueMap.setVectorValue(EntryVal, Part, EntryPart); + recordVectorLoopValueForInductionCast(ID, EntryPart, Part); if (Trunc) addMetadata(EntryPart, Trunc); } @@ -2622,7 +2819,7 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, // Floating point operations had to be 'fast' to enable the induction. FastMathFlags Flags; - Flags.setUnsafeAlgebra(); + Flags.setFast(); Value *MulOp = Builder.CreateFMul(Cv, Step); if (isa<Instruction>(MulOp)) @@ -2638,7 +2835,6 @@ Value *InnerLoopVectorizer::getStepVector(Value *Val, int StartIdx, Value *Step, void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, Value *EntryVal, const InductionDescriptor &ID) { - // We shouldn't have to build scalar steps if we aren't vectorizing. assert(VF > 1 && "VF should be greater than one"); @@ -2663,21 +2859,21 @@ void InnerLoopVectorizer::buildScalarSteps(Value *ScalarIV, Value *Step, // iteration. If EntryVal is uniform, we only need to generate the first // lane. Otherwise, we generate all VF values. unsigned Lanes = - Cost->isUniformAfterVectorization(cast<Instruction>(EntryVal), VF) ? 1 : VF; - + Cost->isUniformAfterVectorization(cast<Instruction>(EntryVal), VF) ? 1 + : VF; // Compute the scalar steps and save the results in VectorLoopValueMap. for (unsigned Part = 0; Part < UF; ++Part) { for (unsigned Lane = 0; Lane < Lanes; ++Lane) { auto *StartIdx = getSignedIntOrFpConstant(ScalarIVTy, VF * Part + Lane); auto *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, StartIdx, Step)); auto *Add = addFastMathFlag(Builder.CreateBinOp(AddOp, ScalarIV, Mul)); - VectorLoopValueMap.setScalarValue(EntryVal, Part, Lane, Add); + VectorLoopValueMap.setScalarValue(EntryVal, {Part, Lane}, Add); + recordVectorLoopValueForInductionCast(ID, Add, Part, Lane); } } } int LoopVectorizationLegality::isConsecutivePtr(Value *Ptr) { - const ValueToValueMap &Strides = getSymbolicStrides() ? *getSymbolicStrides() : ValueToValueMap(); @@ -2708,8 +2904,7 @@ Value *InnerLoopVectorizer::getOrCreateVectorValue(Value *V, unsigned Part) { // instead. If it has been scalarized, and we actually need the value in // vector form, we will construct the vector values on demand. if (VectorLoopValueMap.hasAnyScalarValue(V)) { - - Value *ScalarValue = VectorLoopValueMap.getScalarValue(V, Part, 0); + Value *ScalarValue = VectorLoopValueMap.getScalarValue(V, {Part, 0}); // If we've scalarized a value, that value should be an instruction. auto *I = cast<Instruction>(V); @@ -2726,8 +2921,8 @@ Value *InnerLoopVectorizer::getOrCreateVectorValue(Value *V, unsigned Part) { // of the Part unroll iteration. Otherwise, the last instruction is the one // we created for the last vector lane of the Part unroll iteration. unsigned LastLane = Cost->isUniformAfterVectorization(I, VF) ? 0 : VF - 1; - auto *LastInst = - cast<Instruction>(VectorLoopValueMap.getScalarValue(V, Part, LastLane)); + auto *LastInst = cast<Instruction>( + VectorLoopValueMap.getScalarValue(V, {Part, LastLane})); // Set the insert point after the last scalarized instruction. This ensures // the insertelement sequence will directly follow the scalar definitions. @@ -2744,14 +2939,15 @@ Value *InnerLoopVectorizer::getOrCreateVectorValue(Value *V, unsigned Part) { Value *VectorValue = nullptr; if (Cost->isUniformAfterVectorization(I, VF)) { VectorValue = getBroadcastInstrs(ScalarValue); + VectorLoopValueMap.setVectorValue(V, Part, VectorValue); } else { - VectorValue = UndefValue::get(VectorType::get(V->getType(), VF)); + // Initialize packing with insertelements to start from undef. + Value *Undef = UndefValue::get(VectorType::get(V->getType(), VF)); + VectorLoopValueMap.setVectorValue(V, Part, Undef); for (unsigned Lane = 0; Lane < VF; ++Lane) - VectorValue = Builder.CreateInsertElement( - VectorValue, getOrCreateScalarValue(V, Part, Lane), - Builder.getInt32(Lane)); + packScalarIntoVectorValue(V, {Part, Lane}); + VectorValue = VectorLoopValueMap.getVectorValue(V, Part); } - VectorLoopValueMap.setVectorValue(V, Part, VectorValue); Builder.restoreIP(OldIP); return VectorValue; } @@ -2763,28 +2959,29 @@ Value *InnerLoopVectorizer::getOrCreateVectorValue(Value *V, unsigned Part) { return B; } -Value *InnerLoopVectorizer::getOrCreateScalarValue(Value *V, unsigned Part, - unsigned Lane) { - +Value * +InnerLoopVectorizer::getOrCreateScalarValue(Value *V, + const VPIteration &Instance) { // If the value is not an instruction contained in the loop, it should // already be scalar. if (OrigLoop->isLoopInvariant(V)) return V; - assert(Lane > 0 ? !Cost->isUniformAfterVectorization(cast<Instruction>(V), VF) - : true && "Uniform values only have lane zero"); + assert(Instance.Lane > 0 + ? !Cost->isUniformAfterVectorization(cast<Instruction>(V), VF) + : true && "Uniform values only have lane zero"); // If the value from the original loop has not been vectorized, it is // represented by UF x VF scalar values in the new loop. Return the requested // scalar value. - if (VectorLoopValueMap.hasScalarValue(V, Part, Lane)) - return VectorLoopValueMap.getScalarValue(V, Part, Lane); + if (VectorLoopValueMap.hasScalarValue(V, Instance)) + return VectorLoopValueMap.getScalarValue(V, Instance); // If the value has not been scalarized, get its entry in VectorLoopValueMap // for the given unroll part. If this entry is not a vector type (i.e., the // vectorization factor is one), there is no need to generate an // extractelement instruction. - auto *U = getOrCreateVectorValue(V, Part); + auto *U = getOrCreateVectorValue(V, Instance.Part); if (!U->getType()->isVectorTy()) { assert(VF == 1 && "Value not scalarized has non-vector type"); return U; @@ -2793,7 +2990,20 @@ Value *InnerLoopVectorizer::getOrCreateScalarValue(Value *V, unsigned Part, // Otherwise, the value from the original loop has been vectorized and is // represented by UF vector values. Extract and return the requested scalar // value from the appropriate vector lane. - return Builder.CreateExtractElement(U, Builder.getInt32(Lane)); + return Builder.CreateExtractElement(U, Builder.getInt32(Instance.Lane)); +} + +void InnerLoopVectorizer::packScalarIntoVectorValue( + Value *V, const VPIteration &Instance) { + assert(V != Induction && "The new induction variable should not be used."); + assert(!V->getType()->isVectorTy() && "Can't pack a vector"); + assert(!V->getType()->isVoidTy() && "Type does not produce a value"); + + Value *ScalarInst = VectorLoopValueMap.getScalarValue(V, Instance); + Value *VectorValue = VectorLoopValueMap.getVectorValue(V, Instance.Part); + VectorValue = Builder.CreateInsertElement(VectorValue, ScalarInst, + Builder.getInt32(Instance.Lane)); + VectorLoopValueMap.resetVectorValue(V, Instance.Part, VectorValue); } Value *InnerLoopVectorizer::reverseVector(Value *Vec) { @@ -2843,6 +3053,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { if (Instr != Group->getInsertPos()) return; + const DataLayout &DL = Instr->getModule()->getDataLayout(); Value *Ptr = getPointerOperand(Instr); // Prepare for the vector type of the interleaved load/store. @@ -2866,7 +3077,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Index += (VF - 1) * Group->getFactor(); for (unsigned Part = 0; Part < UF; Part++) { - Value *NewPtr = getOrCreateScalarValue(Ptr, Part, 0); + Value *NewPtr = getOrCreateScalarValue(Ptr, {Part, 0}); // Notice current instruction could be any index. Need to adjust the address // to the member of index 0. @@ -2890,13 +3101,12 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { // Vectorize the interleaved load group. if (isa<LoadInst>(Instr)) { - // For each unroll part, create a wide load for the group. SmallVector<Value *, 2> NewLoads; for (unsigned Part = 0; Part < UF; Part++) { auto *NewLoad = Builder.CreateAlignedLoad( NewPtrs[Part], Group->getAlignment(), "wide.vec"); - addMetadata(NewLoad, Instr); + Group->addMetadata(NewLoad); NewLoads.push_back(NewLoad); } @@ -2917,7 +3127,7 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { // If this member has different type, cast the result type. if (Member->getType() != ScalarTy) { VectorType *OtherVTy = VectorType::get(Member->getType(), VF); - StridedVec = Builder.CreateBitOrPointerCast(StridedVec, OtherVTy); + StridedVec = createBitOrPointerCast(StridedVec, OtherVTy, DL); } if (Group->isReverse()) @@ -2946,9 +3156,10 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { if (Group->isReverse()) StoredVec = reverseVector(StoredVec); - // If this member has different type, cast it to an unified type. + // If this member has different type, cast it to a unified type. + if (StoredVec->getType() != SubVT) - StoredVec = Builder.CreateBitOrPointerCast(StoredVec, SubVT); + StoredVec = createBitOrPointerCast(StoredVec, SubVT, DL); StoredVecs.push_back(StoredVec); } @@ -2963,11 +3174,13 @@ void InnerLoopVectorizer::vectorizeInterleaveGroup(Instruction *Instr) { Instruction *NewStoreInstr = Builder.CreateAlignedStore(IVec, NewPtrs[Part], Group->getAlignment()); - addMetadata(NewStoreInstr, Instr); + + Group->addMetadata(NewStoreInstr); } } -void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { +void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr, + VectorParts *BlockInMask) { // Attempt to issue a wide load. LoadInst *LI = dyn_cast<LoadInst>(Instr); StoreInst *SI = dyn_cast<StoreInst>(Instr); @@ -2992,14 +3205,11 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Alignment = DL.getABITypeAlignment(ScalarDataTy); unsigned AddressSpace = getMemInstAddressSpace(Instr); - // Scalarize the memory instruction if necessary. - if (Decision == LoopVectorizationCostModel::CM_Scalarize) - return scalarizeInstruction(Instr, Legal->isScalarWithPredication(Instr)); - // Determine if the pointer operand of the access is either consecutive or // reverse consecutive. - int ConsecutiveStride = Legal->isConsecutivePtr(Ptr); - bool Reverse = ConsecutiveStride < 0; + bool Reverse = (Decision == LoopVectorizationCostModel::CM_Widen_Reverse); + bool ConsecutiveStride = + Reverse || (Decision == LoopVectorizationCostModel::CM_Widen); bool CreateGatherScatter = (Decision == LoopVectorizationCostModel::CM_GatherScatter); @@ -3010,9 +3220,13 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { // Handle consecutive loads/stores. if (ConsecutiveStride) - Ptr = getOrCreateScalarValue(Ptr, 0, 0); + Ptr = getOrCreateScalarValue(Ptr, {0, 0}); + + VectorParts Mask; + bool isMaskRequired = BlockInMask; + if (isMaskRequired) + Mask = *BlockInMask; - VectorParts Mask = createBlockInMask(Instr->getParent()); // Handle Stores: if (SI) { assert(!Legal->isUniform(SI->getPointerOperand()) && @@ -3023,7 +3237,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Instruction *NewSI = nullptr; Value *StoredVal = getOrCreateVectorValue(SI->getValueOperand(), Part); if (CreateGatherScatter) { - Value *MaskPart = Legal->isMaskRequired(SI) ? Mask[Part] : nullptr; + Value *MaskPart = isMaskRequired ? Mask[Part] : nullptr; Value *VectorGep = getOrCreateVectorValue(Ptr, Part); NewSI = Builder.CreateMaskedScatter(StoredVal, VectorGep, Alignment, MaskPart); @@ -3045,13 +3259,14 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); PartPtr = Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); - Mask[Part] = reverseVector(Mask[Part]); + if (isMaskRequired) // Reverse of a null all-one mask is a null mask. + Mask[Part] = reverseVector(Mask[Part]); } Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - if (Legal->isMaskRequired(SI)) + if (isMaskRequired) NewSI = Builder.CreateMaskedStore(StoredVal, VecPtr, Alignment, Mask[Part]); else @@ -3068,7 +3283,7 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { for (unsigned Part = 0; Part < UF; ++Part) { Value *NewLI; if (CreateGatherScatter) { - Value *MaskPart = Legal->isMaskRequired(LI) ? Mask[Part] : nullptr; + Value *MaskPart = isMaskRequired ? Mask[Part] : nullptr; Value *VectorGep = getOrCreateVectorValue(Ptr, Part); NewLI = Builder.CreateMaskedGather(VectorGep, Alignment, MaskPart, nullptr, "wide.masked.gather"); @@ -3083,12 +3298,13 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { // wide load needs to start at the last vector element. PartPtr = Builder.CreateGEP(nullptr, Ptr, Builder.getInt32(-Part * VF)); PartPtr = Builder.CreateGEP(nullptr, PartPtr, Builder.getInt32(1 - VF)); - Mask[Part] = reverseVector(Mask[Part]); + if (isMaskRequired) // Reverse of a null all-one mask is a null mask. + Mask[Part] = reverseVector(Mask[Part]); } Value *VecPtr = Builder.CreateBitCast(PartPtr, DataTy->getPointerTo(AddressSpace)); - if (Legal->isMaskRequired(LI)) + if (isMaskRequired) NewLI = Builder.CreateMaskedLoad(VecPtr, Alignment, Mask[Part], UndefValue::get(DataTy), "wide.masked.load"); @@ -3105,71 +3321,41 @@ void InnerLoopVectorizer::vectorizeMemoryInstruction(Instruction *Instr) { } void InnerLoopVectorizer::scalarizeInstruction(Instruction *Instr, + const VPIteration &Instance, bool IfPredicateInstr) { assert(!Instr->getType()->isAggregateType() && "Can't handle vectors"); - DEBUG(dbgs() << "LV: Scalarizing" - << (IfPredicateInstr ? " and predicating:" : ":") << *Instr - << '\n'); - // Holds vector parameters or scalars, in case of uniform vals. - SmallVector<VectorParts, 4> Params; setDebugLocFromInst(Builder, Instr); // Does this instruction return a value ? bool IsVoidRetTy = Instr->getType()->isVoidTy(); - VectorParts Cond; - if (IfPredicateInstr) - Cond = createBlockInMask(Instr->getParent()); - - // Determine the number of scalars we need to generate for each unroll - // iteration. If the instruction is uniform, we only need to generate the - // first lane. Otherwise, we generate all VF values. - unsigned Lanes = Cost->isUniformAfterVectorization(Instr, VF) ? 1 : VF; - - // For each vector unroll 'part': - for (unsigned Part = 0; Part < UF; ++Part) { - // For each scalar that we create: - for (unsigned Lane = 0; Lane < Lanes; ++Lane) { + Instruction *Cloned = Instr->clone(); + if (!IsVoidRetTy) + Cloned->setName(Instr->getName() + ".cloned"); - // Start if-block. - Value *Cmp = nullptr; - if (IfPredicateInstr) { - Cmp = Cond[Part]; - if (Cmp->getType()->isVectorTy()) - Cmp = Builder.CreateExtractElement(Cmp, Builder.getInt32(Lane)); - Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Cmp, - ConstantInt::get(Cmp->getType(), 1)); - } - - Instruction *Cloned = Instr->clone(); - if (!IsVoidRetTy) - Cloned->setName(Instr->getName() + ".cloned"); - - // Replace the operands of the cloned instructions with their scalar - // equivalents in the new loop. - for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { - auto *NewOp = getOrCreateScalarValue(Instr->getOperand(op), Part, Lane); - Cloned->setOperand(op, NewOp); - } - addNewMetadata(Cloned, Instr); + // Replace the operands of the cloned instructions with their scalar + // equivalents in the new loop. + for (unsigned op = 0, e = Instr->getNumOperands(); op != e; ++op) { + auto *NewOp = getOrCreateScalarValue(Instr->getOperand(op), Instance); + Cloned->setOperand(op, NewOp); + } + addNewMetadata(Cloned, Instr); - // Place the cloned scalar in the new loop. - Builder.Insert(Cloned); + // Place the cloned scalar in the new loop. + Builder.Insert(Cloned); - // Add the cloned scalar to the scalar map entry. - VectorLoopValueMap.setScalarValue(Instr, Part, Lane, Cloned); + // Add the cloned scalar to the scalar map entry. + VectorLoopValueMap.setScalarValue(Instr, Instance, Cloned); - // If we just cloned a new assumption, add it the assumption cache. - if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) - if (II->getIntrinsicID() == Intrinsic::assume) - AC->registerAssumption(II); + // If we just cloned a new assumption, add it the assumption cache. + if (auto *II = dyn_cast<IntrinsicInst>(Cloned)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); - // End if-block. - if (IfPredicateInstr) - PredicatedInstructions.push_back(std::make_pair(Cloned, Cmp)); - } - } + // End if-block. + if (IfPredicateInstr) + PredicatedInstructions.push_back(Cloned); } PHINode *InnerLoopVectorizer::createInductionVariable(Loop *L, Value *Start, @@ -3281,6 +3467,36 @@ Value *InnerLoopVectorizer::getOrCreateVectorTripCount(Loop *L) { return VectorTripCount; } +Value *InnerLoopVectorizer::createBitOrPointerCast(Value *V, VectorType *DstVTy, + const DataLayout &DL) { + // Verify that V is a vector type with same number of elements as DstVTy. + unsigned VF = DstVTy->getNumElements(); + VectorType *SrcVecTy = cast<VectorType>(V->getType()); + assert((VF == SrcVecTy->getNumElements()) && "Vector dimensions do not match"); + Type *SrcElemTy = SrcVecTy->getElementType(); + Type *DstElemTy = DstVTy->getElementType(); + assert((DL.getTypeSizeInBits(SrcElemTy) == DL.getTypeSizeInBits(DstElemTy)) && + "Vector elements must have same size"); + + // Do a direct cast if element types are castable. + if (CastInst::isBitOrNoopPointerCastable(SrcElemTy, DstElemTy, DL)) { + return Builder.CreateBitOrPointerCast(V, DstVTy); + } + // V cannot be directly casted to desired vector type. + // May happen when V is a floating point vector but DstVTy is a vector of + // pointers or vice-versa. Handle this using a two-step bitcast using an + // intermediate Integer type for the bitcast i.e. Ptr <-> Int <-> Float. + assert((DstElemTy->isPointerTy() != SrcElemTy->isPointerTy()) && + "Only one type should be a pointer type"); + assert((DstElemTy->isFloatingPointTy() != SrcElemTy->isFloatingPointTy()) && + "Only one type should be a floating point type"); + Type *IntTy = + IntegerType::getIntNTy(V->getContext(), DL.getTypeSizeInBits(SrcElemTy)); + VectorType *VecIntTy = VectorType::get(IntTy, VF); + Value *CastVal = Builder.CreateBitOrPointerCast(V, VecIntTy); + return Builder.CreateBitOrPointerCast(CastVal, DstVTy); +} + void InnerLoopVectorizer::emitMinimumIterationCountCheck(Loop *L, BasicBlock *Bypass) { Value *Count = getOrCreateTripCount(L); @@ -3373,7 +3589,7 @@ void InnerLoopVectorizer::emitMemRuntimeChecks(Loop *L, BasicBlock *Bypass) { LVer->prepareNoAliasMetadata(); } -void InnerLoopVectorizer::createVectorizedLoopSkeleton() { +BasicBlock *InnerLoopVectorizer::createVectorizedLoopSkeleton() { /* In this function we generate a new loop. The new loop will contain the vectorized instructions while the old loop will continue to run the @@ -3435,7 +3651,7 @@ void InnerLoopVectorizer::createVectorizedLoopSkeleton() { MiddleBlock->splitBasicBlock(MiddleBlock->getTerminator(), "scalar.ph"); // Create and register the new vector loop. - Loop *Lp = new Loop(); + Loop *Lp = LI->AllocateLoop(); Loop *ParentLoop = OrigLoop->getParentLoop(); // Insert the new loop into the loop nest and register the new basic blocks @@ -3554,6 +3770,8 @@ void InnerLoopVectorizer::createVectorizedLoopSkeleton() { LoopVectorizeHints Hints(Lp, true, *ORE); Hints.setAlreadyVectorized(); + + return LoopVectorPreHeader; } // Fix up external users of the induction variable. At this point, we are @@ -3622,22 +3840,27 @@ void InnerLoopVectorizer::fixupIVUsers(PHINode *OrigPhi, } namespace { + struct CSEDenseMapInfo { static bool canHandle(const Instruction *I) { return isa<InsertElementInst>(I) || isa<ExtractElementInst>(I) || isa<ShuffleVectorInst>(I) || isa<GetElementPtrInst>(I); } + static inline Instruction *getEmptyKey() { return DenseMapInfo<Instruction *>::getEmptyKey(); } + static inline Instruction *getTombstoneKey() { return DenseMapInfo<Instruction *>::getTombstoneKey(); } + static unsigned getHashValue(const Instruction *I) { assert(canHandle(I) && "Unknown instruction!"); return hash_combine(I->getOpcode(), hash_combine_range(I->value_op_begin(), I->value_op_end())); } + static bool isEqual(const Instruction *LHS, const Instruction *RHS) { if (LHS == getEmptyKey() || RHS == getEmptyKey() || LHS == getTombstoneKey() || RHS == getTombstoneKey()) @@ -3645,7 +3868,8 @@ struct CSEDenseMapInfo { return LHS->isIdenticalTo(RHS); } }; -} + +} // end anonymous namespace ///\brief Perform cse of induction variable instructions. static void cse(BasicBlock *BB) { @@ -3777,7 +4001,6 @@ void InnerLoopVectorizer::truncateToMinimalBitwidths() { // For every instruction `I` in MinBWs, truncate the operands, create a // truncated version of `I` and reextend its result. InstCombine runs // later and will remove any ext/trunc pairs. - // SmallPtrSet<Value *, 4> Erased; for (const auto &KV : Cost->getMinimalBitwidths()) { // If the value wasn't vectorized, we must maintain the original scalar @@ -3927,7 +4150,8 @@ void InnerLoopVectorizer::fixVectorizedLoop() { IVEndValues[Entry.first], LoopMiddleBlock); fixLCSSAPHIs(); - predicateInstructions(); + for (Instruction *PI : PredicatedInstructions) + sinkScalarOperands(&*PI); // Remove redundant induction instructions. cse(LoopVectorBody); @@ -3953,7 +4177,6 @@ void InnerLoopVectorizer::fixCrossIterationPHIs() { } void InnerLoopVectorizer::fixFirstOrderRecurrence(PHINode *Phi) { - // This is the second phase of vectorizing first-order recurrences. An // overview of the transformation is described below. Suppose we have the // following loop. @@ -4211,7 +4434,8 @@ void InnerLoopVectorizer::fixReduction(PHINode *Phi) { // entire expression in the smaller type. if (VF > 1 && Phi->getType() != RdxDesc.getRecurrenceType()) { Type *RdxVecTy = VectorType::get(RdxDesc.getRecurrenceType(), VF); - Builder.SetInsertPoint(LoopVectorBody->getTerminator()); + Builder.SetInsertPoint( + LI->getLoopFor(LoopVectorBody)->getLoopLatch()->getTerminator()); VectorParts RdxParts(UF); for (unsigned Part = 0; Part < UF; ++Part) { RdxParts[Part] = VectorLoopValueMap.getVectorValue(LoopExitInst, Part); @@ -4317,7 +4541,6 @@ void InnerLoopVectorizer::fixLCSSAPHIs() { } void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { - // The basic block and loop containing the predicated instruction. auto *PredBB = PredInst->getParent(); auto *VectorLoop = LI->getLoopFor(PredBB); @@ -4346,7 +4569,6 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { // through the worklist doesn't sink a single instruction. bool Changed; do { - // Add the instructions that need to be reanalyzed to the worklist, and // reset the changed indicator. Worklist.insert(InstsToReanalyze.begin(), InstsToReanalyze.end()); @@ -4365,7 +4587,7 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { // It's legal to sink the instruction if all its uses occur in the // predicated block. Otherwise, there's nothing to do yet, and we may // need to reanalyze the instruction. - if (!all_of(I->uses(), isBlockOfUsePredicated)) { + if (!llvm::all_of(I->uses(), isBlockOfUsePredicated)) { InstsToReanalyze.push_back(I); continue; } @@ -4382,200 +4604,11 @@ void InnerLoopVectorizer::sinkScalarOperands(Instruction *PredInst) { } while (Changed); } -void InnerLoopVectorizer::predicateInstructions() { - - // For each instruction I marked for predication on value C, split I into its - // own basic block to form an if-then construct over C. Since I may be fed by - // an extractelement instruction or other scalar operand, we try to - // iteratively sink its scalar operands into the predicated block. If I feeds - // an insertelement instruction, we try to move this instruction into the - // predicated block as well. For non-void types, a phi node will be created - // for the resulting value (either vector or scalar). - // - // So for some predicated instruction, e.g. the conditional sdiv in: - // - // for.body: - // ... - // %add = add nsw i32 %mul, %0 - // %cmp5 = icmp sgt i32 %2, 7 - // br i1 %cmp5, label %if.then, label %if.end - // - // if.then: - // %div = sdiv i32 %0, %1 - // br label %if.end - // - // if.end: - // %x.0 = phi i32 [ %div, %if.then ], [ %add, %for.body ] - // - // the sdiv at this point is scalarized and if-converted using a select. - // The inactive elements in the vector are not used, but the predicated - // instruction is still executed for all vector elements, essentially: - // - // vector.body: - // ... - // %17 = add nsw <2 x i32> %16, %wide.load - // %29 = extractelement <2 x i32> %wide.load, i32 0 - // %30 = extractelement <2 x i32> %wide.load51, i32 0 - // %31 = sdiv i32 %29, %30 - // %32 = insertelement <2 x i32> undef, i32 %31, i32 0 - // %35 = extractelement <2 x i32> %wide.load, i32 1 - // %36 = extractelement <2 x i32> %wide.load51, i32 1 - // %37 = sdiv i32 %35, %36 - // %38 = insertelement <2 x i32> %32, i32 %37, i32 1 - // %predphi = select <2 x i1> %26, <2 x i32> %38, <2 x i32> %17 - // - // Predication will now re-introduce the original control flow to avoid false - // side-effects by the sdiv instructions on the inactive elements, yielding - // (after cleanup): - // - // vector.body: - // ... - // %5 = add nsw <2 x i32> %4, %wide.load - // %8 = icmp sgt <2 x i32> %wide.load52, <i32 7, i32 7> - // %9 = extractelement <2 x i1> %8, i32 0 - // br i1 %9, label %pred.sdiv.if, label %pred.sdiv.continue - // - // pred.sdiv.if: - // %10 = extractelement <2 x i32> %wide.load, i32 0 - // %11 = extractelement <2 x i32> %wide.load51, i32 0 - // %12 = sdiv i32 %10, %11 - // %13 = insertelement <2 x i32> undef, i32 %12, i32 0 - // br label %pred.sdiv.continue - // - // pred.sdiv.continue: - // %14 = phi <2 x i32> [ undef, %vector.body ], [ %13, %pred.sdiv.if ] - // %15 = extractelement <2 x i1> %8, i32 1 - // br i1 %15, label %pred.sdiv.if54, label %pred.sdiv.continue55 - // - // pred.sdiv.if54: - // %16 = extractelement <2 x i32> %wide.load, i32 1 - // %17 = extractelement <2 x i32> %wide.load51, i32 1 - // %18 = sdiv i32 %16, %17 - // %19 = insertelement <2 x i32> %14, i32 %18, i32 1 - // br label %pred.sdiv.continue55 - // - // pred.sdiv.continue55: - // %20 = phi <2 x i32> [ %14, %pred.sdiv.continue ], [ %19, %pred.sdiv.if54 ] - // %predphi = select <2 x i1> %8, <2 x i32> %20, <2 x i32> %5 - - for (auto KV : PredicatedInstructions) { - BasicBlock::iterator I(KV.first); - BasicBlock *Head = I->getParent(); - auto *T = SplitBlockAndInsertIfThen(KV.second, &*I, /*Unreachable=*/false, - /*BranchWeights=*/nullptr, DT, LI); - I->moveBefore(T); - sinkScalarOperands(&*I); - - BasicBlock *PredicatedBlock = I->getParent(); - Twine BBNamePrefix = Twine("pred.") + I->getOpcodeName(); - PredicatedBlock->setName(BBNamePrefix + ".if"); - PredicatedBlock->getSingleSuccessor()->setName(BBNamePrefix + ".continue"); - - // If the instruction is non-void create a Phi node at reconvergence point. - if (!I->getType()->isVoidTy()) { - Value *IncomingTrue = nullptr; - Value *IncomingFalse = nullptr; - - if (I->hasOneUse() && isa<InsertElementInst>(*I->user_begin())) { - // If the predicated instruction is feeding an insert-element, move it - // into the Then block; Phi node will be created for the vector. - InsertElementInst *IEI = cast<InsertElementInst>(*I->user_begin()); - IEI->moveBefore(T); - IncomingTrue = IEI; // the new vector with the inserted element. - IncomingFalse = IEI->getOperand(0); // the unmodified vector - } else { - // Phi node will be created for the scalar predicated instruction. - IncomingTrue = &*I; - IncomingFalse = UndefValue::get(I->getType()); - } - - BasicBlock *PostDom = I->getParent()->getSingleSuccessor(); - assert(PostDom && "Then block has multiple successors"); - PHINode *Phi = - PHINode::Create(IncomingTrue->getType(), 2, "", &PostDom->front()); - IncomingTrue->replaceAllUsesWith(Phi); - Phi->addIncoming(IncomingFalse, Head); - Phi->addIncoming(IncomingTrue, I->getParent()); - } - } - - DEBUG(DT->verifyDomTree()); -} - -InnerLoopVectorizer::VectorParts -InnerLoopVectorizer::createEdgeMask(BasicBlock *Src, BasicBlock *Dst) { - assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); - - // Look for cached value. - std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst); - EdgeMaskCacheTy::iterator ECEntryIt = EdgeMaskCache.find(Edge); - if (ECEntryIt != EdgeMaskCache.end()) - return ECEntryIt->second; - - VectorParts SrcMask = createBlockInMask(Src); - - // The terminator has to be a branch inst! - BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator()); - assert(BI && "Unexpected terminator found"); - - if (BI->isConditional()) { - - VectorParts EdgeMask(UF); - for (unsigned Part = 0; Part < UF; ++Part) { - auto *EdgeMaskPart = getOrCreateVectorValue(BI->getCondition(), Part); - if (BI->getSuccessor(0) != Dst) - EdgeMaskPart = Builder.CreateNot(EdgeMaskPart); - - EdgeMaskPart = Builder.CreateAnd(EdgeMaskPart, SrcMask[Part]); - EdgeMask[Part] = EdgeMaskPart; - } - - EdgeMaskCache[Edge] = EdgeMask; - return EdgeMask; - } - - EdgeMaskCache[Edge] = SrcMask; - return SrcMask; -} - -InnerLoopVectorizer::VectorParts -InnerLoopVectorizer::createBlockInMask(BasicBlock *BB) { - assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); - - // Look for cached value. - BlockMaskCacheTy::iterator BCEntryIt = BlockMaskCache.find(BB); - if (BCEntryIt != BlockMaskCache.end()) - return BCEntryIt->second; - - VectorParts BlockMask(UF); - - // Loop incoming mask is all-one. - if (OrigLoop->getHeader() == BB) { - Value *C = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 1); - for (unsigned Part = 0; Part < UF; ++Part) - BlockMask[Part] = getOrCreateVectorValue(C, Part); - BlockMaskCache[BB] = BlockMask; - return BlockMask; - } - - // This is the block mask. We OR all incoming edges, and with zero. - Value *Zero = ConstantInt::get(IntegerType::getInt1Ty(BB->getContext()), 0); - for (unsigned Part = 0; Part < UF; ++Part) - BlockMask[Part] = getOrCreateVectorValue(Zero, Part); - - // For each pred: - for (pred_iterator It = pred_begin(BB), E = pred_end(BB); It != E; ++It) { - VectorParts EM = createEdgeMask(*It, BB); - for (unsigned Part = 0; Part < UF; ++Part) - BlockMask[Part] = Builder.CreateOr(BlockMask[Part], EM[Part]); - } - - BlockMaskCache[BB] = BlockMask; - return BlockMask; -} - void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, unsigned VF) { + assert(PN->getParent() == OrigLoop->getHeader() && + "Non-header phis should have been handled elsewhere"); + PHINode *P = cast<PHINode>(PN); // In order to support recurrences we need to be able to vectorize Phi nodes. // Phi nodes have cycles, so we need to vectorize them in two stages. This is @@ -4594,43 +4627,6 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, } setDebugLocFromInst(Builder, P); - // Check for PHI nodes that are lowered to vector selects. - if (P->getParent() != OrigLoop->getHeader()) { - // We know that all PHIs in non-header blocks are converted into - // selects, so we don't have to worry about the insertion order and we - // can just use the builder. - // At this point we generate the predication tree. There may be - // duplications since this is a simple recursive scan, but future - // optimizations will clean it up. - - unsigned NumIncoming = P->getNumIncomingValues(); - - // Generate a sequence of selects of the form: - // SELECT(Mask3, In3, - // SELECT(Mask2, In2, - // ( ...))) - VectorParts Entry(UF); - for (unsigned In = 0; In < NumIncoming; In++) { - VectorParts Cond = - createEdgeMask(P->getIncomingBlock(In), P->getParent()); - - for (unsigned Part = 0; Part < UF; ++Part) { - Value *In0 = getOrCreateVectorValue(P->getIncomingValue(In), Part); - // We might have single edge PHIs (blocks) - use an identity - // 'select' for the first PHI operand. - if (In == 0) - Entry[Part] = Builder.CreateSelect(Cond[Part], In0, In0); - else - // Select between the current value and the previous incoming edge - // based on the incoming mask. - Entry[Part] = Builder.CreateSelect(Cond[Part], In0, Entry[Part], - "predphi"); - } - } - for (unsigned Part = 0; Part < UF; ++Part) - VectorLoopValueMap.setVectorValue(P, Part, Entry[Part]); - return; - } // This PHINode must be an induction variable. // Make sure that we know about it. @@ -4646,7 +4642,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, llvm_unreachable("Unknown induction"); case InductionDescriptor::IK_IntInduction: case InductionDescriptor::IK_FpInduction: - return widenIntOrFpInduction(P); + llvm_unreachable("Integer/fp induction is handled elsewhere."); case InductionDescriptor::IK_PtrInduction: { // Handle the pointer induction variable case. assert(P->getType()->isPointerTy() && "Unexpected type."); @@ -4665,7 +4661,7 @@ void InnerLoopVectorizer::widenPHIInstruction(Instruction *PN, unsigned UF, Value *GlobalIdx = Builder.CreateAdd(PtrInd, Idx); Value *SclrGep = II.transform(Builder, GlobalIdx, PSE.getSE(), DL); SclrGep->setName("next.gep"); - VectorLoopValueMap.setScalarValue(P, Part, Lane, SclrGep); + VectorLoopValueMap.setScalarValue(P, {Part, Lane}, SclrGep); } } return; @@ -4691,25 +4687,11 @@ static bool mayDivideByZero(Instruction &I) { return !CInt || CInt->isZero(); } -void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { - // Scalarize instructions that should remain scalar after vectorization. - if (VF > 1 && - !(isa<BranchInst>(&I) || isa<PHINode>(&I) || isa<DbgInfoIntrinsic>(&I)) && - shouldScalarizeInstruction(&I)) { - scalarizeInstruction(&I, Legal->isScalarWithPredication(&I)); - return; - } - +void InnerLoopVectorizer::widenInstruction(Instruction &I) { switch (I.getOpcode()) { case Instruction::Br: - // Nothing to do for PHIs and BR, since we already took care of the - // loop control flow instructions. - break; - case Instruction::PHI: { - // Vectorize PHINodes. - widenPHIInstruction(&I, UF, VF); - break; - } // End of PHI. + case Instruction::PHI: + llvm_unreachable("This instruction is handled by a different recipe."); case Instruction::GetElementPtr: { // Construct a vector GEP by widening the operands of the scalar GEP as // necessary. We mark the vector GEP 'inbounds' if appropriate. A GEP @@ -4746,7 +4728,6 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { // values in the vector mapping with initVector, as we do for other // instructions. for (unsigned Part = 0; Part < UF; ++Part) { - // The pointer operand of the new GEP. If it's loop-invariant, we // won't broadcast it. auto *Ptr = @@ -4782,13 +4763,6 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { case Instruction::SDiv: case Instruction::SRem: case Instruction::URem: - // Scalarize with predication if this instruction may divide by zero and - // block execution is conditional, otherwise fallthrough. - if (Legal->isScalarWithPredication(&I)) { - scalarizeInstruction(&I, true); - break; - } - LLVM_FALLTHROUGH; case Instruction::Add: case Instruction::FAdd: case Instruction::Sub: @@ -4836,7 +4810,7 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { // We have to take the 'vectorized' value and pick the first lane. // Instcombine will make this a no-op. - auto *ScalarCond = getOrCreateScalarValue(I.getOperand(0), 0, 0); + auto *ScalarCond = getOrCreateScalarValue(I.getOperand(0), {0, 0}); for (unsigned Part = 0; Part < UF; ++Part) { Value *Cond = getOrCreateVectorValue(I.getOperand(0), Part); @@ -4862,8 +4836,10 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { Value *B = getOrCreateVectorValue(Cmp->getOperand(1), Part); Value *C = nullptr; if (FCmp) { + // Propagate fast math flags. + IRBuilder<>::FastMathFlagGuard FMFG(Builder); + Builder.setFastMathFlags(Cmp->getFastMathFlags()); C = Builder.CreateFCmp(Cmp->getPredicate(), A, B); - cast<FCmpInst>(C)->copyFastMathFlags(Cmp); } else { C = Builder.CreateICmp(Cmp->getPredicate(), A, B); } @@ -4874,10 +4850,6 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { break; } - case Instruction::Store: - case Instruction::Load: - vectorizeMemoryInstruction(&I); - break; case Instruction::ZExt: case Instruction::SExt: case Instruction::FPToUI: @@ -4893,16 +4865,6 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { auto *CI = dyn_cast<CastInst>(&I); setDebugLocFromInst(Builder, CI); - // Optimize the special case where the source is a constant integer - // induction variable. Notice that we can only optimize the 'trunc' case - // because (a) FP conversions lose precision, (b) sext/zext may wrap, and - // (c) other casts depend on pointer size. - if (Cost->isOptimizableIVTruncate(CI, VF)) { - widenIntOrFpInduction(cast<PHINode>(CI->getOperand(0)), - cast<TruncInst>(CI)); - break; - } - /// Vectorize casts. Type *DestTy = (VF == 1) ? CI->getType() : VectorType::get(CI->getType(), VF); @@ -4933,11 +4895,7 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); - if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || - ID == Intrinsic::lifetime_start)) { - scalarizeInstruction(&I); - break; - } + // The flag shows whether we use Intrinsic or a usual Call for vectorized // version of the instruction. // Is it beneficial to perform intrinsic call compared to lib call? @@ -4945,10 +4903,8 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); bool UseVectorIntrinsic = ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; - if (!UseVectorIntrinsic && NeedToScalarize) { - scalarizeInstruction(&I); - break; - } + assert((UseVectorIntrinsic || !NeedToScalarize) && + "Instruction should be scalarized elsewhere."); for (unsigned Part = 0; Part < UF; ++Part) { SmallVector<Value *, 4> Args; @@ -4998,9 +4954,9 @@ void InnerLoopVectorizer::vectorizeInstruction(Instruction &I) { } default: - // All other instructions are unsupported. Scalarize them. - scalarizeInstruction(&I); - break; + // This instruction is not vectorized by simple widening. + DEBUG(dbgs() << "LV: Found an unhandled instruction: " << I); + llvm_unreachable("Unhandled instruction!"); } // end of switch. } @@ -5012,14 +4968,11 @@ void InnerLoopVectorizer::updateAnalysis() { assert(DT->properlyDominates(LoopBypassBlocks.front(), LoopExitBlock) && "Entry does not dominate exit."); - DT->addNewBlock(LI->getLoopFor(LoopVectorBody)->getHeader(), - LoopVectorPreHeader); DT->addNewBlock(LoopMiddleBlock, LI->getLoopFor(LoopVectorBody)->getLoopLatch()); DT->addNewBlock(LoopScalarPreHeader, LoopBypassBlocks[0]); DT->changeImmediateDominator(LoopScalarBody, LoopScalarPreHeader); DT->changeImmediateDominator(LoopExitBlock, LoopBypassBlocks[0]); - DEBUG(DT->verifyDomTree()); } @@ -5094,12 +5047,15 @@ bool LoopVectorizationLegality::canVectorize() { // Store the result and return it at the end instead of exiting early, in case // allowExtraAnalysis is used to report multiple reasons for not vectorizing. bool Result = true; + + bool DoExtraAnalysis = ORE->allowExtraAnalysis(DEBUG_TYPE); + if (DoExtraAnalysis) // We must have a loop in canonical form. Loops with indirectbr in them cannot // be canonicalized. if (!TheLoop->getLoopPreheader()) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5112,7 +5068,7 @@ bool LoopVectorizationLegality::canVectorize() { if (!TheLoop->empty()) { ORE->emit(createMissedAnalysis("NotInnermostLoop") << "loop is not the innermost loop"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5122,7 +5078,7 @@ bool LoopVectorizationLegality::canVectorize() { if (TheLoop->getNumBackEdges() != 1) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5132,7 +5088,7 @@ bool LoopVectorizationLegality::canVectorize() { if (!TheLoop->getExitingBlock()) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5144,7 +5100,7 @@ bool LoopVectorizationLegality::canVectorize() { if (TheLoop->getExitingBlock() != TheLoop->getLoopLatch()) { ORE->emit(createMissedAnalysis("CFGNotUnderstood") << "loop control flow is not understood by vectorizer"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5158,7 +5114,7 @@ bool LoopVectorizationLegality::canVectorize() { unsigned NumBlocks = TheLoop->getNumBlocks(); if (NumBlocks != 1 && !canVectorizeWithIfConvert()) { DEBUG(dbgs() << "LV: Can't if-convert the loop.\n"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5167,7 +5123,7 @@ bool LoopVectorizationLegality::canVectorize() { // Check if we can vectorize the instructions and CFG in this loop. if (!canVectorizeInstrs()) { DEBUG(dbgs() << "LV: Can't vectorize the instructions or CFG\n"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5176,7 +5132,7 @@ bool LoopVectorizationLegality::canVectorize() { // Go over each instruction and look at memory deps. if (!canVectorizeMemory()) { DEBUG(dbgs() << "LV: Can't vectorize due to memory conflicts\n"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5207,7 +5163,7 @@ bool LoopVectorizationLegality::canVectorize() { << "Too many SCEV assumptions need to be made and checked " << "at runtime"); DEBUG(dbgs() << "LV: Too many SCEV checks needed.\n"); - if (ORE->allowExtraAnalysis()) + if (DoExtraAnalysis) Result = false; else return false; @@ -5263,6 +5219,15 @@ void LoopVectorizationLegality::addInductionPhi( PHINode *Phi, const InductionDescriptor &ID, SmallPtrSetImpl<Value *> &AllowedExit) { Inductions[Phi] = ID; + + // In case this induction also comes with casts that we know we can ignore + // in the vectorized loop body, record them here. All casts could be recorded + // here for ignoring, but suffices to record only the first (as it is the + // only one that may bw used outside the cast sequence). + const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts(); + if (!Casts.empty()) + InductionCastsToIgnore.insert(*Casts.begin()); + Type *PhiTy = Phi->getType(); const DataLayout &DL = Phi->getModule()->getDataLayout(); @@ -5300,7 +5265,6 @@ void LoopVectorizationLegality::addInductionPhi( } DEBUG(dbgs() << "LV: Found an induction variable.\n"); - return; } bool LoopVectorizationLegality::canVectorizeInstrs() { @@ -5439,7 +5403,7 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { // operations, shuffles, or casts, as they don't change precision or // semantics. } else if (I.getType()->isFloatingPointTy() && (CI || I.isBinaryOp()) && - !I.hasUnsafeAlgebra()) { + !I.isFast()) { DEBUG(dbgs() << "LV: Found FP op with unsafe algebra.\n"); Hints->setPotentiallyUnsafe(); } @@ -5451,7 +5415,6 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { << "value cannot be used outside the loop"); return false; } - } // next instr. } @@ -5474,7 +5437,6 @@ bool LoopVectorizationLegality::canVectorizeInstrs() { } void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { - // We should not collect Scalars more than once per VF. Right now, this // function is called from collectUniformsAndScalars(), which already does // this check. Collecting Scalars for VF=1 does not make any sense. @@ -5517,7 +5479,6 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { // place the pointer in ScalarPtrs. Otherwise, the pointer is placed in // PossibleNonScalarPtrs. auto evaluatePtrUse = [&](Instruction *MemAccess, Value *Ptr) { - // We only care about bitcast and getelementptr instructions contained in // the loop. if (!isLoopVaryingBitCastOrGEP(Ptr)) @@ -5532,7 +5493,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { // If the use of the pointer will be a scalar use, and all users of the // pointer are memory accesses, place the pointer in ScalarPtrs. Otherwise, // place the pointer in PossibleNonScalarPtrs. - if (isScalarUse(MemAccess, Ptr) && all_of(I->users(), [&](User *U) { + if (isScalarUse(MemAccess, Ptr) && llvm::all_of(I->users(), [&](User *U) { return isa<LoadInst>(U) || isa<StoreInst>(U); })) ScalarPtrs.insert(I); @@ -5604,7 +5565,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { if (!isLoopVaryingBitCastOrGEP(Dst->getOperand(0))) continue; auto *Src = cast<Instruction>(Dst->getOperand(0)); - if (all_of(Src->users(), [&](User *U) -> bool { + if (llvm::all_of(Src->users(), [&](User *U) -> bool { auto *J = cast<Instruction>(U); return !TheLoop->contains(J) || Worklist.count(J) || ((isa<LoadInst>(J) || isa<StoreInst>(J)) && @@ -5631,7 +5592,7 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { // Determine if all users of the induction variable are scalar after // vectorization. - auto ScalarInd = all_of(Ind->users(), [&](User *U) -> bool { + auto ScalarInd = llvm::all_of(Ind->users(), [&](User *U) -> bool { auto *I = cast<Instruction>(U); return I == IndUpdate || !TheLoop->contains(I) || Worklist.count(I); }); @@ -5640,10 +5601,11 @@ void LoopVectorizationCostModel::collectLoopScalars(unsigned VF) { // Determine if all users of the induction variable update instruction are // scalar after vectorization. - auto ScalarIndUpdate = all_of(IndUpdate->users(), [&](User *U) -> bool { - auto *I = cast<Instruction>(U); - return I == Ind || !TheLoop->contains(I) || Worklist.count(I); - }); + auto ScalarIndUpdate = + llvm::all_of(IndUpdate->users(), [&](User *U) -> bool { + auto *I = cast<Instruction>(U); + return I == Ind || !TheLoop->contains(I) || Worklist.count(I); + }); if (!ScalarIndUpdate) continue; @@ -5703,7 +5665,6 @@ bool LoopVectorizationLegality::memoryInstructionCanBeWidened(Instruction *I, } void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { - // We should not collect Uniforms more than once per VF. Right now, // this function is called from collectUniformsAndScalars(), which // already does this check. Collecting Uniforms for VF=1 does not make any @@ -5754,6 +5715,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { "Widening decision should be ready at this moment"); return (WideningDecision == CM_Widen || + WideningDecision == CM_Widen_Reverse || WideningDecision == CM_Interleave); }; // Iterate over the instructions in the loop, and collect all @@ -5766,7 +5728,6 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // the getelementptr won't remain uniform. for (auto *BB : TheLoop->blocks()) for (auto &I : *BB) { - // If there's no pointer operand, there's nothing to do. auto *Ptr = dyn_cast_or_null<Instruction>(getPointerOperand(&I)); if (!Ptr) @@ -5774,9 +5735,10 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // True if all users of Ptr are memory accesses that have Ptr as their // pointer operand. - auto UsersAreMemAccesses = all_of(Ptr->users(), [&](User *U) -> bool { - return getPointerOperand(U) == Ptr; - }); + auto UsersAreMemAccesses = + llvm::all_of(Ptr->users(), [&](User *U) -> bool { + return getPointerOperand(U) == Ptr; + }); // Ensure the memory instruction will not be scalarized or used by // gather/scatter, making its pointer operand non-uniform. If the pointer @@ -5812,7 +5774,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { if (isOutOfScope(OV)) continue; auto *OI = cast<Instruction>(OV); - if (all_of(OI->users(), [&](User *U) -> bool { + if (llvm::all_of(OI->users(), [&](User *U) -> bool { auto *J = cast<Instruction>(U); return !TheLoop->contains(J) || Worklist.count(J) || (OI == getPointerOperand(J) && isUniformDecision(J, VF)); @@ -5841,7 +5803,7 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // Determine if all users of the induction variable are uniform after // vectorization. - auto UniformInd = all_of(Ind->users(), [&](User *U) -> bool { + auto UniformInd = llvm::all_of(Ind->users(), [&](User *U) -> bool { auto *I = cast<Instruction>(U); return I == IndUpdate || !TheLoop->contains(I) || Worklist.count(I) || isVectorizedMemAccessUse(I, Ind); @@ -5851,11 +5813,12 @@ void LoopVectorizationCostModel::collectLoopUniforms(unsigned VF) { // Determine if all users of the induction variable update instruction are // uniform after vectorization. - auto UniformIndUpdate = all_of(IndUpdate->users(), [&](User *U) -> bool { - auto *I = cast<Instruction>(U); - return I == Ind || !TheLoop->contains(I) || Worklist.count(I) || - isVectorizedMemAccessUse(I, IndUpdate); - }); + auto UniformIndUpdate = + llvm::all_of(IndUpdate->users(), [&](User *U) -> bool { + auto *I = cast<Instruction>(U); + return I == Ind || !TheLoop->contains(I) || Worklist.count(I) || + isVectorizedMemAccessUse(I, IndUpdate); + }); if (!UniformIndUpdate) continue; @@ -5874,9 +5837,10 @@ bool LoopVectorizationLegality::canVectorizeMemory() { InterleaveInfo.setLAI(LAI); const OptimizationRemarkAnalysis *LAR = LAI->getReport(); if (LAR) { - OptimizationRemarkAnalysis VR(Hints->vectorizeAnalysisPassName(), - "loop not vectorized: ", *LAR); - ORE->emit(VR); + ORE->emit([&]() { + return OptimizationRemarkAnalysis(Hints->vectorizeAnalysisPassName(), + "loop not vectorized: ", *LAR); + }); } if (!LAI->canVectorizeMemory()) return false; @@ -5894,7 +5858,7 @@ bool LoopVectorizationLegality::canVectorizeMemory() { return true; } -bool LoopVectorizationLegality::isInductionVariable(const Value *V) { +bool LoopVectorizationLegality::isInductionPhi(const Value *V) { Value *In0 = const_cast<Value *>(V); PHINode *PN = dyn_cast_or_null<PHINode>(In0); if (!PN) @@ -5903,6 +5867,15 @@ bool LoopVectorizationLegality::isInductionVariable(const Value *V) { return Inductions.count(PN); } +bool LoopVectorizationLegality::isCastedInductionVariable(const Value *V) { + auto *Inst = dyn_cast<Instruction>(V); + return (Inst && InductionCastsToIgnore.count(Inst)); +} + +bool LoopVectorizationLegality::isInductionVariable(const Value *V) { + return isInductionPhi(V) || isCastedInductionVariable(V); +} + bool LoopVectorizationLegality::isFirstOrderRecurrence(const PHINode *Phi) { return FirstOrderRecurrences.count(Phi); } @@ -5972,7 +5945,6 @@ bool LoopVectorizationLegality::blockCanBePredicated( void InterleavedAccessInfo::collectConstStrideAccesses( MapVector<Instruction *, StrideDescriptor> &AccessStrideInfo, const ValueToValueMap &Strides) { - auto &DL = TheLoop->getHeader()->getModule()->getDataLayout(); // Since it's desired that the load/store instructions be maintained in @@ -6126,7 +6098,6 @@ void InterleavedAccessInfo::analyzeInterleaving( // but not with (4). If we did, the dependent access (3) would be within // the boundaries of the (2, 4) group. if (!canReorderMemAccessesForInterleavedGroups(&*AI, &*BI)) { - // If a dependence exists and A is already in a group, we know that A // must be a store since A precedes B and WAR dependences are allowed. // Thus, A would be sunk below B. We release A's group to prevent this @@ -6205,9 +6176,11 @@ void InterleavedAccessInfo::analyzeInterleaving( // Remove interleaved store groups with gaps. for (InterleaveGroup *Group : StoreGroups) - if (Group->getNumMembers() != Group->getFactor()) + if (Group->getNumMembers() != Group->getFactor()) { + DEBUG(dbgs() << "LV: Invalidate candidate interleaved store group due " + "to gaps.\n"); releaseGroup(Group); - + } // Remove interleaved groups with gaps (currently only loads) whose memory // accesses may wrap around. We have to revisit the getPtrStride analysis, // this time with ShouldCheckWrap=true, since collectConstStrideAccesses does @@ -6222,9 +6195,7 @@ void InterleavedAccessInfo::analyzeInterleaving( // This means that we can forcefully peel the loop in order to only have to // check the first pointer for no-wrap. When we'll change to use Assume=true // we'll only need at most one runtime check per interleaved group. - // for (InterleaveGroup *Group : LoadGroups) { - // Case 1: A full group. Can Skip the checks; For full groups, if the wide // load would wrap around the address space we would do a memory access at // nullptr even without the transformation. @@ -6260,6 +6231,8 @@ void InterleavedAccessInfo::analyzeInterleaving( // to look for a member at index factor - 1, since every group must have // a member at index zero. if (Group->isReverse()) { + DEBUG(dbgs() << "LV: Invalidate candidate interleaved group due to " + "a reverse access with gaps.\n"); releaseGroup(Group); continue; } @@ -6277,8 +6250,21 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { return None; } + if (Legal->getRuntimePointerChecking()->Need && TTI.hasBranchDivergence()) { + // TODO: It may by useful to do since it's still likely to be dynamically + // uniform if the target can skip. + DEBUG(dbgs() << "LV: Not inserting runtime ptr check for divergent target"); + + ORE->emit( + createMissedAnalysis("CantVersionLoopWithDivergentTarget") + << "runtime pointer checks needed. Not enabled for divergent target"); + + return None; + } + + unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); if (!OptForSize) // Remaining checks deal with scalar loop when OptForSize. - return computeFeasibleMaxVF(OptForSize); + return computeFeasibleMaxVF(OptForSize, TC); if (Legal->getRuntimePointerChecking()->Need) { ORE->emit(createMissedAnalysis("CantVersionLoopWithOptForSize") @@ -6291,7 +6277,6 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { } // If we optimize the program for size, avoid creating the tail loop. - unsigned TC = PSE.getSE()->getSmallConstantTripCount(TheLoop); DEBUG(dbgs() << "LV: Found trip count: " << TC << '\n'); // If we don't know the precise trip count, don't try to vectorize. @@ -6303,7 +6288,7 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { return None; } - unsigned MaxVF = computeFeasibleMaxVF(OptForSize); + unsigned MaxVF = computeFeasibleMaxVF(OptForSize, TC); if (TC % MaxVF != 0) { // If the trip count that we found modulo the vectorization factor is not @@ -6324,46 +6309,52 @@ Optional<unsigned> LoopVectorizationCostModel::computeMaxVF(bool OptForSize) { return MaxVF; } -unsigned LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize) { +unsigned +LoopVectorizationCostModel::computeFeasibleMaxVF(bool OptForSize, + unsigned ConstTripCount) { MinBWs = computeMinimumValueSizes(TheLoop->getBlocks(), *DB, &TTI); unsigned SmallestType, WidestType; std::tie(SmallestType, WidestType) = getSmallestAndWidestTypes(); unsigned WidestRegister = TTI.getRegisterBitWidth(true); - unsigned MaxSafeDepDist = -1U; - // Get the maximum safe dependence distance in bits computed by LAA. If the - // loop contains any interleaved accesses, we divide the dependence distance - // by the maximum interleave factor of all interleaved groups. Note that - // although the division ensures correctness, this is a fairly conservative - // computation because the maximum distance computed by LAA may not involve - // any of the interleaved accesses. - if (Legal->getMaxSafeDepDistBytes() != -1U) - MaxSafeDepDist = - Legal->getMaxSafeDepDistBytes() * 8 / Legal->getMaxInterleaveFactor(); + // Get the maximum safe dependence distance in bits computed by LAA. + // It is computed by MaxVF * sizeOf(type) * 8, where type is taken from + // the memory accesses that is most restrictive (involved in the smallest + // dependence distance). + unsigned MaxSafeRegisterWidth = Legal->getMaxSafeRegisterWidth(); + + WidestRegister = std::min(WidestRegister, MaxSafeRegisterWidth); - WidestRegister = - ((WidestRegister < MaxSafeDepDist) ? WidestRegister : MaxSafeDepDist); unsigned MaxVectorSize = WidestRegister / WidestType; DEBUG(dbgs() << "LV: The Smallest and Widest types: " << SmallestType << " / " << WidestType << " bits.\n"); - DEBUG(dbgs() << "LV: The Widest register is: " << WidestRegister + DEBUG(dbgs() << "LV: The Widest register safe to use is: " << WidestRegister << " bits.\n"); + assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" + " into one vector!"); if (MaxVectorSize == 0) { DEBUG(dbgs() << "LV: The target has no vector registers.\n"); MaxVectorSize = 1; + return MaxVectorSize; + } else if (ConstTripCount && ConstTripCount < MaxVectorSize && + isPowerOf2_32(ConstTripCount)) { + // We need to clamp the VF to be the ConstTripCount. There is no point in + // choosing a higher viable VF as done in the loop below. + DEBUG(dbgs() << "LV: Clamping the MaxVF to the constant trip count: " + << ConstTripCount << "\n"); + MaxVectorSize = ConstTripCount; + return MaxVectorSize; } - assert(MaxVectorSize <= 64 && "Did not expect to pack so many elements" - " into one vector!"); - unsigned MaxVF = MaxVectorSize; if (MaximizeBandwidth && !OptForSize) { - // Collect all viable vectorization factors. + // Collect all viable vectorization factors larger than the default MaxVF + // (i.e. MaxVectorSize). SmallVector<unsigned, 8> VFs; unsigned NewMaxVectorSize = WidestRegister / SmallestType; - for (unsigned VS = MaxVectorSize; VS <= NewMaxVectorSize; VS *= 2) + for (unsigned VS = MaxVectorSize * 2; VS <= NewMaxVectorSize; VS *= 2) VFs.push_back(VS); // For each VF calculate its register usage. @@ -6485,7 +6476,6 @@ LoopVectorizationCostModel::getSmallestAndWidestTypes() { unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, unsigned VF, unsigned LoopCost) { - // -- The interleave heuristics -- // We interleave the loop in order to expose ILP and reduce the loop overhead. // There are many micro-architectural considerations that we can't predict @@ -6573,7 +6563,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // Interleave if we vectorized this loop and there is a reduction that could // benefit from interleaving. - if (VF > 1 && Legal->getReductionVars()->size()) { + if (VF > 1 && !Legal->getReductionVars()->empty()) { DEBUG(dbgs() << "LV: Interleaving because of reductions.\n"); return IC; } @@ -6604,7 +6594,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // by this point), we can increase the critical path length if the loop // we're interleaving is inside another loop. Limit, by default to 2, so the // critical path only gets increased by one reduction operation. - if (Legal->getReductionVars()->size() && TheLoop->getLoopDepth() > 1) { + if (!Legal->getReductionVars()->empty() && TheLoop->getLoopDepth() > 1) { unsigned F = static_cast<unsigned>(MaxNestedScalarReductionIC); SmallIC = std::min(SmallIC, F); StoresIC = std::min(StoresIC, F); @@ -6623,7 +6613,7 @@ unsigned LoopVectorizationCostModel::selectInterleaveCount(bool OptForSize, // Interleave if this is a large loop (small loops are already dealt with by // this point) that could benefit from interleaving. - bool HasReductions = (Legal->getReductionVars()->size() > 0); + bool HasReductions = !Legal->getReductionVars()->empty(); if (TTI.enableAggressiveInterleaving(HasReductions)) { DEBUG(dbgs() << "LV: Interleaving to expose ILP.\n"); return IC; @@ -6661,7 +6651,8 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { // Each 'key' in the map opens a new interval. The values // of the map are the index of the 'last seen' usage of the // instruction that is the key. - typedef DenseMap<Instruction *, unsigned> IntervalMap; + using IntervalMap = DenseMap<Instruction *, unsigned>; + // Maps instruction to its index. DenseMap<unsigned, Instruction *> IdxToInstr; // Marks the end of each interval. @@ -6700,7 +6691,7 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { } // Saves the list of intervals that end with the index in 'key'. - typedef SmallVector<Instruction *, 2> InstrList; + using InstrList = SmallVector<Instruction *, 2>; DenseMap<unsigned, InstrList> TransposeEnds; // Transpose the EndPoints to a list of values that end at each index. @@ -6795,7 +6786,6 @@ LoopVectorizationCostModel::calculateRegisterUsage(ArrayRef<unsigned> VFs) { } void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { - // If we aren't vectorizing the loop, or if we've already collected the // instructions to scalarize, there's nothing to do. Collection may already // have occurred if we have a user-selected VF and are now computing the @@ -6829,7 +6819,6 @@ void LoopVectorizationCostModel::collectInstsToScalarize(unsigned VF) { int LoopVectorizationCostModel::computePredInstDiscount( Instruction *PredInst, DenseMap<Instruction *, unsigned> &ScalarCosts, unsigned VF) { - assert(!isUniformAfterVectorization(PredInst, VF) && "Instruction marked uniform-after-vectorization will be predicated"); @@ -6844,7 +6833,6 @@ int LoopVectorizationCostModel::computePredInstDiscount( // Returns true if the given instruction can be scalarized. auto canBeScalarized = [&](Instruction *I) -> bool { - // We only attempt to scalarize instructions forming a single-use chain // from the original predicated block that would otherwise be vectorized. // Although not strictly necessary, we give up on instructions we know will @@ -6947,13 +6935,6 @@ LoopVectorizationCostModel::VectorizationCostTy LoopVectorizationCostModel::expectedCost(unsigned VF) { VectorizationCostTy Cost; - // Collect Uniform and Scalar instructions after vectorization with VF. - collectUniformsAndScalars(VF); - - // Collect the instructions (and their associated costs) that will be more - // profitable to scalarize. - collectInstsToScalarize(VF); - // For each block. for (BasicBlock *BB : TheLoop->blocks()) { VectorizationCostTy BlockCost; @@ -6965,7 +6946,8 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { continue; // Skip ignored values. - if (ValuesToIgnore.count(&I)) + if (ValuesToIgnore.count(&I) || + (VF > 1 && VecValuesToIgnore.count(&I))) continue; VectorizationCostTy C = getInstructionCost(&I, VF); @@ -7004,14 +6986,16 @@ LoopVectorizationCostModel::expectedCost(unsigned VF) { static const SCEV *getAddressAccessSCEV( Value *Ptr, LoopVectorizationLegality *Legal, - ScalarEvolution *SE, + PredicatedScalarEvolution &PSE, const Loop *TheLoop) { + auto *Gep = dyn_cast<GetElementPtrInst>(Ptr); if (!Gep) return nullptr; // We are looking for a gep with all loop invariant indices except for one // which should be an induction variable. + auto SE = PSE.getSE(); unsigned NumOperands = Gep->getNumOperands(); for (unsigned i = 1; i < NumOperands; ++i) { Value *Opd = Gep->getOperand(i); @@ -7021,7 +7005,7 @@ static const SCEV *getAddressAccessSCEV( } // Now we know we have a GEP ptr, %inv, %ind, %inv. return the Ptr SCEV. - return SE->getSCEV(Ptr); + return PSE.getSCEV(Ptr); } static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) { @@ -7041,7 +7025,7 @@ unsigned LoopVectorizationCostModel::getMemInstScalarizationCost(Instruction *I, // Figure out whether the access is strided and get the stride value // if it's known in compile time - const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop); + const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, PSE, TheLoop); // Get the cost of the scalar memory instruction and address computation. unsigned Cost = VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV); @@ -7145,7 +7129,6 @@ unsigned LoopVectorizationCostModel::getInterleaveGroupCost(Instruction *I, unsigned LoopVectorizationCostModel::getMemoryInstructionCost(Instruction *I, unsigned VF) { - // Calculate scalar cost only. Vectorization cost should be ready at this // moment. if (VF == 1) { @@ -7202,12 +7185,17 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { // We assume that widening is the best solution when possible. if (Legal->memoryInstructionCanBeWidened(&I, VF)) { unsigned Cost = getConsecutiveMemOpCost(&I, VF); - setWideningDecision(&I, VF, CM_Widen, Cost); + int ConsecutiveStride = Legal->isConsecutivePtr(getPointerOperand(&I)); + assert((ConsecutiveStride == 1 || ConsecutiveStride == -1) && + "Expected consecutive stride."); + InstWidening Decision = + ConsecutiveStride == 1 ? CM_Widen : CM_Widen_Reverse; + setWideningDecision(&I, VF, Decision, Cost); continue; } // Choose between Interleaving, Gather/Scatter or Scalarization. - unsigned InterleaveCost = UINT_MAX; + unsigned InterleaveCost = std::numeric_limits<unsigned>::max(); unsigned NumAccesses = 1; if (Legal->isAccessInterleaved(&I)) { auto Group = Legal->getInterleavedAccessGroup(&I); @@ -7224,7 +7212,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { unsigned GatherScatterCost = Legal->isLegalGatherOrScatter(&I) ? getGatherScatterCost(&I, VF) * NumAccesses - : UINT_MAX; + : std::numeric_limits<unsigned>::max(); unsigned ScalarizationCost = getMemInstScalarizationCost(&I, VF) * NumAccesses; @@ -7282,7 +7270,7 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { for (auto &Op : I->operands()) if (auto *InstOp = dyn_cast<Instruction>(Op)) if ((InstOp->getParent() == I->getParent()) && !isa<PHINode>(InstOp) && - AddrDefs.insert(InstOp).second == true) + AddrDefs.insert(InstOp).second) Worklist.push_back(InstOp); } @@ -7292,7 +7280,8 @@ void LoopVectorizationCostModel::setCostBasedWideningDecision(unsigned VF) { // by cost functions, but since this involves the task of finding out // if the loaded register is involved in an address computation, it is // instead changed here when we know this is the case. - if (getWideningDecision(I, VF) == CM_Widen) + InstWidening Decision = getWideningDecision(I, VF); + if (Decision == CM_Widen || Decision == CM_Widen_Reverse) // Scalarize a widened load of address. setWideningDecision(I, VF, CM_Scalarize, (VF * getMemoryInstructionCost(I, 1))); @@ -7551,7 +7540,9 @@ unsigned LoopVectorizationCostModel::getInstructionCost(Instruction *I, } char LoopVectorize::ID = 0; + static const char lv_name[] = "Loop Vectorization"; + INITIALIZE_PASS_BEGIN(LoopVectorize, LV_NAME, lv_name, false, false) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass) @@ -7568,13 +7559,14 @@ INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(LoopVectorize, LV_NAME, lv_name, false, false) namespace llvm { + Pass *createLoopVectorizePass(bool NoUnrolling, bool AlwaysVectorize) { return new LoopVectorize(NoUnrolling, AlwaysVectorize); } -} -bool LoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { +} // end namespace llvm +bool LoopVectorizationCostModel::isConsecutiveLoadOrStore(Instruction *Inst) { // Check if the pointer operand of a load or store instruction is // consecutive. if (auto *Ptr = getPointerOperand(Inst)) @@ -7593,11 +7585,17 @@ void LoopVectorizationCostModel::collectValuesToIgnore() { SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts(); VecValuesToIgnore.insert(Casts.begin(), Casts.end()); } + // Ignore type-casting instructions we identified during induction + // detection. + for (auto &Induction : *Legal->getInductionVars()) { + InductionDescriptor &IndDes = Induction.second; + const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts(); + VecValuesToIgnore.insert(Casts.begin(), Casts.end()); + } } LoopVectorizationCostModel::VectorizationFactor LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { - // Width 1 means no vectorize, cost 0 means uncomputed cost. const LoopVectorizationCostModel::VectorizationFactor NoVectorization = {1U, 0U}; @@ -7611,11 +7609,26 @@ LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { // Collect the instructions (and their associated costs) that will be more // profitable to scalarize. CM.selectUserVectorizationFactor(UserVF); + buildVPlans(UserVF, UserVF); + DEBUG(printPlans(dbgs())); return {UserVF, 0}; } unsigned MaxVF = MaybeMaxVF.getValue(); assert(MaxVF != 0 && "MaxVF is zero."); + + for (unsigned VF = 1; VF <= MaxVF; VF *= 2) { + // Collect Uniform and Scalar instructions after vectorization with VF. + CM.collectUniformsAndScalars(VF); + + // Collect the instructions (and their associated costs) that will be more + // profitable to scalarize. + if (VF > 1) + CM.collectInstsToScalarize(VF); + } + + buildVPlans(1, MaxVF); + DEBUG(printPlans(dbgs())); if (MaxVF == 1) return NoVectorization; @@ -7623,11 +7636,28 @@ LoopVectorizationPlanner::plan(bool OptForSize, unsigned UserVF) { return CM.selectVectorizationFactor(MaxVF); } -void LoopVectorizationPlanner::executePlan(InnerLoopVectorizer &ILV) { +void LoopVectorizationPlanner::setBestPlan(unsigned VF, unsigned UF) { + DEBUG(dbgs() << "Setting best plan to VF=" << VF << ", UF=" << UF << '\n'); + BestVF = VF; + BestUF = UF; + + erase_if(VPlans, [VF](const VPlanPtr &Plan) { + return !Plan->hasVF(VF); + }); + assert(VPlans.size() == 1 && "Best VF has not a single VPlan."); +} + +void LoopVectorizationPlanner::executePlan(InnerLoopVectorizer &ILV, + DominatorTree *DT) { // Perform the actual loop transformation. // 1. Create a new empty loop. Unlink the old loop and connect the new one. - ILV.createVectorizedLoopSkeleton(); + VPCallbackILV CallbackILV(ILV); + + VPTransformState State{BestVF, BestUF, LI, + DT, ILV.Builder, ILV.VectorLoopValueMap, + &ILV, CallbackILV}; + State.CFG.PrevBB = ILV.createVectorizedLoopSkeleton(); //===------------------------------------------------===// // @@ -7638,36 +7668,8 @@ void LoopVectorizationPlanner::executePlan(InnerLoopVectorizer &ILV) { //===------------------------------------------------===// // 2. Copy and widen instructions from the old loop into the new loop. - - // Move instructions to handle first-order recurrences. - DenseMap<Instruction *, Instruction *> SinkAfter = Legal->getSinkAfter(); - for (auto &Entry : SinkAfter) { - Entry.first->removeFromParent(); - Entry.first->insertAfter(Entry.second); - DEBUG(dbgs() << "Sinking" << *Entry.first << " after" << *Entry.second - << " to vectorize a 1st order recurrence.\n"); - } - - // Collect instructions from the original loop that will become trivially dead - // in the vectorized loop. We don't need to vectorize these instructions. For - // example, original induction update instructions can become dead because we - // separately emit induction "steps" when generating code for the new loop. - // Similarly, we create a new latch condition when setting up the structure - // of the new loop, so the old one can become dead. - SmallPtrSet<Instruction *, 4> DeadInstructions; - collectTriviallyDeadInstructions(DeadInstructions); - - // Scan the loop in a topological order to ensure that defs are vectorized - // before users. - LoopBlocksDFS DFS(OrigLoop); - DFS.perform(LI); - - // Vectorize all instructions in the original loop that will not become - // trivially dead when vectorized. - for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) - for (Instruction &I : *BB) - if (!DeadInstructions.count(&I)) - ILV.vectorizeInstruction(I); + assert(VPlans.size() == 1 && "Not a single VPlan to execute."); + VPlans.front()->execute(&State); // 3. Fix the vectorized code: take care of header phi's, live-outs, // predication, updating analyses. @@ -7691,18 +7693,23 @@ void LoopVectorizationPlanner::collectTriviallyDeadInstructions( for (auto &Induction : *Legal->getInductionVars()) { PHINode *Ind = Induction.first; auto *IndUpdate = cast<Instruction>(Ind->getIncomingValueForBlock(Latch)); - if (all_of(IndUpdate->users(), [&](User *U) -> bool { + if (llvm::all_of(IndUpdate->users(), [&](User *U) -> bool { return U == Ind || DeadInstructions.count(cast<Instruction>(U)); })) DeadInstructions.insert(IndUpdate); - } -} - -void InnerLoopUnroller::vectorizeMemoryInstruction(Instruction *Instr) { - auto *SI = dyn_cast<StoreInst>(Instr); - bool IfPredicateInstr = (SI && Legal->blockNeedsPredication(SI->getParent())); - return scalarizeInstruction(Instr, IfPredicateInstr); + // We record as "Dead" also the type-casting instructions we had identified + // during induction analysis. We don't need any handling for them in the + // vectorized loop because we have proven that, under a proper runtime + // test guarding the vectorized loop, the value of the phi, and the casted + // value of the phi, are the same. The last instruction in this casting chain + // will get its scalar/vector/widened def from the scalar/vector/widened def + // of the respective phi node. Any other casts in the induction def-use chain + // have no other uses outside the phi update chain, and will be ignored. + InductionDescriptor &IndDes = Induction.second; + const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts(); + DeadInstructions.insert(Casts.begin(), Casts.end()); + } } Value *InnerLoopUnroller::reverseVector(Value *Vec) { return Vec; } @@ -7760,6 +7767,722 @@ static void AddRuntimeUnrollDisableMetaData(Loop *L) { } } +bool LoopVectorizationPlanner::getDecisionAndClampRange( + const std::function<bool(unsigned)> &Predicate, VFRange &Range) { + assert(Range.End > Range.Start && "Trying to test an empty VF range."); + bool PredicateAtRangeStart = Predicate(Range.Start); + + for (unsigned TmpVF = Range.Start * 2; TmpVF < Range.End; TmpVF *= 2) + if (Predicate(TmpVF) != PredicateAtRangeStart) { + Range.End = TmpVF; + break; + } + + return PredicateAtRangeStart; +} + +/// Build VPlans for the full range of feasible VF's = {\p MinVF, 2 * \p MinVF, +/// 4 * \p MinVF, ..., \p MaxVF} by repeatedly building a VPlan for a sub-range +/// of VF's starting at a given VF and extending it as much as possible. Each +/// vectorization decision can potentially shorten this sub-range during +/// buildVPlan(). +void LoopVectorizationPlanner::buildVPlans(unsigned MinVF, unsigned MaxVF) { + + // Collect conditions feeding internal conditional branches; they need to be + // represented in VPlan for it to model masking. + SmallPtrSet<Value *, 1> NeedDef; + + auto *Latch = OrigLoop->getLoopLatch(); + for (BasicBlock *BB : OrigLoop->blocks()) { + if (BB == Latch) + continue; + BranchInst *Branch = dyn_cast<BranchInst>(BB->getTerminator()); + if (Branch && Branch->isConditional()) + NeedDef.insert(Branch->getCondition()); + } + + for (unsigned VF = MinVF; VF < MaxVF + 1;) { + VFRange SubRange = {VF, MaxVF + 1}; + VPlans.push_back(buildVPlan(SubRange, NeedDef)); + VF = SubRange.End; + } +} + +VPValue *LoopVectorizationPlanner::createEdgeMask(BasicBlock *Src, + BasicBlock *Dst, + VPlanPtr &Plan) { + assert(is_contained(predecessors(Dst), Src) && "Invalid edge"); + + // Look for cached value. + std::pair<BasicBlock *, BasicBlock *> Edge(Src, Dst); + EdgeMaskCacheTy::iterator ECEntryIt = EdgeMaskCache.find(Edge); + if (ECEntryIt != EdgeMaskCache.end()) + return ECEntryIt->second; + + VPValue *SrcMask = createBlockInMask(Src, Plan); + + // The terminator has to be a branch inst! + BranchInst *BI = dyn_cast<BranchInst>(Src->getTerminator()); + assert(BI && "Unexpected terminator found"); + + if (!BI->isConditional()) + return EdgeMaskCache[Edge] = SrcMask; + + VPValue *EdgeMask = Plan->getVPValue(BI->getCondition()); + assert(EdgeMask && "No Edge Mask found for condition"); + + if (BI->getSuccessor(0) != Dst) + EdgeMask = Builder.createNot(EdgeMask); + + if (SrcMask) // Otherwise block in-mask is all-one, no need to AND. + EdgeMask = Builder.createAnd(EdgeMask, SrcMask); + + return EdgeMaskCache[Edge] = EdgeMask; +} + +VPValue *LoopVectorizationPlanner::createBlockInMask(BasicBlock *BB, + VPlanPtr &Plan) { + assert(OrigLoop->contains(BB) && "Block is not a part of a loop"); + + // Look for cached value. + BlockMaskCacheTy::iterator BCEntryIt = BlockMaskCache.find(BB); + if (BCEntryIt != BlockMaskCache.end()) + return BCEntryIt->second; + + // All-one mask is modelled as no-mask following the convention for masked + // load/store/gather/scatter. Initialize BlockMask to no-mask. + VPValue *BlockMask = nullptr; + + // Loop incoming mask is all-one. + if (OrigLoop->getHeader() == BB) + return BlockMaskCache[BB] = BlockMask; + + // This is the block mask. We OR all incoming edges. + for (auto *Predecessor : predecessors(BB)) { + VPValue *EdgeMask = createEdgeMask(Predecessor, BB, Plan); + if (!EdgeMask) // Mask of predecessor is all-one so mask of block is too. + return BlockMaskCache[BB] = EdgeMask; + + if (!BlockMask) { // BlockMask has its initialized nullptr value. + BlockMask = EdgeMask; + continue; + } + + BlockMask = Builder.createOr(BlockMask, EdgeMask); + } + + return BlockMaskCache[BB] = BlockMask; +} + +VPInterleaveRecipe * +LoopVectorizationPlanner::tryToInterleaveMemory(Instruction *I, + VFRange &Range) { + const InterleaveGroup *IG = Legal->getInterleavedAccessGroup(I); + if (!IG) + return nullptr; + + // Now check if IG is relevant for VF's in the given range. + auto isIGMember = [&](Instruction *I) -> std::function<bool(unsigned)> { + return [=](unsigned VF) -> bool { + return (VF >= 2 && // Query is illegal for VF == 1 + CM.getWideningDecision(I, VF) == + LoopVectorizationCostModel::CM_Interleave); + }; + }; + if (!getDecisionAndClampRange(isIGMember(I), Range)) + return nullptr; + + // I is a member of an InterleaveGroup for VF's in the (possibly trimmed) + // range. If it's the primary member of the IG construct a VPInterleaveRecipe. + // Otherwise, it's an adjunct member of the IG, do not construct any Recipe. + assert(I == IG->getInsertPos() && + "Generating a recipe for an adjunct member of an interleave group"); + + return new VPInterleaveRecipe(IG); +} + +VPWidenMemoryInstructionRecipe * +LoopVectorizationPlanner::tryToWidenMemory(Instruction *I, VFRange &Range, + VPlanPtr &Plan) { + if (!isa<LoadInst>(I) && !isa<StoreInst>(I)) + return nullptr; + + auto willWiden = [&](unsigned VF) -> bool { + if (VF == 1) + return false; + if (CM.isScalarAfterVectorization(I, VF) || + CM.isProfitableToScalarize(I, VF)) + return false; + LoopVectorizationCostModel::InstWidening Decision = + CM.getWideningDecision(I, VF); + assert(Decision != LoopVectorizationCostModel::CM_Unknown && + "CM decision should be taken at this point."); + assert(Decision != LoopVectorizationCostModel::CM_Interleave && + "Interleave memory opportunity should be caught earlier."); + return Decision != LoopVectorizationCostModel::CM_Scalarize; + }; + + if (!getDecisionAndClampRange(willWiden, Range)) + return nullptr; + + VPValue *Mask = nullptr; + if (Legal->isMaskRequired(I)) + Mask = createBlockInMask(I->getParent(), Plan); + + return new VPWidenMemoryInstructionRecipe(*I, Mask); +} + +VPWidenIntOrFpInductionRecipe * +LoopVectorizationPlanner::tryToOptimizeInduction(Instruction *I, + VFRange &Range) { + if (PHINode *Phi = dyn_cast<PHINode>(I)) { + // Check if this is an integer or fp induction. If so, build the recipe that + // produces its scalar and vector values. + InductionDescriptor II = Legal->getInductionVars()->lookup(Phi); + if (II.getKind() == InductionDescriptor::IK_IntInduction || + II.getKind() == InductionDescriptor::IK_FpInduction) + return new VPWidenIntOrFpInductionRecipe(Phi); + + return nullptr; + } + + // Optimize the special case where the source is a constant integer + // induction variable. Notice that we can only optimize the 'trunc' case + // because (a) FP conversions lose precision, (b) sext/zext may wrap, and + // (c) other casts depend on pointer size. + + // Determine whether \p K is a truncation based on an induction variable that + // can be optimized. + auto isOptimizableIVTruncate = + [&](Instruction *K) -> std::function<bool(unsigned)> { + return + [=](unsigned VF) -> bool { return CM.isOptimizableIVTruncate(K, VF); }; + }; + + if (isa<TruncInst>(I) && + getDecisionAndClampRange(isOptimizableIVTruncate(I), Range)) + return new VPWidenIntOrFpInductionRecipe(cast<PHINode>(I->getOperand(0)), + cast<TruncInst>(I)); + return nullptr; +} + +VPBlendRecipe * +LoopVectorizationPlanner::tryToBlend(Instruction *I, VPlanPtr &Plan) { + PHINode *Phi = dyn_cast<PHINode>(I); + if (!Phi || Phi->getParent() == OrigLoop->getHeader()) + return nullptr; + + // We know that all PHIs in non-header blocks are converted into selects, so + // we don't have to worry about the insertion order and we can just use the + // builder. At this point we generate the predication tree. There may be + // duplications since this is a simple recursive scan, but future + // optimizations will clean it up. + + SmallVector<VPValue *, 2> Masks; + unsigned NumIncoming = Phi->getNumIncomingValues(); + for (unsigned In = 0; In < NumIncoming; In++) { + VPValue *EdgeMask = + createEdgeMask(Phi->getIncomingBlock(In), Phi->getParent(), Plan); + assert((EdgeMask || NumIncoming == 1) && + "Multiple predecessors with one having a full mask"); + if (EdgeMask) + Masks.push_back(EdgeMask); + } + return new VPBlendRecipe(Phi, Masks); +} + +bool LoopVectorizationPlanner::tryToWiden(Instruction *I, VPBasicBlock *VPBB, + VFRange &Range) { + if (Legal->isScalarWithPredication(I)) + return false; + + auto IsVectorizableOpcode = [](unsigned Opcode) { + switch (Opcode) { + case Instruction::Add: + case Instruction::And: + case Instruction::AShr: + case Instruction::BitCast: + case Instruction::Br: + case Instruction::Call: + case Instruction::FAdd: + case Instruction::FCmp: + case Instruction::FDiv: + case Instruction::FMul: + case Instruction::FPExt: + case Instruction::FPToSI: + case Instruction::FPToUI: + case Instruction::FPTrunc: + case Instruction::FRem: + case Instruction::FSub: + case Instruction::GetElementPtr: + case Instruction::ICmp: + case Instruction::IntToPtr: + case Instruction::Load: + case Instruction::LShr: + case Instruction::Mul: + case Instruction::Or: + case Instruction::PHI: + case Instruction::PtrToInt: + case Instruction::SDiv: + case Instruction::Select: + case Instruction::SExt: + case Instruction::Shl: + case Instruction::SIToFP: + case Instruction::SRem: + case Instruction::Store: + case Instruction::Sub: + case Instruction::Trunc: + case Instruction::UDiv: + case Instruction::UIToFP: + case Instruction::URem: + case Instruction::Xor: + case Instruction::ZExt: + return true; + } + return false; + }; + + if (!IsVectorizableOpcode(I->getOpcode())) + return false; + + if (CallInst *CI = dyn_cast<CallInst>(I)) { + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + if (ID && (ID == Intrinsic::assume || ID == Intrinsic::lifetime_end || + ID == Intrinsic::lifetime_start || ID == Intrinsic::sideeffect)) + return false; + } + + auto willWiden = [&](unsigned VF) -> bool { + if (!isa<PHINode>(I) && (CM.isScalarAfterVectorization(I, VF) || + CM.isProfitableToScalarize(I, VF))) + return false; + if (CallInst *CI = dyn_cast<CallInst>(I)) { + Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); + // The following case may be scalarized depending on the VF. + // The flag shows whether we use Intrinsic or a usual Call for vectorized + // version of the instruction. + // Is it beneficial to perform intrinsic call compared to lib call? + bool NeedToScalarize; + unsigned CallCost = getVectorCallCost(CI, VF, *TTI, TLI, NeedToScalarize); + bool UseVectorIntrinsic = + ID && getVectorIntrinsicCost(CI, VF, *TTI, TLI) <= CallCost; + return UseVectorIntrinsic || !NeedToScalarize; + } + if (isa<LoadInst>(I) || isa<StoreInst>(I)) { + assert(CM.getWideningDecision(I, VF) == + LoopVectorizationCostModel::CM_Scalarize && + "Memory widening decisions should have been taken care by now"); + return false; + } + return true; + }; + + if (!getDecisionAndClampRange(willWiden, Range)) + return false; + + // Success: widen this instruction. We optimize the common case where + // consecutive instructions can be represented by a single recipe. + if (!VPBB->empty()) { + VPWidenRecipe *LastWidenRecipe = dyn_cast<VPWidenRecipe>(&VPBB->back()); + if (LastWidenRecipe && LastWidenRecipe->appendInstruction(I)) + return true; + } + + VPBB->appendRecipe(new VPWidenRecipe(I)); + return true; +} + +VPBasicBlock *LoopVectorizationPlanner::handleReplication( + Instruction *I, VFRange &Range, VPBasicBlock *VPBB, + DenseMap<Instruction *, VPReplicateRecipe *> &PredInst2Recipe, + VPlanPtr &Plan) { + bool IsUniform = getDecisionAndClampRange( + [&](unsigned VF) { return CM.isUniformAfterVectorization(I, VF); }, + Range); + + bool IsPredicated = Legal->isScalarWithPredication(I); + auto *Recipe = new VPReplicateRecipe(I, IsUniform, IsPredicated); + + // Find if I uses a predicated instruction. If so, it will use its scalar + // value. Avoid hoisting the insert-element which packs the scalar value into + // a vector value, as that happens iff all users use the vector value. + for (auto &Op : I->operands()) + if (auto *PredInst = dyn_cast<Instruction>(Op)) + if (PredInst2Recipe.find(PredInst) != PredInst2Recipe.end()) + PredInst2Recipe[PredInst]->setAlsoPack(false); + + // Finalize the recipe for Instr, first if it is not predicated. + if (!IsPredicated) { + DEBUG(dbgs() << "LV: Scalarizing:" << *I << "\n"); + VPBB->appendRecipe(Recipe); + return VPBB; + } + DEBUG(dbgs() << "LV: Scalarizing and predicating:" << *I << "\n"); + assert(VPBB->getSuccessors().empty() && + "VPBB has successors when handling predicated replication."); + // Record predicated instructions for above packing optimizations. + PredInst2Recipe[I] = Recipe; + VPBlockBase *Region = + VPBB->setOneSuccessor(createReplicateRegion(I, Recipe, Plan)); + return cast<VPBasicBlock>(Region->setOneSuccessor(new VPBasicBlock())); +} + +VPRegionBlock * +LoopVectorizationPlanner::createReplicateRegion(Instruction *Instr, + VPRecipeBase *PredRecipe, + VPlanPtr &Plan) { + // Instructions marked for predication are replicated and placed under an + // if-then construct to prevent side-effects. + + // Generate recipes to compute the block mask for this region. + VPValue *BlockInMask = createBlockInMask(Instr->getParent(), Plan); + + // Build the triangular if-then region. + std::string RegionName = (Twine("pred.") + Instr->getOpcodeName()).str(); + assert(Instr->getParent() && "Predicated instruction not in any basic block"); + auto *BOMRecipe = new VPBranchOnMaskRecipe(BlockInMask); + auto *Entry = new VPBasicBlock(Twine(RegionName) + ".entry", BOMRecipe); + auto *PHIRecipe = + Instr->getType()->isVoidTy() ? nullptr : new VPPredInstPHIRecipe(Instr); + auto *Exit = new VPBasicBlock(Twine(RegionName) + ".continue", PHIRecipe); + auto *Pred = new VPBasicBlock(Twine(RegionName) + ".if", PredRecipe); + VPRegionBlock *Region = new VPRegionBlock(Entry, Exit, RegionName, true); + + // Note: first set Entry as region entry and then connect successors starting + // from it in order, to propagate the "parent" of each VPBasicBlock. + Entry->setTwoSuccessors(Pred, Exit); + Pred->setOneSuccessor(Exit); + + return Region; +} + +LoopVectorizationPlanner::VPlanPtr +LoopVectorizationPlanner::buildVPlan(VFRange &Range, + const SmallPtrSetImpl<Value *> &NeedDef) { + EdgeMaskCache.clear(); + BlockMaskCache.clear(); + DenseMap<Instruction *, Instruction *> &SinkAfter = Legal->getSinkAfter(); + DenseMap<Instruction *, Instruction *> SinkAfterInverse; + + // Collect instructions from the original loop that will become trivially dead + // in the vectorized loop. We don't need to vectorize these instructions. For + // example, original induction update instructions can become dead because we + // separately emit induction "steps" when generating code for the new loop. + // Similarly, we create a new latch condition when setting up the structure + // of the new loop, so the old one can become dead. + SmallPtrSet<Instruction *, 4> DeadInstructions; + collectTriviallyDeadInstructions(DeadInstructions); + + // Hold a mapping from predicated instructions to their recipes, in order to + // fix their AlsoPack behavior if a user is determined to replicate and use a + // scalar instead of vector value. + DenseMap<Instruction *, VPReplicateRecipe *> PredInst2Recipe; + + // Create a dummy pre-entry VPBasicBlock to start building the VPlan. + VPBasicBlock *VPBB = new VPBasicBlock("Pre-Entry"); + auto Plan = llvm::make_unique<VPlan>(VPBB); + + // Represent values that will have defs inside VPlan. + for (Value *V : NeedDef) + Plan->addVPValue(V); + + // Scan the body of the loop in a topological order to visit each basic block + // after having visited its predecessor basic blocks. + LoopBlocksDFS DFS(OrigLoop); + DFS.perform(LI); + + for (BasicBlock *BB : make_range(DFS.beginRPO(), DFS.endRPO())) { + // Relevant instructions from basic block BB will be grouped into VPRecipe + // ingredients and fill a new VPBasicBlock. + unsigned VPBBsForBB = 0; + auto *FirstVPBBForBB = new VPBasicBlock(BB->getName()); + VPBB->setOneSuccessor(FirstVPBBForBB); + VPBB = FirstVPBBForBB; + Builder.setInsertPoint(VPBB); + + std::vector<Instruction *> Ingredients; + + // Organize the ingredients to vectorize from current basic block in the + // right order. + for (Instruction &I : *BB) { + Instruction *Instr = &I; + + // First filter out irrelevant instructions, to ensure no recipes are + // built for them. + if (isa<BranchInst>(Instr) || isa<DbgInfoIntrinsic>(Instr) || + DeadInstructions.count(Instr)) + continue; + + // I is a member of an InterleaveGroup for Range.Start. If it's an adjunct + // member of the IG, do not construct any Recipe for it. + const InterleaveGroup *IG = Legal->getInterleavedAccessGroup(Instr); + if (IG && Instr != IG->getInsertPos() && + Range.Start >= 2 && // Query is illegal for VF == 1 + CM.getWideningDecision(Instr, Range.Start) == + LoopVectorizationCostModel::CM_Interleave) { + if (SinkAfterInverse.count(Instr)) + Ingredients.push_back(SinkAfterInverse.find(Instr)->second); + continue; + } + + // Move instructions to handle first-order recurrences, step 1: avoid + // handling this instruction until after we've handled the instruction it + // should follow. + auto SAIt = SinkAfter.find(Instr); + if (SAIt != SinkAfter.end()) { + DEBUG(dbgs() << "Sinking" << *SAIt->first << " after" << *SAIt->second + << " to vectorize a 1st order recurrence.\n"); + SinkAfterInverse[SAIt->second] = Instr; + continue; + } + + Ingredients.push_back(Instr); + + // Move instructions to handle first-order recurrences, step 2: push the + // instruction to be sunk at its insertion point. + auto SAInvIt = SinkAfterInverse.find(Instr); + if (SAInvIt != SinkAfterInverse.end()) + Ingredients.push_back(SAInvIt->second); + } + + // Introduce each ingredient into VPlan. + for (Instruction *Instr : Ingredients) { + VPRecipeBase *Recipe = nullptr; + + // Check if Instr should belong to an interleave memory recipe, or already + // does. In the latter case Instr is irrelevant. + if ((Recipe = tryToInterleaveMemory(Instr, Range))) { + VPBB->appendRecipe(Recipe); + continue; + } + + // Check if Instr is a memory operation that should be widened. + if ((Recipe = tryToWidenMemory(Instr, Range, Plan))) { + VPBB->appendRecipe(Recipe); + continue; + } + + // Check if Instr should form some PHI recipe. + if ((Recipe = tryToOptimizeInduction(Instr, Range))) { + VPBB->appendRecipe(Recipe); + continue; + } + if ((Recipe = tryToBlend(Instr, Plan))) { + VPBB->appendRecipe(Recipe); + continue; + } + if (PHINode *Phi = dyn_cast<PHINode>(Instr)) { + VPBB->appendRecipe(new VPWidenPHIRecipe(Phi)); + continue; + } + + // Check if Instr is to be widened by a general VPWidenRecipe, after + // having first checked for specific widening recipes that deal with + // Interleave Groups, Inductions and Phi nodes. + if (tryToWiden(Instr, VPBB, Range)) + continue; + + // Otherwise, if all widening options failed, Instruction is to be + // replicated. This may create a successor for VPBB. + VPBasicBlock *NextVPBB = + handleReplication(Instr, Range, VPBB, PredInst2Recipe, Plan); + if (NextVPBB != VPBB) { + VPBB = NextVPBB; + VPBB->setName(BB->hasName() ? BB->getName() + "." + Twine(VPBBsForBB++) + : ""); + } + } + } + + // Discard empty dummy pre-entry VPBasicBlock. Note that other VPBasicBlocks + // may also be empty, such as the last one VPBB, reflecting original + // basic-blocks with no recipes. + VPBasicBlock *PreEntry = cast<VPBasicBlock>(Plan->getEntry()); + assert(PreEntry->empty() && "Expecting empty pre-entry block."); + VPBlockBase *Entry = Plan->setEntry(PreEntry->getSingleSuccessor()); + PreEntry->disconnectSuccessor(Entry); + delete PreEntry; + + std::string PlanName; + raw_string_ostream RSO(PlanName); + unsigned VF = Range.Start; + Plan->addVF(VF); + RSO << "Initial VPlan for VF={" << VF; + for (VF *= 2; VF < Range.End; VF *= 2) { + Plan->addVF(VF); + RSO << "," << VF; + } + RSO << "},UF>=1"; + RSO.flush(); + Plan->setName(PlanName); + + return Plan; +} + +void VPInterleaveRecipe::print(raw_ostream &O, const Twine &Indent) const { + O << " +\n" + << Indent << "\"INTERLEAVE-GROUP with factor " << IG->getFactor() << " at "; + IG->getInsertPos()->printAsOperand(O, false); + O << "\\l\""; + for (unsigned i = 0; i < IG->getFactor(); ++i) + if (Instruction *I = IG->getMember(i)) + O << " +\n" + << Indent << "\" " << VPlanIngredient(I) << " " << i << "\\l\""; +} + +void VPWidenRecipe::execute(VPTransformState &State) { + for (auto &Instr : make_range(Begin, End)) + State.ILV->widenInstruction(Instr); +} + +void VPWidenIntOrFpInductionRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "Int or FP induction being replicated."); + State.ILV->widenIntOrFpInduction(IV, Trunc); +} + +void VPWidenPHIRecipe::execute(VPTransformState &State) { + State.ILV->widenPHIInstruction(Phi, State.UF, State.VF); +} + +void VPBlendRecipe::execute(VPTransformState &State) { + State.ILV->setDebugLocFromInst(State.Builder, Phi); + // We know that all PHIs in non-header blocks are converted into + // selects, so we don't have to worry about the insertion order and we + // can just use the builder. + // At this point we generate the predication tree. There may be + // duplications since this is a simple recursive scan, but future + // optimizations will clean it up. + + unsigned NumIncoming = Phi->getNumIncomingValues(); + + assert((User || NumIncoming == 1) && + "Multiple predecessors with predecessors having a full mask"); + // Generate a sequence of selects of the form: + // SELECT(Mask3, In3, + // SELECT(Mask2, In2, + // ( ...))) + InnerLoopVectorizer::VectorParts Entry(State.UF); + for (unsigned In = 0; In < NumIncoming; ++In) { + for (unsigned Part = 0; Part < State.UF; ++Part) { + // We might have single edge PHIs (blocks) - use an identity + // 'select' for the first PHI operand. + Value *In0 = + State.ILV->getOrCreateVectorValue(Phi->getIncomingValue(In), Part); + if (In == 0) + Entry[Part] = In0; // Initialize with the first incoming value. + else { + // Select between the current value and the previous incoming edge + // based on the incoming mask. + Value *Cond = State.get(User->getOperand(In), Part); + Entry[Part] = + State.Builder.CreateSelect(Cond, In0, Entry[Part], "predphi"); + } + } + } + for (unsigned Part = 0; Part < State.UF; ++Part) + State.ValueMap.setVectorValue(Phi, Part, Entry[Part]); +} + +void VPInterleaveRecipe::execute(VPTransformState &State) { + assert(!State.Instance && "Interleave group being replicated."); + State.ILV->vectorizeInterleaveGroup(IG->getInsertPos()); +} + +void VPReplicateRecipe::execute(VPTransformState &State) { + if (State.Instance) { // Generate a single instance. + State.ILV->scalarizeInstruction(Ingredient, *State.Instance, IsPredicated); + // Insert scalar instance packing it into a vector. + if (AlsoPack && State.VF > 1) { + // If we're constructing lane 0, initialize to start from undef. + if (State.Instance->Lane == 0) { + Value *Undef = + UndefValue::get(VectorType::get(Ingredient->getType(), State.VF)); + State.ValueMap.setVectorValue(Ingredient, State.Instance->Part, Undef); + } + State.ILV->packScalarIntoVectorValue(Ingredient, *State.Instance); + } + return; + } + + // Generate scalar instances for all VF lanes of all UF parts, unless the + // instruction is uniform inwhich case generate only the first lane for each + // of the UF parts. + unsigned EndLane = IsUniform ? 1 : State.VF; + for (unsigned Part = 0; Part < State.UF; ++Part) + for (unsigned Lane = 0; Lane < EndLane; ++Lane) + State.ILV->scalarizeInstruction(Ingredient, {Part, Lane}, IsPredicated); +} + +void VPBranchOnMaskRecipe::execute(VPTransformState &State) { + assert(State.Instance && "Branch on Mask works only on single instance."); + + unsigned Part = State.Instance->Part; + unsigned Lane = State.Instance->Lane; + + Value *ConditionBit = nullptr; + if (!User) // Block in mask is all-one. + ConditionBit = State.Builder.getTrue(); + else { + VPValue *BlockInMask = User->getOperand(0); + ConditionBit = State.get(BlockInMask, Part); + if (ConditionBit->getType()->isVectorTy()) + ConditionBit = State.Builder.CreateExtractElement( + ConditionBit, State.Builder.getInt32(Lane)); + } + + // Replace the temporary unreachable terminator with a new conditional branch, + // whose two destinations will be set later when they are created. + auto *CurrentTerminator = State.CFG.PrevBB->getTerminator(); + assert(isa<UnreachableInst>(CurrentTerminator) && + "Expected to replace unreachable terminator with conditional branch."); + auto *CondBr = BranchInst::Create(State.CFG.PrevBB, nullptr, ConditionBit); + CondBr->setSuccessor(0, nullptr); + ReplaceInstWithInst(CurrentTerminator, CondBr); +} + +void VPPredInstPHIRecipe::execute(VPTransformState &State) { + assert(State.Instance && "Predicated instruction PHI works per instance."); + Instruction *ScalarPredInst = cast<Instruction>( + State.ValueMap.getScalarValue(PredInst, *State.Instance)); + BasicBlock *PredicatedBB = ScalarPredInst->getParent(); + BasicBlock *PredicatingBB = PredicatedBB->getSinglePredecessor(); + assert(PredicatingBB && "Predicated block has no single predecessor."); + + // By current pack/unpack logic we need to generate only a single phi node: if + // a vector value for the predicated instruction exists at this point it means + // the instruction has vector users only, and a phi for the vector value is + // needed. In this case the recipe of the predicated instruction is marked to + // also do that packing, thereby "hoisting" the insert-element sequence. + // Otherwise, a phi node for the scalar value is needed. + unsigned Part = State.Instance->Part; + if (State.ValueMap.hasVectorValue(PredInst, Part)) { + Value *VectorValue = State.ValueMap.getVectorValue(PredInst, Part); + InsertElementInst *IEI = cast<InsertElementInst>(VectorValue); + PHINode *VPhi = State.Builder.CreatePHI(IEI->getType(), 2); + VPhi->addIncoming(IEI->getOperand(0), PredicatingBB); // Unmodified vector. + VPhi->addIncoming(IEI, PredicatedBB); // New vector with inserted element. + State.ValueMap.resetVectorValue(PredInst, Part, VPhi); // Update cache. + } else { + Type *PredInstType = PredInst->getType(); + PHINode *Phi = State.Builder.CreatePHI(PredInstType, 2); + Phi->addIncoming(UndefValue::get(ScalarPredInst->getType()), PredicatingBB); + Phi->addIncoming(ScalarPredInst, PredicatedBB); + State.ValueMap.resetScalarValue(PredInst, *State.Instance, Phi); + } +} + +void VPWidenMemoryInstructionRecipe::execute(VPTransformState &State) { + if (!User) + return State.ILV->vectorizeMemoryInstruction(&Instr); + + // Last (and currently only) operand is a mask. + InnerLoopVectorizer::VectorParts MaskValues(State.UF); + VPValue *Mask = User->getOperand(User->getNumOperands() - 1); + for (unsigned Part = 0; Part < State.UF; ++Part) + MaskValues[Part] = State.get(Mask, Part); + State.ILV->vectorizeMemoryInstruction(&Instr, &MaskValues); +} + bool LoopVectorizePass::processLoop(Loop *L) { assert(L->empty() && "Only process inner loops."); @@ -7878,7 +8601,7 @@ bool LoopVectorizePass::processLoop(Loop *L) { CM.collectValuesToIgnore(); // Use the planner for vectorization. - LoopVectorizationPlanner LVP(L, LI, &LVL, CM); + LoopVectorizationPlanner LVP(L, LI, TLI, TTI, &LVL, CM); // Get user vectorization factor. unsigned UserVF = Hints.getWidth(); @@ -7941,48 +8664,61 @@ bool LoopVectorizePass::processLoop(Loop *L) { const char *VAPassName = Hints.vectorizeAnalysisPassName(); if (!VectorizeLoop && !InterleaveLoop) { // Do not vectorize or interleaving the loop. - ORE->emit(OptimizationRemarkMissed(VAPassName, VecDiagMsg.first, - L->getStartLoc(), L->getHeader()) - << VecDiagMsg.second); - ORE->emit(OptimizationRemarkMissed(LV_NAME, IntDiagMsg.first, - L->getStartLoc(), L->getHeader()) - << IntDiagMsg.second); + ORE->emit([&]() { + return OptimizationRemarkMissed(VAPassName, VecDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << VecDiagMsg.second; + }); + ORE->emit([&]() { + return OptimizationRemarkMissed(LV_NAME, IntDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << IntDiagMsg.second; + }); return false; } else if (!VectorizeLoop && InterleaveLoop) { DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); - ORE->emit(OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, - L->getStartLoc(), L->getHeader()) - << VecDiagMsg.second); + ORE->emit([&]() { + return OptimizationRemarkAnalysis(VAPassName, VecDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << VecDiagMsg.second; + }); } else if (VectorizeLoop && !InterleaveLoop) { DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " << DebugLocStr << '\n'); - ORE->emit(OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, - L->getStartLoc(), L->getHeader()) - << IntDiagMsg.second); + ORE->emit([&]() { + return OptimizationRemarkAnalysis(LV_NAME, IntDiagMsg.first, + L->getStartLoc(), L->getHeader()) + << IntDiagMsg.second; + }); } else if (VectorizeLoop && InterleaveLoop) { DEBUG(dbgs() << "LV: Found a vectorizable loop (" << VF.Width << ") in " << DebugLocStr << '\n'); DEBUG(dbgs() << "LV: Interleave Count is " << IC << '\n'); } + LVP.setBestPlan(VF.Width, IC); + using namespace ore; + if (!VectorizeLoop) { assert(IC > 1 && "interleave count should not be 1 or 0"); // If we decided that it is not legal to vectorize the loop, then // interleave it. InnerLoopUnroller Unroller(L, PSE, LI, DT, TLI, TTI, AC, ORE, IC, &LVL, &CM); - LVP.executePlan(Unroller); + LVP.executePlan(Unroller, DT); - ORE->emit(OptimizationRemark(LV_NAME, "Interleaved", L->getStartLoc(), - L->getHeader()) - << "interleaved loop (interleaved count: " - << NV("InterleaveCount", IC) << ")"); + ORE->emit([&]() { + return OptimizationRemark(LV_NAME, "Interleaved", L->getStartLoc(), + L->getHeader()) + << "interleaved loop (interleaved count: " + << NV("InterleaveCount", IC) << ")"; + }); } else { // If we decided that it is *legal* to vectorize the loop, then do it. InnerLoopVectorizer LB(L, PSE, LI, DT, TLI, TTI, AC, ORE, VF.Width, IC, &LVL, &CM); - LVP.executePlan(LB); + LVP.executePlan(LB, DT); ++LoopsVectorized; // Add metadata to disable runtime unrolling a scalar loop when there are @@ -7992,11 +8728,13 @@ bool LoopVectorizePass::processLoop(Loop *L) { AddRuntimeUnrollDisableMetaData(L); // Report the vectorization decision. - ORE->emit(OptimizationRemark(LV_NAME, "Vectorized", L->getStartLoc(), - L->getHeader()) - << "vectorized loop (vectorization width: " - << NV("VectorizationFactor", VF.Width) - << ", interleaved count: " << NV("InterleaveCount", IC) << ")"); + ORE->emit([&]() { + return OptimizationRemark(LV_NAME, "Vectorized", L->getStartLoc(), + L->getHeader()) + << "vectorized loop (vectorization width: " + << NV("VectorizationFactor", VF.Width) + << ", interleaved count: " << NV("InterleaveCount", IC) << ")"; + }); } // Mark the loop as already vectorized to avoid vectorizing again. @@ -8012,7 +8750,6 @@ bool LoopVectorizePass::runImpl( DemandedBits &DB_, AliasAnalysis &AA_, AssumptionCache &AC_, std::function<const LoopAccessInfo &(Loop &)> &GetLAA_, OptimizationRemarkEmitter &ORE_) { - SE = &SE_; LI = &LI_; TTI = &TTI_; @@ -8068,10 +8805,8 @@ bool LoopVectorizePass::runImpl( // Process each loop nest in the function. return Changed; - } - PreservedAnalyses LoopVectorizePass::run(Function &F, FunctionAnalysisManager &AM) { auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); @@ -8088,7 +8823,7 @@ PreservedAnalyses LoopVectorizePass::run(Function &F, auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); std::function<const LoopAccessInfo &(Loop &)> GetLAA = [&](Loop &L) -> const LoopAccessInfo & { - LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI}; + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI, nullptr}; return LAM.getResult<LoopAccessAnalysis>(L, AR); }; bool Changed = diff --git a/lib/Transforms/Vectorize/SLPVectorizer.cpp b/lib/Transforms/Vectorize/SLPVectorizer.cpp index dcbcab459a6b..76ba62f5d596 100644 --- a/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -6,6 +6,7 @@ // License. See LICENSE.TXT for details. // //===----------------------------------------------------------------------===// +// // This pass implements the Bottom Up SLP vectorizer. It detects consecutive // stores that can be put together into vector-stores. Next, it attempts to // construct vectorizable tree using the use-def chains. If a profitable tree @@ -15,39 +16,89 @@ // "Loop-Aware SLP in GCC" by Ira Rosen, Dorit Nuzman, Ayal Zaks. // //===----------------------------------------------------------------------===// + #include "llvm/Transforms/Vectorize/SLPVectorizer.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/None.h" #include "llvm/ADT/Optional.h" #include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" #include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DemandedBits.h" #include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/ValueTracking.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" #include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/NoFolder.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" #include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" #include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" #include "llvm/IR/Verifier.h" #include "llvm/Pass.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/DOTGraphTraits.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/GraphWriter.h" #include "llvm/Support/KnownBits.h" +#include "llvm/Support/MathExtras.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Transforms/Utils/LoopUtils.h" #include "llvm/Transforms/Vectorize.h" #include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> #include <memory> +#include <set> +#include <string> +#include <tuple> +#include <utility> +#include <vector> using namespace llvm; +using namespace llvm::PatternMatch; using namespace slpvectorizer; #define SV_NAME "slp-vectorizer" @@ -156,6 +207,119 @@ static bool isSplat(ArrayRef<Value *> VL) { return true; } +/// Checks if the vector of instructions can be represented as a shuffle, like: +/// %x0 = extractelement <4 x i8> %x, i32 0 +/// %x3 = extractelement <4 x i8> %x, i32 3 +/// %y1 = extractelement <4 x i8> %y, i32 1 +/// %y2 = extractelement <4 x i8> %y, i32 2 +/// %x0x0 = mul i8 %x0, %x0 +/// %x3x3 = mul i8 %x3, %x3 +/// %y1y1 = mul i8 %y1, %y1 +/// %y2y2 = mul i8 %y2, %y2 +/// %ins1 = insertelement <4 x i8> undef, i8 %x0x0, i32 0 +/// %ins2 = insertelement <4 x i8> %ins1, i8 %x3x3, i32 1 +/// %ins3 = insertelement <4 x i8> %ins2, i8 %y1y1, i32 2 +/// %ins4 = insertelement <4 x i8> %ins3, i8 %y2y2, i32 3 +/// ret <4 x i8> %ins4 +/// can be transformed into: +/// %1 = shufflevector <4 x i8> %x, <4 x i8> %y, <4 x i32> <i32 0, i32 3, i32 5, +/// i32 6> +/// %2 = mul <4 x i8> %1, %1 +/// ret <4 x i8> %2 +/// We convert this initially to something like: +/// %x0 = extractelement <4 x i8> %x, i32 0 +/// %x3 = extractelement <4 x i8> %x, i32 3 +/// %y1 = extractelement <4 x i8> %y, i32 1 +/// %y2 = extractelement <4 x i8> %y, i32 2 +/// %1 = insertelement <4 x i8> undef, i8 %x0, i32 0 +/// %2 = insertelement <4 x i8> %1, i8 %x3, i32 1 +/// %3 = insertelement <4 x i8> %2, i8 %y1, i32 2 +/// %4 = insertelement <4 x i8> %3, i8 %y2, i32 3 +/// %5 = mul <4 x i8> %4, %4 +/// %6 = extractelement <4 x i8> %5, i32 0 +/// %ins1 = insertelement <4 x i8> undef, i8 %6, i32 0 +/// %7 = extractelement <4 x i8> %5, i32 1 +/// %ins2 = insertelement <4 x i8> %ins1, i8 %7, i32 1 +/// %8 = extractelement <4 x i8> %5, i32 2 +/// %ins3 = insertelement <4 x i8> %ins2, i8 %8, i32 2 +/// %9 = extractelement <4 x i8> %5, i32 3 +/// %ins4 = insertelement <4 x i8> %ins3, i8 %9, i32 3 +/// ret <4 x i8> %ins4 +/// InstCombiner transforms this into a shuffle and vector mul +static Optional<TargetTransformInfo::ShuffleKind> +isShuffle(ArrayRef<Value *> VL) { + auto *EI0 = cast<ExtractElementInst>(VL[0]); + unsigned Size = EI0->getVectorOperandType()->getVectorNumElements(); + Value *Vec1 = nullptr; + Value *Vec2 = nullptr; + enum ShuffleMode {Unknown, FirstAlternate, SecondAlternate, Permute}; + ShuffleMode CommonShuffleMode = Unknown; + for (unsigned I = 0, E = VL.size(); I < E; ++I) { + auto *EI = cast<ExtractElementInst>(VL[I]); + auto *Vec = EI->getVectorOperand(); + // All vector operands must have the same number of vector elements. + if (Vec->getType()->getVectorNumElements() != Size) + return None; + auto *Idx = dyn_cast<ConstantInt>(EI->getIndexOperand()); + if (!Idx) + return None; + // Undefined behavior if Idx is negative or >= Size. + if (Idx->getValue().uge(Size)) + continue; + unsigned IntIdx = Idx->getValue().getZExtValue(); + // We can extractelement from undef vector. + if (isa<UndefValue>(Vec)) + continue; + // For correct shuffling we have to have at most 2 different vector operands + // in all extractelement instructions. + if (Vec1 && Vec2 && Vec != Vec1 && Vec != Vec2) + return None; + if (CommonShuffleMode == Permute) + continue; + // If the extract index is not the same as the operation number, it is a + // permutation. + if (IntIdx != I) { + CommonShuffleMode = Permute; + continue; + } + // Check the shuffle mode for the current operation. + if (!Vec1) + Vec1 = Vec; + else if (Vec != Vec1) + Vec2 = Vec; + // Example: shufflevector A, B, <0,5,2,7> + // I is odd and IntIdx for A == I - FirstAlternate shuffle. + // I is even and IntIdx for B == I - FirstAlternate shuffle. + // Example: shufflevector A, B, <4,1,6,3> + // I is even and IntIdx for A == I - SecondAlternate shuffle. + // I is odd and IntIdx for B == I - SecondAlternate shuffle. + const bool IIsEven = I & 1; + const bool CurrVecIsA = Vec == Vec1; + const bool IIsOdd = !IIsEven; + const bool CurrVecIsB = !CurrVecIsA; + ShuffleMode CurrentShuffleMode = + ((IIsOdd && CurrVecIsA) || (IIsEven && CurrVecIsB)) ? FirstAlternate + : SecondAlternate; + // Common mode is not set or the same as the shuffle mode of the current + // operation - alternate. + if (CommonShuffleMode == Unknown) + CommonShuffleMode = CurrentShuffleMode; + // Common shuffle mode is not the same as the shuffle mode of the current + // operation - permutation. + if (CommonShuffleMode != CurrentShuffleMode) + CommonShuffleMode = Permute; + } + // If we're not crossing lanes in different vectors, consider it as blending. + if ((CommonShuffleMode == FirstAlternate || + CommonShuffleMode == SecondAlternate) && + Vec2) + return TargetTransformInfo::SK_Alternate; + // If Vec2 was never used, we have a permutation of a single vector, otherwise + // we have permutation of 2 vectors. + return Vec2 ? TargetTransformInfo::SK_PermuteTwoSrc + : TargetTransformInfo::SK_PermuteSingleSrc; +} + ///\returns Opcode that can be clubbed with \p Op to create an alternate /// sequence which can later be merged as a ShuffleVector instruction. static unsigned getAltOpcode(unsigned Op) { @@ -173,50 +337,107 @@ static unsigned getAltOpcode(unsigned Op) { } } -/// true if the \p Value is odd, false otherwise. static bool isOdd(unsigned Value) { return Value & 1; } -///\returns bool representing if Opcode \p Op can be part -/// of an alternate sequence which can later be merged as -/// a ShuffleVector instruction. -static bool canCombineAsAltInst(unsigned Op) { - return Op == Instruction::FAdd || Op == Instruction::FSub || - Op == Instruction::Sub || Op == Instruction::Add; +static bool sameOpcodeOrAlt(unsigned Opcode, unsigned AltOpcode, + unsigned CheckedOpcode) { + return Opcode == CheckedOpcode || AltOpcode == CheckedOpcode; } -/// \returns ShuffleVector instruction if instructions in \p VL have -/// alternate fadd,fsub / fsub,fadd/add,sub/sub,add sequence. -/// (i.e. e.g. opcodes of fadd,fsub,fadd,fsub...) -static unsigned isAltInst(ArrayRef<Value *> VL) { - Instruction *I0 = dyn_cast<Instruction>(VL[0]); - unsigned Opcode = I0->getOpcode(); - unsigned AltOpcode = getAltOpcode(Opcode); - for (int i = 1, e = VL.size(); i < e; i++) { - Instruction *I = dyn_cast<Instruction>(VL[i]); - if (!I || I->getOpcode() != (isOdd(i) ? AltOpcode : Opcode)) - return 0; - } - return Instruction::ShuffleVector; +/// Chooses the correct key for scheduling data. If \p Op has the same (or +/// alternate) opcode as \p OpValue, the key is \p Op. Otherwise the key is \p +/// OpValue. +static Value *isOneOf(Value *OpValue, Value *Op) { + auto *I = dyn_cast<Instruction>(Op); + if (!I) + return OpValue; + auto *OpInst = cast<Instruction>(OpValue); + unsigned OpInstOpcode = OpInst->getOpcode(); + unsigned IOpcode = I->getOpcode(); + if (sameOpcodeOrAlt(OpInstOpcode, getAltOpcode(OpInstOpcode), IOpcode)) + return Op; + return OpValue; } -/// \returns The opcode if all of the Instructions in \p VL have the same -/// opcode, or zero. -static unsigned getSameOpcode(ArrayRef<Value *> VL) { - Instruction *I0 = dyn_cast<Instruction>(VL[0]); +namespace { + +/// Contains data for the instructions going to be vectorized. +struct RawInstructionsData { + /// Main Opcode of the instructions going to be vectorized. + unsigned Opcode = 0; + + /// The list of instructions have some instructions with alternate opcodes. + bool HasAltOpcodes = false; +}; + +} // end anonymous namespace + +/// Checks the list of the vectorized instructions \p VL and returns info about +/// this list. +static RawInstructionsData getMainOpcode(ArrayRef<Value *> VL) { + auto *I0 = dyn_cast<Instruction>(VL[0]); if (!I0) - return 0; + return {}; + RawInstructionsData Res; unsigned Opcode = I0->getOpcode(); - for (int i = 1, e = VL.size(); i < e; i++) { - Instruction *I = dyn_cast<Instruction>(VL[i]); - if (!I || Opcode != I->getOpcode()) { - if (canCombineAsAltInst(Opcode) && i == 1) - return isAltInst(VL); - return 0; + // Walk through the list of the vectorized instructions + // in order to check its structure described by RawInstructionsData. + for (unsigned Cnt = 0, E = VL.size(); Cnt != E; ++Cnt) { + auto *I = dyn_cast<Instruction>(VL[Cnt]); + if (!I) + return {}; + if (Opcode != I->getOpcode()) + Res.HasAltOpcodes = true; + } + Res.Opcode = Opcode; + return Res; +} + +namespace { + +/// Main data required for vectorization of instructions. +struct InstructionsState { + /// The very first instruction in the list with the main opcode. + Value *OpValue = nullptr; + + /// The main opcode for the list of instructions. + unsigned Opcode = 0; + + /// Some of the instructions in the list have alternate opcodes. + bool IsAltShuffle = false; + + InstructionsState() = default; + InstructionsState(Value *OpValue, unsigned Opcode, bool IsAltShuffle) + : OpValue(OpValue), Opcode(Opcode), IsAltShuffle(IsAltShuffle) {} +}; + +} // end anonymous namespace + +/// \returns analysis of the Instructions in \p VL described in +/// InstructionsState, the Opcode that we suppose the whole list +/// could be vectorized even if its structure is diverse. +static InstructionsState getSameOpcode(ArrayRef<Value *> VL) { + auto Res = getMainOpcode(VL); + unsigned Opcode = Res.Opcode; + if (!Res.HasAltOpcodes) + return InstructionsState(VL[0], Opcode, false); + auto *OpInst = cast<Instruction>(VL[0]); + unsigned AltOpcode = getAltOpcode(Opcode); + // Examine each element in the list instructions VL to determine + // if some operations there could be considered as an alternative + // (for example as subtraction relates to addition operation). + for (int Cnt = 0, E = VL.size(); Cnt < E; Cnt++) { + auto *I = cast<Instruction>(VL[Cnt]); + unsigned InstOpcode = I->getOpcode(); + if ((Res.HasAltOpcodes && + InstOpcode != (isOdd(Cnt) ? AltOpcode : Opcode)) || + (!Res.HasAltOpcodes && InstOpcode != Opcode)) { + return InstructionsState(OpInst, 0, false); } } - return Opcode; + return InstructionsState(OpInst, Opcode, Res.HasAltOpcodes); } /// \returns true if all of the values in \p VL have the same type or false @@ -247,7 +468,6 @@ static bool matchExtractIndex(Instruction *E, unsigned Idx, unsigned Opcode) { /// possible scalar operand in vectorized instruction. static bool InTreeUserNeedToExtract(Value *Scalar, Instruction *UserInst, TargetLibraryInfo *TLI) { - unsigned Opcode = UserInst->getOpcode(); switch (Opcode) { case Instruction::Load: { @@ -292,24 +512,25 @@ static bool isSimple(Instruction *I) { } namespace llvm { + namespace slpvectorizer { + /// Bottom Up SLP Vectorizer. class BoUpSLP { public: - typedef SmallVector<Value *, 8> ValueList; - typedef SmallVector<Instruction *, 16> InstrList; - typedef SmallPtrSet<Value *, 16> ValueSet; - typedef SmallVector<StoreInst *, 8> StoreList; - typedef MapVector<Value *, SmallVector<Instruction *, 2>> - ExtraValueToDebugLocsMap; + using ValueList = SmallVector<Value *, 8>; + using InstrList = SmallVector<Instruction *, 16>; + using ValueSet = SmallPtrSet<Value *, 16>; + using StoreList = SmallVector<StoreInst *, 8>; + using ExtraValueToDebugLocsMap = + MapVector<Value *, SmallVector<Instruction *, 2>>; BoUpSLP(Function *Func, ScalarEvolution *Se, TargetTransformInfo *Tti, TargetLibraryInfo *TLi, AliasAnalysis *Aa, LoopInfo *Li, DominatorTree *Dt, AssumptionCache *AC, DemandedBits *DB, const DataLayout *DL, OptimizationRemarkEmitter *ORE) - : NumLoadsWantToKeepOrder(0), NumLoadsWantToChangeOrder(0), F(Func), - SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), DB(DB), - DL(DL), ORE(ORE), Builder(Se->getContext()) { + : F(Func), SE(Se), TTI(Tti), TLI(TLi), AA(Aa), LI(Li), DT(Dt), AC(AC), + DB(DB), DL(DL), ORE(ORE), Builder(Se->getContext()) { CodeMetrics::collectEphemeralValues(F, AC, EphValues); // Use the vector register size specified by the target unless overridden // by a command-line option. @@ -331,6 +552,7 @@ public: /// \brief Vectorize the tree that starts with the elements in \p VL. /// Returns the vectorized root. Value *vectorizeTree(); + /// Vectorize the tree but with the list of externally used values \p /// ExternallyUsedValues. Values in this MapVector can be replaced but the /// generated extractvalue instructions. @@ -348,6 +570,7 @@ public: /// the purpose of scheduling and extraction in the \p UserIgnoreLst. void buildTree(ArrayRef<Value *> Roots, ArrayRef<Value *> UserIgnoreLst = None); + /// Construct a vectorizable tree that starts at \p Roots, ignoring users for /// the purpose of scheduling and extraction in the \p UserIgnoreLst taking /// into account (anf updating it, if required) list of externally used @@ -374,7 +597,7 @@ public: unsigned getTreeSize() const { return VectorizableTree.size(); } /// \brief Perform LICM and CSE on the newly generated gather sequences. - void optimizeGatherSequence(); + void optimizeGatherSequence(Function &F); /// \returns true if it is beneficial to reverse the vector order. bool shouldReorder() const { @@ -416,21 +639,30 @@ public: private: struct TreeEntry; + /// Checks if all users of \p I are the part of the vectorization tree. + bool areAllUsersVectorized(Instruction *I) const; + /// \returns the cost of the vectorizable entry. int getEntryCost(TreeEntry *E); /// This is the recursive part of buildTree. - void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, int); + void buildTree_rec(ArrayRef<Value *> Roots, unsigned Depth, int UserIndx = -1, + int OpdNum = 0); /// \returns True if the ExtractElement/ExtractValue instructions in VL can /// be vectorized to use the original vector (or aggregate "bitcast" to a vector). - bool canReuseExtract(ArrayRef<Value *> VL, unsigned Opcode) const; + bool canReuseExtract(ArrayRef<Value *> VL, Value *OpValue) const; - /// Vectorize a single entry in the tree. - Value *vectorizeTree(TreeEntry *E); + /// Vectorize a single entry in the tree.\p OpdNum indicate the ordinality of + /// operand corrsponding to this tree entry \p E for the user tree entry + /// indicated by \p UserIndx. + // In other words, "E == TreeEntry[UserIndx].getOperand(OpdNum)". + Value *vectorizeTree(TreeEntry *E, int OpdNum = 0, int UserIndx = -1); - /// Vectorize a single entry in the tree, starting in \p VL. - Value *vectorizeTree(ArrayRef<Value *> VL); + /// Vectorize a single entry in the tree, starting in \p VL.\p OpdNum indicate + /// the ordinality of operand corrsponding to the \p VL of scalar values for the + /// user indicated by \p UserIndx this \p VL feeds into. + Value *vectorizeTree(ArrayRef<Value *> VL, int OpdNum = 0, int UserIndx = -1); /// \returns the pointer to the vectorized value if \p VL is already /// vectorized, or NULL. They may happen in cycles. @@ -447,7 +679,7 @@ private: /// \brief Set the Builder insert point to one after the last instruction in /// the bundle - void setInsertPointAfterBundle(ArrayRef<Value *> VL); + void setInsertPointAfterBundle(ArrayRef<Value *> VL, Value *OpValue); /// \returns a vector from a collection of scalars in \p VL. Value *Gather(ArrayRef<Value *> VL, VectorType *Ty); @@ -458,18 +690,17 @@ private: /// \reorder commutative operands in alt shuffle if they result in /// vectorized code. - void reorderAltShuffleOperands(ArrayRef<Value *> VL, + void reorderAltShuffleOperands(unsigned Opcode, ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right); + /// \reorder commutative operands to get better probability of /// generating vectorized code. - void reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, + void reorderInputsAccordingToOpcode(unsigned Opcode, ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right); struct TreeEntry { - TreeEntry(std::vector<TreeEntry> &Container) - : Scalars(), VectorizedValue(nullptr), NeedToGather(0), - Container(Container) {} + TreeEntry(std::vector<TreeEntry> &Container) : Container(Container) {} /// \returns true if the scalars in VL are equal to this entry. bool isSame(ArrayRef<Value *> VL) const { @@ -477,14 +708,32 @@ private: return std::equal(VL.begin(), VL.end(), Scalars.begin()); } + /// \returns true if the scalars in VL are found in this tree entry. + bool isFoundJumbled(ArrayRef<Value *> VL, const DataLayout &DL, + ScalarEvolution &SE) const { + assert(VL.size() == Scalars.size() && "Invalid size"); + SmallVector<Value *, 8> List; + if (!sortLoadAccesses(VL, DL, SE, List)) + return false; + return std::equal(List.begin(), List.end(), Scalars.begin()); + } + /// A vector of scalars. ValueList Scalars; /// The Scalars are vectorized into this value. It is initialized to Null. - Value *VectorizedValue; + Value *VectorizedValue = nullptr; /// Do we need to gather this sequence ? - bool NeedToGather; + bool NeedToGather = false; + + /// Records optional shuffle mask for the uses of jumbled memory accesses. + /// For example, a non-empty ShuffleMask[1] represents the permutation of + /// lanes that operand #1 of this vectorized instruction should undergo + /// before feeding this vectorized instruction, whereas an empty + /// ShuffleMask[0] indicates that the lanes of operand #0 of this vectorized + /// instruction need not be permuted at all. + SmallVector<SmallVector<unsigned, 4>, 2> ShuffleMask; /// Points back to the VectorizableTree. /// @@ -501,12 +750,31 @@ private: /// Create a new VectorizableTree entry. TreeEntry *newTreeEntry(ArrayRef<Value *> VL, bool Vectorized, - int &UserTreeIdx) { + int &UserTreeIdx, const InstructionsState &S, + ArrayRef<unsigned> ShuffleMask = None, + int OpdNum = 0) { + assert((!Vectorized || S.Opcode != 0) && + "Vectorized TreeEntry without opcode"); VectorizableTree.emplace_back(VectorizableTree); + int idx = VectorizableTree.size() - 1; TreeEntry *Last = &VectorizableTree[idx]; Last->Scalars.insert(Last->Scalars.begin(), VL.begin(), VL.end()); Last->NeedToGather = !Vectorized; + + TreeEntry *UserTreeEntry = nullptr; + if (UserTreeIdx != -1) + UserTreeEntry = &VectorizableTree[UserTreeIdx]; + + if (UserTreeEntry && !ShuffleMask.empty()) { + if ((unsigned)OpdNum >= UserTreeEntry->ShuffleMask.size()) + UserTreeEntry->ShuffleMask.resize(OpdNum + 1); + assert(UserTreeEntry->ShuffleMask[OpdNum].empty() && + "Mask already present"); + using mask = SmallVector<unsigned, 4>; + mask tempMask(ShuffleMask.begin(), ShuffleMask.end()); + UserTreeEntry->ShuffleMask[OpdNum] = tempMask; + } if (Vectorized) { for (int i = 0, e = VL.size(); i != e; ++i) { assert(!getTreeEntry(VL[i]) && "Scalar already in tree!"); @@ -548,16 +816,19 @@ private: /// This POD struct describes one external user in the vectorized tree. struct ExternalUser { - ExternalUser (Value *S, llvm::User *U, int L) : - Scalar(S), User(U), Lane(L){} + ExternalUser(Value *S, llvm::User *U, int L) + : Scalar(S), User(U), Lane(L) {} + // Which scalar in our function. Value *Scalar; + // Which user that uses the scalar. llvm::User *User; + // Which lane does the scalar belong to. int Lane; }; - typedef SmallVector<ExternalUser, 16> UserList; + using UserList = SmallVector<ExternalUser, 16>; /// Checks if two instructions may access the same memory. /// @@ -565,7 +836,6 @@ private: /// is invariant in the calling loop. bool isAliased(const MemoryLocation &Loc1, Instruction *Inst1, Instruction *Inst2) { - // First check if the result is already in the cache. AliasCacheKey key = std::make_pair(Inst1, Inst2); Optional<bool> &result = AliasCache[key]; @@ -583,7 +853,7 @@ private: return aliased; } - typedef std::pair<Instruction *, Instruction *> AliasCacheKey; + using AliasCacheKey = std::pair<Instruction *, Instruction *>; /// Cache for alias results. /// TODO: consider moving this to the AliasAnalysis itself. @@ -616,6 +886,7 @@ private: /// Holds all of the instructions that we gathered. SetVector<Instruction *> GatherSeq; + /// A list of blocks that we are going to CSE. SetVector<BasicBlock *> CSEBlocks; @@ -624,18 +895,13 @@ private: /// instruction bundle (= a group of instructions which is combined into a /// vector instruction). struct ScheduleData { - // The initial value for the dependency counters. It means that the // dependencies are not calculated yet. enum { InvalidDeps = -1 }; - ScheduleData() - : Inst(nullptr), FirstInBundle(nullptr), NextInBundle(nullptr), - NextLoadStore(nullptr), SchedulingRegionID(0), SchedulingPriority(0), - Dependencies(InvalidDeps), UnscheduledDeps(InvalidDeps), - UnscheduledDepsInBundle(InvalidDeps), IsScheduled(false) {} + ScheduleData() = default; - void init(int BlockSchedulingRegionID) { + void init(int BlockSchedulingRegionID, Value *OpVal) { FirstInBundle = this; NextInBundle = nullptr; NextLoadStore = nullptr; @@ -643,6 +909,7 @@ private: SchedulingRegionID = BlockSchedulingRegionID; UnscheduledDepsInBundle = UnscheduledDeps; clearDependencies(); + OpValue = OpVal; } /// Returns true if the dependency information has been calculated. @@ -702,19 +969,19 @@ private: } } - Instruction *Inst; + Instruction *Inst = nullptr; /// Points to the head in an instruction bundle (and always to this for /// single instructions). - ScheduleData *FirstInBundle; + ScheduleData *FirstInBundle = nullptr; /// Single linked list of all instructions in a bundle. Null if it is a /// single instruction. - ScheduleData *NextInBundle; + ScheduleData *NextInBundle = nullptr; /// Single linked list of all memory instructions (e.g. load, store, call) /// in the block - until the end of the scheduling region. - ScheduleData *NextLoadStore; + ScheduleData *NextLoadStore = nullptr; /// The dependent memory instructions. /// This list is derived on demand in calculateDependencies(). @@ -722,31 +989,33 @@ private: /// This ScheduleData is in the current scheduling region if this matches /// the current SchedulingRegionID of BlockScheduling. - int SchedulingRegionID; + int SchedulingRegionID = 0; /// Used for getting a "good" final ordering of instructions. - int SchedulingPriority; + int SchedulingPriority = 0; /// The number of dependencies. Constitutes of the number of users of the /// instruction plus the number of dependent memory instructions (if any). /// This value is calculated on demand. /// If InvalidDeps, the number of dependencies is not calculated yet. - /// - int Dependencies; + int Dependencies = InvalidDeps; /// The number of dependencies minus the number of dependencies of scheduled /// instructions. As soon as this is zero, the instruction/bundle gets ready /// for scheduling. /// Note that this is negative as long as Dependencies is not calculated. - int UnscheduledDeps; + int UnscheduledDeps = InvalidDeps; /// The sum of UnscheduledDeps in a bundle. Equals to UnscheduledDeps for /// single instructions. - int UnscheduledDepsInBundle; + int UnscheduledDepsInBundle = InvalidDeps; /// True if this instruction is scheduled (or considered as scheduled in the /// dry-run). - bool IsScheduled; + bool IsScheduled = false; + + /// Opcode of the current instruction in the schedule data. + Value *OpValue = nullptr; }; #ifndef NDEBUG @@ -756,22 +1025,14 @@ private: return os; } #endif + friend struct GraphTraits<BoUpSLP *>; friend struct DOTGraphTraits<BoUpSLP *>; /// Contains all scheduling data for a basic block. - /// struct BlockScheduling { - BlockScheduling(BasicBlock *BB) - : BB(BB), ChunkSize(BB->size()), ChunkPos(ChunkSize), - ScheduleStart(nullptr), ScheduleEnd(nullptr), - FirstLoadStoreInRegion(nullptr), LastLoadStoreInRegion(nullptr), - ScheduleRegionSize(0), - ScheduleRegionSizeLimit(ScheduleRegionSizeBudget), - // Make sure that the initial SchedulingRegionID is greater than the - // initial SchedulingRegionID in ScheduleData (which is 0). - SchedulingRegionID(1) {} + : BB(BB), ChunkSize(BB->size()), ChunkPos(ChunkSize) {} void clear() { ReadyInsts.clear(); @@ -799,6 +1060,18 @@ private: return nullptr; } + ScheduleData *getScheduleData(Value *V, Value *Key) { + if (V == Key) + return getScheduleData(V); + auto I = ExtraScheduleDataMap.find(V); + if (I != ExtraScheduleDataMap.end()) { + ScheduleData *SD = I->second[Key]; + if (SD && SD->SchedulingRegionID == SchedulingRegionID) + return SD; + } + return nullptr; + } + bool isInSchedulingRegion(ScheduleData *SD) { return SD->SchedulingRegionID == SchedulingRegionID; } @@ -812,19 +1085,29 @@ private: ScheduleData *BundleMember = SD; while (BundleMember) { + if (BundleMember->Inst != BundleMember->OpValue) { + BundleMember = BundleMember->NextInBundle; + continue; + } // Handle the def-use chain dependencies. for (Use &U : BundleMember->Inst->operands()) { - ScheduleData *OpDef = getScheduleData(U.get()); - if (OpDef && OpDef->hasValidDependencies() && - OpDef->incrementUnscheduledDeps(-1) == 0) { - // There are no more unscheduled dependencies after decrementing, - // so we can put the dependent instruction into the ready list. - ScheduleData *DepBundle = OpDef->FirstInBundle; - assert(!DepBundle->IsScheduled && - "already scheduled bundle gets ready"); - ReadyList.insert(DepBundle); - DEBUG(dbgs() << "SLP: gets ready (def): " << *DepBundle << "\n"); - } + auto *I = dyn_cast<Instruction>(U.get()); + if (!I) + continue; + doForAllOpcodes(I, [&ReadyList](ScheduleData *OpDef) { + if (OpDef && OpDef->hasValidDependencies() && + OpDef->incrementUnscheduledDeps(-1) == 0) { + // There are no more unscheduled dependencies after + // decrementing, so we can put the dependent instruction + // into the ready list. + ScheduleData *DepBundle = OpDef->FirstInBundle; + assert(!DepBundle->IsScheduled && + "already scheduled bundle gets ready"); + ReadyList.insert(DepBundle); + DEBUG(dbgs() + << "SLP: gets ready (def): " << *DepBundle << "\n"); + } + }); } // Handle the memory dependencies. for (ScheduleData *MemoryDepSD : BundleMember->MemoryDependencies) { @@ -835,22 +1118,35 @@ private: assert(!DepBundle->IsScheduled && "already scheduled bundle gets ready"); ReadyList.insert(DepBundle); - DEBUG(dbgs() << "SLP: gets ready (mem): " << *DepBundle << "\n"); + DEBUG(dbgs() << "SLP: gets ready (mem): " << *DepBundle + << "\n"); } } BundleMember = BundleMember->NextInBundle; } } + void doForAllOpcodes(Value *V, + function_ref<void(ScheduleData *SD)> Action) { + if (ScheduleData *SD = getScheduleData(V)) + Action(SD); + auto I = ExtraScheduleDataMap.find(V); + if (I != ExtraScheduleDataMap.end()) + for (auto &P : I->second) + if (P.second->SchedulingRegionID == SchedulingRegionID) + Action(P.second); + } + /// Put all instructions into the ReadyList which are ready for scheduling. template <typename ReadyListType> void initialFillReadyList(ReadyListType &ReadyList) { for (auto *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { - ScheduleData *SD = getScheduleData(I); - if (SD->isSchedulingEntity() && SD->isReady()) { - ReadyList.insert(SD); - DEBUG(dbgs() << "SLP: initially in ready list: " << *I << "\n"); - } + doForAllOpcodes(I, [&](ScheduleData *SD) { + if (SD->isSchedulingEntity() && SD->isReady()) { + ReadyList.insert(SD); + DEBUG(dbgs() << "SLP: initially in ready list: " << *I << "\n"); + } + }); } } @@ -862,9 +1158,12 @@ private: /// Un-bundles a group of instructions. void cancelScheduling(ArrayRef<Value *> VL, Value *OpValue); + /// Allocates schedule data chunk. + ScheduleData *allocateScheduleDataChunks(); + /// Extends the scheduling region so that V is inside the region. /// \returns true if the region size is within the limit. - bool extendSchedulingRegion(Value *V); + bool extendSchedulingRegion(Value *V, Value *OpValue); /// Initialize the ScheduleData structures for new instructions in the /// scheduling region. @@ -897,6 +1196,10 @@ private: /// ScheduleData structures are recycled. DenseMap<Value *, ScheduleData *> ScheduleDataMap; + /// Attaches ScheduleData to Instruction with the leading key. + DenseMap<Value *, SmallDenseMap<Value *, ScheduleData *>> + ExtraScheduleDataMap; + struct ReadyList : SmallVector<ScheduleData *, 8> { void insert(ScheduleData *SD) { push_back(SD); } }; @@ -905,28 +1208,30 @@ private: ReadyList ReadyInsts; /// The first instruction of the scheduling region. - Instruction *ScheduleStart; + Instruction *ScheduleStart = nullptr; /// The first instruction _after_ the scheduling region. - Instruction *ScheduleEnd; + Instruction *ScheduleEnd = nullptr; /// The first memory accessing instruction in the scheduling region /// (can be null). - ScheduleData *FirstLoadStoreInRegion; + ScheduleData *FirstLoadStoreInRegion = nullptr; /// The last memory accessing instruction in the scheduling region /// (can be null). - ScheduleData *LastLoadStoreInRegion; + ScheduleData *LastLoadStoreInRegion = nullptr; /// The current size of the scheduling region. - int ScheduleRegionSize; + int ScheduleRegionSize = 0; /// The maximum size allowed for the scheduling region. - int ScheduleRegionSizeLimit; + int ScheduleRegionSizeLimit = ScheduleRegionSizeBudget; /// The ID of the scheduling region. For a new vectorization iteration this /// is incremented which "removes" all ScheduleData from the region. - int SchedulingRegionID; + // Make sure that the initial SchedulingRegionID is greater than the + // initial SchedulingRegionID in ScheduleData (which is 0). + int SchedulingRegionID = 1; }; /// Attaches the BlockScheduling structures to basic blocks. @@ -940,10 +1245,10 @@ private: ArrayRef<Value *> UserIgnoreList; // Number of load bundles that contain consecutive loads. - int NumLoadsWantToKeepOrder; + int NumLoadsWantToKeepOrder = 0; // Number of load bundles that contain consecutive loads in reversed order. - int NumLoadsWantToChangeOrder; + int NumLoadsWantToChangeOrder = 0; // Analysis and block reference. Function *F; @@ -960,6 +1265,7 @@ private: unsigned MaxVecRegSize; // This is set by TTI or overridden by cl::opt. unsigned MinVecRegSize; // Set by cl::opt (default: 128). + /// Instruction builder to construct the vectorized tree. IRBuilder<> Builder; @@ -970,20 +1276,20 @@ private: /// original width. MapVector<Value *, std::pair<uint64_t, bool>> MinBWs; }; + } // end namespace slpvectorizer template <> struct GraphTraits<BoUpSLP *> { - typedef BoUpSLP::TreeEntry TreeEntry; + using TreeEntry = BoUpSLP::TreeEntry; /// NodeRef has to be a pointer per the GraphWriter. - typedef TreeEntry *NodeRef; + using NodeRef = TreeEntry *; /// \brief Add the VectorizableTree to the index iterator to be able to return /// TreeEntry pointers. struct ChildIteratorType : public iterator_adaptor_base<ChildIteratorType, SmallVector<int, 1>::iterator> { - std::vector<TreeEntry> &VectorizableTree; ChildIteratorType(SmallVector<int, 1>::iterator W, @@ -998,17 +1304,19 @@ template <> struct GraphTraits<BoUpSLP *> { static ChildIteratorType child_begin(NodeRef N) { return {N->UserTreeIndices.begin(), N->Container}; } + static ChildIteratorType child_end(NodeRef N) { return {N->UserTreeIndices.end(), N->Container}; } /// For the node iterator we just need to turn the TreeEntry iterator into a /// TreeEntry* iterator so that it dereferences to NodeRef. - typedef pointer_iterator<std::vector<TreeEntry>::iterator> nodes_iterator; + using nodes_iterator = pointer_iterator<std::vector<TreeEntry>::iterator>; static nodes_iterator nodes_begin(BoUpSLP *R) { return nodes_iterator(R->VectorizableTree.begin()); } + static nodes_iterator nodes_end(BoUpSLP *R) { return nodes_iterator(R->VectorizableTree.end()); } @@ -1017,7 +1325,7 @@ template <> struct GraphTraits<BoUpSLP *> { }; template <> struct DOTGraphTraits<BoUpSLP *> : public DefaultDOTGraphTraits { - typedef BoUpSLP::TreeEntry TreeEntry; + using TreeEntry = BoUpSLP::TreeEntry; DOTGraphTraits(bool isSimple = false) : DefaultDOTGraphTraits(isSimple) {} @@ -1054,6 +1362,7 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ExtraValueToDebugLocsMap ExternallyUsedValues; buildTree(Roots, ExternallyUsedValues, UserIgnoreLst); } + void BoUpSLP::buildTree(ArrayRef<Value *> Roots, ExtraValueToDebugLocsMap &ExternallyUsedValues, ArrayRef<Value *> UserIgnoreLst) { @@ -1118,44 +1427,34 @@ void BoUpSLP::buildTree(ArrayRef<Value *> Roots, } void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, - int UserTreeIdx) { - bool isAltShuffle = false; + int UserTreeIdx, int OpdNum) { assert((allConstant(VL) || allSameType(VL)) && "Invalid types!"); + InstructionsState S = getSameOpcode(VL); if (Depth == RecursionMaxDepth) { DEBUG(dbgs() << "SLP: Gathering due to max recursion depth.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } // Don't handle vectors. - if (VL[0]->getType()->isVectorTy()) { + if (S.OpValue->getType()->isVectorTy()) { DEBUG(dbgs() << "SLP: Gathering due to vector type.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } - if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) + if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue)) if (SI->getValueOperand()->getType()->isVectorTy()) { DEBUG(dbgs() << "SLP: Gathering due to store vector type.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } - unsigned Opcode = getSameOpcode(VL); - - // Check that this shuffle vector refers to the alternate - // sequence of opcodes. - if (Opcode == Instruction::ShuffleVector) { - Instruction *I0 = dyn_cast<Instruction>(VL[0]); - unsigned Op = I0->getOpcode(); - if (Op != Instruction::ShuffleVector) - isAltShuffle = true; - } // If all of the operands are identical or constant we have a simple solution. - if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !Opcode) { + if (allConstant(VL) || isSplat(VL) || !allSameBlock(VL) || !S.Opcode) { DEBUG(dbgs() << "SLP: Gathering due to C,S,B,O. \n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } @@ -1167,87 +1466,92 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (EphValues.count(VL[i])) { DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] << ") is ephemeral.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } // Check if this is a duplicate of another entry. - if (TreeEntry *E = getTreeEntry(VL[0])) { + if (TreeEntry *E = getTreeEntry(S.OpValue)) { for (unsigned i = 0, e = VL.size(); i != e; ++i) { DEBUG(dbgs() << "SLP: \tChecking bundle: " << *VL[i] << ".\n"); if (E->Scalars[i] != VL[i]) { DEBUG(dbgs() << "SLP: Gathering due to partial overlap.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } // Record the reuse of the tree node. FIXME, currently this is only used to // properly draw the graph rather than for the actual vectorization. E->UserTreeIndices.push_back(UserTreeIdx); - DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *VL[0] << ".\n"); + DEBUG(dbgs() << "SLP: Perfect diamond merge at " << *S.OpValue << ".\n"); return; } // Check that none of the instructions in the bundle are already in the tree. for (unsigned i = 0, e = VL.size(); i != e; ++i) { - if (ScalarToTreeEntry.count(VL[i])) { + auto *I = dyn_cast<Instruction>(VL[i]); + if (!I) + continue; + if (getTreeEntry(I)) { DEBUG(dbgs() << "SLP: The instruction (" << *VL[i] << ") is already in tree.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } - // If any of the scalars is marked as a value that needs to stay scalar then + // If any of the scalars is marked as a value that needs to stay scalar, then // we need to gather the scalars. for (unsigned i = 0, e = VL.size(); i != e; ++i) { if (MustGather.count(VL[i])) { DEBUG(dbgs() << "SLP: Gathering due to gathered scalar.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } // Check that all of the users of the scalars that we want to vectorize are // schedulable. - Instruction *VL0 = cast<Instruction>(VL[0]); + auto *VL0 = cast<Instruction>(S.OpValue); BasicBlock *BB = VL0->getParent(); if (!DT->isReachableFromEntry(BB)) { // Don't go into unreachable blocks. They may contain instructions with // dependency cycles which confuse the final scheduling. DEBUG(dbgs() << "SLP: bundle in unreachable block.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } - // Check that every instructions appears once in this bundle. + // Check that every instruction appears once in this bundle. for (unsigned i = 0, e = VL.size(); i < e; ++i) - for (unsigned j = i+1; j < e; ++j) + for (unsigned j = i + 1; j < e; ++j) if (VL[i] == VL[j]) { DEBUG(dbgs() << "SLP: Scalar used twice in bundle.\n"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } auto &BSRef = BlocksSchedules[BB]; - if (!BSRef) { + if (!BSRef) BSRef = llvm::make_unique<BlockScheduling>(BB); - } + BlockScheduling &BS = *BSRef.get(); - if (!BS.tryScheduleBundle(VL, this, VL0)) { + if (!BS.tryScheduleBundle(VL, this, S.OpValue)) { DEBUG(dbgs() << "SLP: We are not able to schedule this bundle!\n"); - assert((!BS.getScheduleData(VL[0]) || - !BS.getScheduleData(VL[0])->isPartOfBundle()) && + assert((!BS.getScheduleData(VL0) || + !BS.getScheduleData(VL0)->isPartOfBundle()) && "tryScheduleBundle should cancelScheduling on failure"); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } DEBUG(dbgs() << "SLP: We are able to schedule this bundle.\n"); - switch (Opcode) { + unsigned ShuffleOrOp = S.IsAltShuffle ? + (unsigned) Instruction::ShuffleVector : S.Opcode; + switch (ShuffleOrOp) { case Instruction::PHI: { PHINode *PH = dyn_cast<PHINode>(VL0); @@ -1259,12 +1563,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (Term) { DEBUG(dbgs() << "SLP: Need to swizzle PHINodes (TerminatorInst use).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a vector of PHINodes.\n"); for (unsigned i = 0, e = PH->getNumIncomingValues(); i < e; ++i) { @@ -1274,35 +1578,34 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Operands.push_back(cast<PHINode>(j)->getIncomingValueForBlock( PH->getIncomingBlock(i))); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, UserTreeIdx, i); } return; } case Instruction::ExtractValue: case Instruction::ExtractElement: { - bool Reuse = canReuseExtract(VL, Opcode); + bool Reuse = canReuseExtract(VL, VL0); if (Reuse) { DEBUG(dbgs() << "SLP: Reusing extract sequence.\n"); } else { BS.cancelScheduling(VL, VL0); } - newTreeEntry(VL, Reuse, UserTreeIdx); + newTreeEntry(VL, Reuse, UserTreeIdx, S); return; } case Instruction::Load: { // Check that a vectorized load would load the same memory as a scalar - // load. - // For example we don't want vectorize loads that are smaller than 8 bit. - // Even though we have a packed struct {<i2, i2, i2, i2>} LLVM treats - // loading/storing it as an i8 struct. If we vectorize loads/stores from - // such a struct we read/write packed bits disagreeing with the + // load. For example, we don't want to vectorize loads that are smaller + // than 8-bit. Even though we have a packed struct {<i2, i2, i2, i2>} LLVM + // treats loading/storing it as an i8 struct. If we vectorize loads/stores + // from such a struct, we read/write packed bits disagreeing with the // unvectorized version. - Type *ScalarTy = VL[0]->getType(); + Type *ScalarTy = VL0->getType(); if (DL->getTypeSizeInBits(ScalarTy) != DL->getTypeAllocSizeInBits(ScalarTy)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: Gathering loads of non-packed type.\n"); return; } @@ -1313,15 +1616,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, LoadInst *L = cast<LoadInst>(VL[i]); if (!L->isSimple()) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: Gathering non-simple loads.\n"); return; } } // Check if the loads are consecutive, reversed, or neither. - // TODO: What we really want is to sort the loads, but for now, check - // the two likely directions. bool Consecutive = true; bool ReverseConsecutive = true; for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) { @@ -1335,7 +1636,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, if (Consecutive) { ++NumLoadsWantToKeepOrder; - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a vector of loads.\n"); return; } @@ -1349,15 +1650,41 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, break; } - BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); - if (ReverseConsecutive) { - ++NumLoadsWantToChangeOrder; DEBUG(dbgs() << "SLP: Gathering reversed loads.\n"); - } else { - DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); + ++NumLoadsWantToChangeOrder; + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx, S); + return; + } + + if (VL.size() > 2) { + bool ShuffledLoads = true; + SmallVector<Value *, 8> Sorted; + SmallVector<unsigned, 4> Mask; + if (sortLoadAccesses(VL, *DL, *SE, Sorted, &Mask)) { + auto NewVL = makeArrayRef(Sorted.begin(), Sorted.end()); + for (unsigned i = 0, e = NewVL.size() - 1; i < e; ++i) { + if (!isConsecutiveAccess(NewVL[i], NewVL[i + 1], *DL, *SE)) { + ShuffledLoads = false; + break; + } + } + // TODO: Tracking how many load wants to have arbitrary shuffled order + // would be usefull. + if (ShuffledLoads) { + DEBUG(dbgs() << "SLP: added a vector of loads which needs " + "permutation of loaded lanes.\n"); + newTreeEntry(NewVL, true, UserTreeIdx, S, + makeArrayRef(Mask.begin(), Mask.end()), OpdNum); + return; + } + } } + + DEBUG(dbgs() << "SLP: Gathering non-consecutive loads.\n"); + BS.cancelScheduling(VL, VL0); + newTreeEntry(VL, false, UserTreeIdx, S); return; } case Instruction::ZExt: @@ -1377,12 +1704,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Type *Ty = cast<Instruction>(VL[i])->getOperand(0)->getType(); if (Ty != SrcTy || !isValidElementType(Ty)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: Gathering casts with different src types.\n"); return; } } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a vector of casts.\n"); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { @@ -1391,7 +1718,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, UserTreeIdx, i); } return; } @@ -1399,19 +1726,19 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::FCmp: { // Check that all of the compares have the same predicate. CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); - Type *ComparedTy = cast<Instruction>(VL[0])->getOperand(0)->getType(); + Type *ComparedTy = VL0->getOperand(0)->getType(); for (unsigned i = 1, e = VL.size(); i < e; ++i) { CmpInst *Cmp = cast<CmpInst>(VL[i]); if (Cmp->getPredicate() != P0 || Cmp->getOperand(0)->getType() != ComparedTy) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: Gathering cmp with different predicate.\n"); return; } } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a vector of compares.\n"); for (unsigned i = 0, e = VL0->getNumOperands(); i < e; ++i) { @@ -1420,7 +1747,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, UserTreeIdx, i); } return; } @@ -1442,17 +1769,17 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, case Instruction::AShr: case Instruction::And: case Instruction::Or: - case Instruction::Xor: { - newTreeEntry(VL, true, UserTreeIdx); + case Instruction::Xor: + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a vector of bin op.\n"); // Sort operands of the instructions so that each side is more likely to // have the same opcode. if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) { ValueList Left, Right; - reorderInputsAccordingToOpcode(VL, Left, Right); + reorderInputsAccordingToOpcode(S.Opcode, VL, Left, Right); buildTree_rec(Left, Depth + 1, UserTreeIdx); - buildTree_rec(Right, Depth + 1, UserTreeIdx); + buildTree_rec(Right, Depth + 1, UserTreeIdx, 1); return; } @@ -1462,30 +1789,30 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, UserTreeIdx, i); } return; - } + case Instruction::GetElementPtr: { // We don't combine GEPs with complicated (nested) indexing. for (unsigned j = 0; j < VL.size(); ++j) { if (cast<Instruction>(VL[j])->getNumOperands() != 2) { DEBUG(dbgs() << "SLP: not-vectorizable GEP (nested indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } // We can't combine several GEPs into one vector if they operate on // different types. - Type *Ty0 = cast<Instruction>(VL0)->getOperand(0)->getType(); + Type *Ty0 = VL0->getOperand(0)->getType(); for (unsigned j = 0; j < VL.size(); ++j) { Type *CurTy = cast<Instruction>(VL[j])->getOperand(0)->getType(); if (Ty0 != CurTy) { DEBUG(dbgs() << "SLP: not-vectorizable GEP (different types).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } @@ -1497,12 +1824,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, DEBUG( dbgs() << "SLP: not-vectorizable GEP (non-constant indexes).\n"); BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); return; } } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a vector of GEPs.\n"); for (unsigned i = 0, e = 2; i < e; ++i) { ValueList Operands; @@ -1510,7 +1837,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, UserTreeIdx, i); } return; } @@ -1519,12 +1846,12 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (unsigned i = 0, e = VL.size() - 1; i < e; ++i) if (!isConsecutiveAccess(VL[i], VL[i + 1], *DL, *SE)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: Non-consecutive store.\n"); return; } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a vector of stores.\n"); ValueList Operands; @@ -1536,13 +1863,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, } case Instruction::Call: { // Check if the calls are all to the same vectorizable intrinsic. - CallInst *CI = cast<CallInst>(VL[0]); + CallInst *CI = cast<CallInst>(VL0); // Check if this is an Intrinsic call or something that can be // represented by an intrinsic call Intrinsic::ID ID = getVectorIntrinsicIDForCall(CI, TLI); if (!isTriviallyVectorizable(ID)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: Non-vectorizable call.\n"); return; } @@ -1556,7 +1883,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, getVectorIntrinsicIDForCall(CI2, TLI) != ID || !CI->hasIdenticalOperandBundleSchema(*CI2)) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: mismatched calls:" << *CI << "!=" << *VL[i] << "\n"); return; @@ -1567,7 +1894,7 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, Value *A1J = CI2->getArgOperand(1); if (A1I != A1J) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: mismatched arguments in call:" << *CI << " argument "<< A1I<<"!=" << A1J << "\n"); @@ -1580,14 +1907,14 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CI->op_begin() + CI->getBundleOperandsEndIndex(), CI2->op_begin() + CI2->getBundleOperandsStartIndex())) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: mismatched bundle operands in calls:" << *CI << "!=" << *VL[i] << '\n'); return; } } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) { ValueList Operands; // Prepare the operand vector. @@ -1595,28 +1922,28 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, CallInst *CI2 = dyn_cast<CallInst>(j); Operands.push_back(CI2->getArgOperand(i)); } - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, UserTreeIdx, i); } return; } - case Instruction::ShuffleVector: { + case Instruction::ShuffleVector: // If this is not an alternate sequence of opcode like add-sub // then do not vectorize this instruction. - if (!isAltShuffle) { + if (!S.IsAltShuffle) { BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: ShuffleVector are not vectorized.\n"); return; } - newTreeEntry(VL, true, UserTreeIdx); + newTreeEntry(VL, true, UserTreeIdx, S); DEBUG(dbgs() << "SLP: added a ShuffleVector op.\n"); // Reorder operands if reordering would enable vectorization. if (isa<BinaryOperator>(VL0)) { ValueList Left, Right; - reorderAltShuffleOperands(VL, Left, Right); + reorderAltShuffleOperands(S.Opcode, VL, Left, Right); buildTree_rec(Left, Depth + 1, UserTreeIdx); - buildTree_rec(Right, Depth + 1, UserTreeIdx); + buildTree_rec(Right, Depth + 1, UserTreeIdx, 1); return; } @@ -1626,13 +1953,13 @@ void BoUpSLP::buildTree_rec(ArrayRef<Value *> VL, unsigned Depth, for (Value *j : VL) Operands.push_back(cast<Instruction>(j)->getOperand(i)); - buildTree_rec(Operands, Depth + 1, UserTreeIdx); + buildTree_rec(Operands, Depth + 1, UserTreeIdx, i); } return; - } + default: BS.cancelScheduling(VL, VL0); - newTreeEntry(VL, false, UserTreeIdx); + newTreeEntry(VL, false, UserTreeIdx, S); DEBUG(dbgs() << "SLP: Gathering unknown instruction.\n"); return; } @@ -1663,19 +1990,18 @@ unsigned BoUpSLP::canMapToVector(Type *T, const DataLayout &DL) const { return N; } -bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, unsigned Opcode) const { - assert(Opcode == Instruction::ExtractElement || - Opcode == Instruction::ExtractValue); - assert(Opcode == getSameOpcode(VL) && "Invalid opcode"); +bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, Value *OpValue) const { + Instruction *E0 = cast<Instruction>(OpValue); + assert(E0->getOpcode() == Instruction::ExtractElement || + E0->getOpcode() == Instruction::ExtractValue); + assert(E0->getOpcode() == getSameOpcode(VL).Opcode && "Invalid opcode"); // Check if all of the extracts come from the same vector and from the // correct offset. - Value *VL0 = VL[0]; - Instruction *E0 = cast<Instruction>(VL0); Value *Vec = E0->getOperand(0); // We have to extract from a vector/aggregate with the same number of elements. unsigned NElts; - if (Opcode == Instruction::ExtractValue) { + if (E0->getOpcode() == Instruction::ExtractValue) { const DataLayout &DL = E0->getModule()->getDataLayout(); NElts = canMapToVector(Vec->getType(), DL); if (!NElts) @@ -1692,20 +2018,24 @@ bool BoUpSLP::canReuseExtract(ArrayRef<Value *> VL, unsigned Opcode) const { return false; // Check that all of the indices extract from the correct offset. - if (!matchExtractIndex(E0, 0, Opcode)) - return false; - - for (unsigned i = 1, e = VL.size(); i < e; ++i) { - Instruction *E = cast<Instruction>(VL[i]); - if (!matchExtractIndex(E, i, Opcode)) + for (unsigned I = 0, E = VL.size(); I < E; ++I) { + Instruction *Inst = cast<Instruction>(VL[I]); + if (!matchExtractIndex(Inst, I, Inst->getOpcode())) return false; - if (E->getOperand(0) != Vec) + if (Inst->getOperand(0) != Vec) return false; } return true; } +bool BoUpSLP::areAllUsersVectorized(Instruction *I) const { + return I->hasOneUse() || + std::all_of(I->user_begin(), I->user_end(), [this](User *U) { + return ScalarToTreeEntry.count(U) > 0; + }); +} + int BoUpSLP::getEntryCost(TreeEntry *E) { ArrayRef<Value*> VL = E->Scalars; @@ -1728,28 +2058,47 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { if (isSplat(VL)) { return TTI->getShuffleCost(TargetTransformInfo::SK_Broadcast, VecTy, 0); } + if (getSameOpcode(VL).Opcode == Instruction::ExtractElement) { + Optional<TargetTransformInfo::ShuffleKind> ShuffleKind = isShuffle(VL); + if (ShuffleKind.hasValue()) { + int Cost = TTI->getShuffleCost(ShuffleKind.getValue(), VecTy); + for (auto *V : VL) { + // If all users of instruction are going to be vectorized and this + // instruction itself is not going to be vectorized, consider this + // instruction as dead and remove its cost from the final cost of the + // vectorized tree. + if (areAllUsersVectorized(cast<Instruction>(V)) && + !ScalarToTreeEntry.count(V)) { + auto *IO = cast<ConstantInt>( + cast<ExtractElementInst>(V)->getIndexOperand()); + Cost -= TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, + IO->getZExtValue()); + } + } + return Cost; + } + } return getGatherCost(E->Scalars); } - unsigned Opcode = getSameOpcode(VL); - assert(Opcode && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); - Instruction *VL0 = cast<Instruction>(VL[0]); - switch (Opcode) { - case Instruction::PHI: { + InstructionsState S = getSameOpcode(VL); + assert(S.Opcode && allSameType(VL) && allSameBlock(VL) && "Invalid VL"); + Instruction *VL0 = cast<Instruction>(S.OpValue); + unsigned ShuffleOrOp = S.IsAltShuffle ? + (unsigned) Instruction::ShuffleVector : S.Opcode; + switch (ShuffleOrOp) { + case Instruction::PHI: return 0; - } + case Instruction::ExtractValue: - case Instruction::ExtractElement: { - if (canReuseExtract(VL, Opcode)) { + case Instruction::ExtractElement: + if (canReuseExtract(VL, S.OpValue)) { int DeadCost = 0; for (unsigned i = 0, e = VL.size(); i < e; ++i) { Instruction *E = cast<Instruction>(VL[i]); // If all users are going to be vectorized, instruction can be // considered as dead. // The same, if have only one user, it will be vectorized for sure. - if (E->hasOneUse() || - std::all_of(E->user_begin(), E->user_end(), [this](User *U) { - return ScalarToTreeEntry.count(U) > 0; - })) + if (areAllUsersVectorized(E)) // Take credit for instruction that will become dead. DeadCost += TTI->getVectorInstrCost(Instruction::ExtractElement, VecTy, i); @@ -1757,7 +2106,7 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { return -DeadCost; } return getGatherCost(VecTy); - } + case Instruction::ZExt: case Instruction::SExt: case Instruction::FPToUI: @@ -1786,8 +2135,8 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { // Calculate the cost of this instruction. VectorType *MaskTy = VectorType::get(Builder.getInt1Ty(), VL.size()); int ScalarCost = VecTy->getNumElements() * - TTI->getCmpSelInstrCost(Opcode, ScalarTy, Builder.getInt1Ty(), VL0); - int VecCost = TTI->getCmpSelInstrCost(Opcode, VecTy, MaskTy, VL0); + TTI->getCmpSelInstrCost(S.Opcode, ScalarTy, Builder.getInt1Ty(), VL0); + int VecCost = TTI->getCmpSelInstrCost(S.Opcode, VecTy, MaskTy, VL0); return VecCost - ScalarCost; } case Instruction::Add: @@ -1848,9 +2197,9 @@ int BoUpSLP::getEntryCost(TreeEntry *E) { SmallVector<const Value *, 4> Operands(VL0->operand_values()); int ScalarCost = VecTy->getNumElements() * - TTI->getArithmeticInstrCost(Opcode, ScalarTy, Op1VK, Op2VK, Op1VP, + TTI->getArithmeticInstrCost(S.Opcode, ScalarTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands); - int VecCost = TTI->getArithmeticInstrCost(Opcode, VecTy, Op1VK, Op2VK, + int VecCost = TTI->getArithmeticInstrCost(S.Opcode, VecTy, Op1VK, Op2VK, Op1VP, Op2VP, Operands); return VecCost - ScalarCost; } @@ -1968,7 +2317,6 @@ bool BoUpSLP::isFullyVectorizableTinyTree() { } bool BoUpSLP::isTreeTinyAndNotFullyVectorizable() { - // We can vectorize the tree if its size is greater than or equal to the // minimum size specified by the MinTreeSize command line option. if (VectorizableTree.size() >= MinTreeSize) @@ -2139,13 +2487,18 @@ int BoUpSLP::getGatherCost(ArrayRef<Value *> VL) { // load a[3] + load b[3] // Reordering the second load b[1] load a[1] would allow us to vectorize this // code. -void BoUpSLP::reorderAltShuffleOperands(ArrayRef<Value *> VL, +void BoUpSLP::reorderAltShuffleOperands(unsigned Opcode, ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right) { // Push left and right operands of binary operation into Left and Right - for (Value *i : VL) { - Left.push_back(cast<Instruction>(i)->getOperand(0)); - Right.push_back(cast<Instruction>(i)->getOperand(1)); + unsigned AltOpcode = getAltOpcode(Opcode); + (void)AltOpcode; + for (Value *V : VL) { + auto *I = cast<Instruction>(V); + assert(sameOpcodeOrAlt(Opcode, AltOpcode, I->getOpcode()) && + "Incorrect instruction in vector"); + Left.push_back(I->getOperand(0)); + Right.push_back(I->getOperand(1)); } // Reorder if we have a commutative operation and consecutive access @@ -2190,14 +2543,12 @@ void BoUpSLP::reorderAltShuffleOperands(ArrayRef<Value *> VL, // The vectorizer is trying to either have all elements one side being // instruction with the same opcode to enable further vectorization, or having // a splat to lower the vectorizing cost. -static bool shouldReorderOperands(int i, Instruction &I, - SmallVectorImpl<Value *> &Left, - SmallVectorImpl<Value *> &Right, - bool AllSameOpcodeLeft, - bool AllSameOpcodeRight, bool SplatLeft, - bool SplatRight) { - Value *VLeft = I.getOperand(0); - Value *VRight = I.getOperand(1); +static bool shouldReorderOperands( + int i, unsigned Opcode, Instruction &I, ArrayRef<Value *> Left, + ArrayRef<Value *> Right, bool AllSameOpcodeLeft, bool AllSameOpcodeRight, + bool SplatLeft, bool SplatRight, Value *&VLeft, Value *&VRight) { + VLeft = I.getOperand(0); + VRight = I.getOperand(1); // If we have "SplatRight", try to see if commuting is needed to preserve it. if (SplatRight) { if (VRight == Right[i - 1]) @@ -2253,15 +2604,16 @@ static bool shouldReorderOperands(int i, Instruction &I, return false; } -void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, +void BoUpSLP::reorderInputsAccordingToOpcode(unsigned Opcode, + ArrayRef<Value *> VL, SmallVectorImpl<Value *> &Left, SmallVectorImpl<Value *> &Right) { - - if (VL.size()) { + if (!VL.empty()) { // Peel the first iteration out of the loop since there's nothing // interesting to do anyway and it simplifies the checks in the loop. - auto VLeft = cast<Instruction>(VL[0])->getOperand(0); - auto VRight = cast<Instruction>(VL[0])->getOperand(1); + auto *I = cast<Instruction>(VL[0]); + Value *VLeft = I->getOperand(0); + Value *VRight = I->getOperand(1); if (!isa<Instruction>(VRight) && isa<Instruction>(VLeft)) // Favor having instruction to the right. FIXME: why? std::swap(VLeft, VRight); @@ -2278,16 +2630,21 @@ void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, for (unsigned i = 1, e = VL.size(); i != e; ++i) { Instruction *I = cast<Instruction>(VL[i]); - assert(I->isCommutative() && "Can only process commutative instruction"); + assert(((I->getOpcode() == Opcode && I->isCommutative()) || + (I->getOpcode() != Opcode && Instruction::isCommutative(Opcode))) && + "Can only process commutative instruction"); // Commute to favor either a splat or maximizing having the same opcodes on // one side. - if (shouldReorderOperands(i, *I, Left, Right, AllSameOpcodeLeft, - AllSameOpcodeRight, SplatLeft, SplatRight)) { - Left.push_back(I->getOperand(1)); - Right.push_back(I->getOperand(0)); + Value *VLeft; + Value *VRight; + if (shouldReorderOperands(i, Opcode, *I, Left, Right, AllSameOpcodeLeft, + AllSameOpcodeRight, SplatLeft, SplatRight, VLeft, + VRight)) { + Left.push_back(VRight); + Right.push_back(VLeft); } else { - Left.push_back(I->getOperand(0)); - Right.push_back(I->getOperand(1)); + Left.push_back(VLeft); + Right.push_back(VRight); } // Update Splat* and AllSameOpcode* after the insertion. SplatRight = SplatRight && (Right[i - 1] == Right[i]); @@ -2340,14 +2697,17 @@ void BoUpSLP::reorderInputsAccordingToOpcode(ArrayRef<Value *> VL, } } -void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL) { - +void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL, Value *OpValue) { // Get the basic block this bundle is in. All instructions in the bundle // should be in this block. - auto *Front = cast<Instruction>(VL.front()); + auto *Front = cast<Instruction>(OpValue); auto *BB = Front->getParent(); - assert(all_of(make_range(VL.begin(), VL.end()), [&](Value *V) -> bool { - return cast<Instruction>(V)->getParent() == BB; + const unsigned Opcode = cast<Instruction>(OpValue)->getOpcode(); + const unsigned AltOpcode = getAltOpcode(Opcode); + assert(llvm::all_of(make_range(VL.begin(), VL.end()), [=](Value *V) -> bool { + return !sameOpcodeOrAlt(Opcode, AltOpcode, + cast<Instruction>(V)->getOpcode()) || + cast<Instruction>(V)->getParent() == BB; })); // The last instruction in the bundle in program order. @@ -2358,10 +2718,12 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL) { // VL.back() and iterate over schedule data until we reach the end of the // bundle. The end of the bundle is marked by null ScheduleData. if (BlocksSchedules.count(BB)) { - auto *Bundle = BlocksSchedules[BB]->getScheduleData(VL.back()); + auto *Bundle = + BlocksSchedules[BB]->getScheduleData(isOneOf(OpValue, VL.back())); if (Bundle && Bundle->isPartOfBundle()) for (; Bundle; Bundle = Bundle->NextInBundle) - LastInst = Bundle->Inst; + if (Bundle->OpValue == Bundle->Inst) + LastInst = Bundle->Inst; } // LastInst can still be null at this point if there's either not an entry @@ -2385,7 +2747,7 @@ void BoUpSLP::setInsertPointAfterBundle(ArrayRef<Value *> VL) { if (!LastInst) { SmallPtrSet<Value *, 16> Bundle(VL.begin(), VL.end()); for (auto &I : make_range(BasicBlock::iterator(Front), BB->end())) { - if (Bundle.erase(&I)) + if (Bundle.erase(&I) && sameOpcodeOrAlt(Opcode, AltOpcode, I.getOpcode())) LastInst = &I; if (Bundle.empty()) break; @@ -2435,27 +2797,41 @@ Value *BoUpSLP::alreadyVectorized(ArrayRef<Value *> VL, Value *OpValue) const { return nullptr; } -Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL) { - if (TreeEntry *E = getTreeEntry(VL[0])) - if (E->isSame(VL)) - return vectorizeTree(E); +Value *BoUpSLP::vectorizeTree(ArrayRef<Value *> VL, int OpdNum, int UserIndx) { + InstructionsState S = getSameOpcode(VL); + if (S.Opcode) { + if (TreeEntry *E = getTreeEntry(S.OpValue)) { + TreeEntry *UserTreeEntry = nullptr; + if (UserIndx != -1) + UserTreeEntry = &VectorizableTree[UserIndx]; + + if (E->isSame(VL) || + (UserTreeEntry && + (unsigned)OpdNum < UserTreeEntry->ShuffleMask.size() && + !UserTreeEntry->ShuffleMask[OpdNum].empty() && + E->isFoundJumbled(VL, *DL, *SE))) + return vectorizeTree(E, OpdNum, UserIndx); + } + } - Type *ScalarTy = VL[0]->getType(); - if (StoreInst *SI = dyn_cast<StoreInst>(VL[0])) + Type *ScalarTy = S.OpValue->getType(); + if (StoreInst *SI = dyn_cast<StoreInst>(S.OpValue)) ScalarTy = SI->getValueOperand()->getType(); VectorType *VecTy = VectorType::get(ScalarTy, VL.size()); return Gather(VL, VecTy); } -Value *BoUpSLP::vectorizeTree(TreeEntry *E) { +Value *BoUpSLP::vectorizeTree(TreeEntry *E, int OpdNum, int UserIndx) { IRBuilder<>::InsertPointGuard Guard(Builder); + TreeEntry *UserTreeEntry = nullptr; if (E->VectorizedValue) { DEBUG(dbgs() << "SLP: Diamond merged for " << *E->Scalars[0] << ".\n"); return E->VectorizedValue; } + InstructionsState S = getSameOpcode(E->Scalars); Instruction *VL0 = cast<Instruction>(E->Scalars[0]); Type *ScalarTy = VL0->getType(); if (StoreInst *SI = dyn_cast<StoreInst>(VL0)) @@ -2463,15 +2839,19 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { VectorType *VecTy = VectorType::get(ScalarTy, E->Scalars.size()); if (E->NeedToGather) { - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); auto *V = Gather(E->Scalars, VecTy); E->VectorizedValue = V; return V; } - unsigned Opcode = getSameOpcode(E->Scalars); + assert(ScalarToTreeEntry.count(E->Scalars[0]) && + "Expected user tree entry, missing!"); + int CurrIndx = ScalarToTreeEntry[E->Scalars[0]]; - switch (Opcode) { + unsigned ShuffleOrOp = S.IsAltShuffle ? + (unsigned) Instruction::ShuffleVector : S.Opcode; + switch (ShuffleOrOp) { case Instruction::PHI: { PHINode *PH = dyn_cast<PHINode>(VL0); Builder.SetInsertPoint(PH->getParent()->getFirstNonPHI()); @@ -2498,7 +2878,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { Builder.SetInsertPoint(IBB->getTerminator()); Builder.SetCurrentDebugLocation(PH->getDebugLoc()); - Value *Vec = vectorizeTree(Operands); + Value *Vec = vectorizeTree(Operands, i, CurrIndx); NewPhi->addIncoming(Vec, IBB); } @@ -2508,18 +2888,18 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::ExtractElement: { - if (canReuseExtract(E->Scalars, Instruction::ExtractElement)) { + if (canReuseExtract(E->Scalars, VL0)) { Value *V = VL0->getOperand(0); E->VectorizedValue = V; return V; } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); auto *V = Gather(E->Scalars, VecTy); E->VectorizedValue = V; return V; } case Instruction::ExtractValue: { - if (canReuseExtract(E->Scalars, Instruction::ExtractValue)) { + if (canReuseExtract(E->Scalars, VL0)) { LoadInst *LI = cast<LoadInst>(VL0->getOperand(0)); Builder.SetInsertPoint(LI); PointerType *PtrTy = PointerType::get(VecTy, LI->getPointerAddressSpace()); @@ -2528,7 +2908,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { E->VectorizedValue = V; return propagateMetadata(V, E->Scalars); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); auto *V = Gather(E->Scalars, VecTy); E->VectorizedValue = V; return V; @@ -2549,9 +2929,9 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { for (Value *V : E->Scalars) INVL.push_back(cast<Instruction>(V)->getOperand(0)); - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); - Value *InVec = vectorizeTree(INVL); + Value *InVec = vectorizeTree(INVL, 0, CurrIndx); if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; @@ -2570,23 +2950,23 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { RHSV.push_back(cast<Instruction>(V)->getOperand(1)); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); - Value *L = vectorizeTree(LHSV); - Value *R = vectorizeTree(RHSV); + Value *L = vectorizeTree(LHSV, 0, CurrIndx); + Value *R = vectorizeTree(RHSV, 1, CurrIndx); if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; CmpInst::Predicate P0 = cast<CmpInst>(VL0)->getPredicate(); Value *V; - if (Opcode == Instruction::FCmp) + if (S.Opcode == Instruction::FCmp) V = Builder.CreateFCmp(P0, L, R); else V = Builder.CreateICmp(P0, L, R); E->VectorizedValue = V; - propagateIRFlags(E->VectorizedValue, E->Scalars); + propagateIRFlags(E->VectorizedValue, E->Scalars, VL0); ++NumVectorInstructions; return V; } @@ -2598,11 +2978,11 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { FalseVec.push_back(cast<Instruction>(V)->getOperand(2)); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); - Value *Cond = vectorizeTree(CondVec); - Value *True = vectorizeTree(TrueVec); - Value *False = vectorizeTree(FalseVec); + Value *Cond = vectorizeTree(CondVec, 0, CurrIndx); + Value *True = vectorizeTree(TrueVec, 1, CurrIndx); + Value *False = vectorizeTree(FalseVec, 2, CurrIndx); if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; @@ -2632,25 +3012,27 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Xor: { ValueList LHSVL, RHSVL; if (isa<BinaryOperator>(VL0) && VL0->isCommutative()) - reorderInputsAccordingToOpcode(E->Scalars, LHSVL, RHSVL); + reorderInputsAccordingToOpcode(S.Opcode, E->Scalars, LHSVL, + RHSVL); else for (Value *V : E->Scalars) { - LHSVL.push_back(cast<Instruction>(V)->getOperand(0)); - RHSVL.push_back(cast<Instruction>(V)->getOperand(1)); + auto *I = cast<Instruction>(V); + LHSVL.push_back(I->getOperand(0)); + RHSVL.push_back(I->getOperand(1)); } - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); - Value *LHS = vectorizeTree(LHSVL); - Value *RHS = vectorizeTree(RHSVL); + Value *LHS = vectorizeTree(LHSVL, 0, CurrIndx); + Value *RHS = vectorizeTree(RHSVL, 1, CurrIndx); if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; - BinaryOperator *BinOp = cast<BinaryOperator>(VL0); - Value *V = Builder.CreateBinOp(BinOp->getOpcode(), LHS, RHS); + Value *V = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(S.Opcode), LHS, RHS); E->VectorizedValue = V; - propagateIRFlags(E->VectorizedValue, E->Scalars); + propagateIRFlags(E->VectorizedValue, E->Scalars, VL0); ++NumVectorInstructions; if (Instruction *I = dyn_cast<Instruction>(V)) @@ -2661,9 +3043,22 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { case Instruction::Load: { // Loads are inserted at the head of the tree because we don't want to // sink them all the way down past store instructions. - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); + + if (UserIndx != -1) + UserTreeEntry = &VectorizableTree[UserIndx]; + + bool isJumbled = false; + LoadInst *LI = NULL; + if (UserTreeEntry && + (unsigned)OpdNum < UserTreeEntry->ShuffleMask.size() && + !UserTreeEntry->ShuffleMask[OpdNum].empty()) { + isJumbled = true; + LI = cast<LoadInst>(E->Scalars[0]); + } else { + LI = cast<LoadInst>(VL0); + } - LoadInst *LI = cast<LoadInst>(VL0); Type *ScalarLoadTy = LI->getType(); unsigned AS = LI->getPointerAddressSpace(); @@ -2685,47 +3080,60 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { LI->setAlignment(Alignment); E->VectorizedValue = LI; ++NumVectorInstructions; - return propagateMetadata(LI, E->Scalars); + propagateMetadata(LI, E->Scalars); + + if (isJumbled) { + SmallVector<Constant *, 8> Mask; + for (unsigned LaneEntry : UserTreeEntry->ShuffleMask[OpdNum]) + Mask.push_back(Builder.getInt32(LaneEntry)); + // Generate shuffle for jumbled memory access + Value *Undef = UndefValue::get(VecTy); + Value *Shuf = Builder.CreateShuffleVector((Value *)LI, Undef, + ConstantVector::get(Mask)); + E->VectorizedValue = Shuf; + ++NumVectorInstructions; + return Shuf; + } + return LI; } case Instruction::Store: { StoreInst *SI = cast<StoreInst>(VL0); unsigned Alignment = SI->getAlignment(); unsigned AS = SI->getPointerAddressSpace(); - ValueList ValueOp; + ValueList ScalarStoreValues; for (Value *V : E->Scalars) - ValueOp.push_back(cast<StoreInst>(V)->getValueOperand()); + ScalarStoreValues.push_back(cast<StoreInst>(V)->getValueOperand()); - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); - Value *VecValue = vectorizeTree(ValueOp); - Value *VecPtr = Builder.CreateBitCast(SI->getPointerOperand(), - VecTy->getPointerTo(AS)); + Value *VecValue = vectorizeTree(ScalarStoreValues, 0, CurrIndx); + Value *ScalarPtr = SI->getPointerOperand(); + Value *VecPtr = Builder.CreateBitCast(ScalarPtr, VecTy->getPointerTo(AS)); StoreInst *S = Builder.CreateStore(VecValue, VecPtr); - // The pointer operand uses an in-tree scalar so we add the new BitCast to - // ExternalUses list to make sure that an extract will be generated in the + // The pointer operand uses an in-tree scalar, so add the new BitCast to + // ExternalUses to make sure that an extract will be generated in the // future. - Value *PO = SI->getPointerOperand(); - if (getTreeEntry(PO)) - ExternalUses.push_back(ExternalUser(PO, cast<User>(VecPtr), 0)); + if (getTreeEntry(ScalarPtr)) + ExternalUses.push_back(ExternalUser(ScalarPtr, cast<User>(VecPtr), 0)); - if (!Alignment) { + if (!Alignment) Alignment = DL->getABITypeAlignment(SI->getValueOperand()->getType()); - } + S->setAlignment(Alignment); E->VectorizedValue = S; ++NumVectorInstructions; return propagateMetadata(S, E->Scalars); } case Instruction::GetElementPtr: { - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); ValueList Op0VL; for (Value *V : E->Scalars) Op0VL.push_back(cast<GetElementPtrInst>(V)->getOperand(0)); - Value *Op0 = vectorizeTree(Op0VL); + Value *Op0 = vectorizeTree(Op0VL, 0, CurrIndx); std::vector<Value *> OpVecs; for (int j = 1, e = cast<GetElementPtrInst>(VL0)->getNumOperands(); j < e; @@ -2734,7 +3142,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { for (Value *V : E->Scalars) OpVL.push_back(cast<GetElementPtrInst>(V)->getOperand(j)); - Value *OpVec = vectorizeTree(OpVL); + Value *OpVec = vectorizeTree(OpVL, j, CurrIndx); OpVecs.push_back(OpVec); } @@ -2750,7 +3158,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { } case Instruction::Call: { CallInst *CI = cast<CallInst>(VL0); - setInsertPointAfterBundle(E->Scalars); + setInsertPointAfterBundle(E->Scalars, VL0); Function *FI; Intrinsic::ID IID = Intrinsic::not_intrinsic; Value *ScalarArg = nullptr; @@ -2763,7 +3171,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { // ctlz,cttz and powi are special intrinsics whose second argument is // a scalar. This argument should not be vectorized. if (hasVectorInstrinsicScalarOpd(IID, 1) && j == 1) { - CallInst *CEI = cast<CallInst>(E->Scalars[0]); + CallInst *CEI = cast<CallInst>(VL0); ScalarArg = CEI->getArgOperand(j); OpVecs.push_back(CEI->getArgOperand(j)); continue; @@ -2773,7 +3181,7 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { OpVL.push_back(CEI->getArgOperand(j)); } - Value *OpVec = vectorizeTree(OpVL); + Value *OpVec = vectorizeTree(OpVL, j, CurrIndx); DEBUG(dbgs() << "SLP: OpVec[" << j << "]: " << *OpVec << "\n"); OpVecs.push_back(OpVec); } @@ -2793,30 +3201,31 @@ Value *BoUpSLP::vectorizeTree(TreeEntry *E) { ExternalUses.push_back(ExternalUser(ScalarArg, cast<User>(V), 0)); E->VectorizedValue = V; - propagateIRFlags(E->VectorizedValue, E->Scalars); + propagateIRFlags(E->VectorizedValue, E->Scalars, VL0); ++NumVectorInstructions; return V; } case Instruction::ShuffleVector: { ValueList LHSVL, RHSVL; - assert(isa<BinaryOperator>(VL0) && "Invalid Shuffle Vector Operand"); - reorderAltShuffleOperands(E->Scalars, LHSVL, RHSVL); - setInsertPointAfterBundle(E->Scalars); + assert(Instruction::isBinaryOp(S.Opcode) && + "Invalid Shuffle Vector Operand"); + reorderAltShuffleOperands(S.Opcode, E->Scalars, LHSVL, RHSVL); + setInsertPointAfterBundle(E->Scalars, VL0); - Value *LHS = vectorizeTree(LHSVL); - Value *RHS = vectorizeTree(RHSVL); + Value *LHS = vectorizeTree(LHSVL, 0, CurrIndx); + Value *RHS = vectorizeTree(RHSVL, 1, CurrIndx); if (Value *V = alreadyVectorized(E->Scalars, VL0)) return V; // Create a vector of LHS op1 RHS - BinaryOperator *BinOp0 = cast<BinaryOperator>(VL0); - Value *V0 = Builder.CreateBinOp(BinOp0->getOpcode(), LHS, RHS); + Value *V0 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(S.Opcode), LHS, RHS); + unsigned AltOpcode = getAltOpcode(S.Opcode); // Create a vector of LHS op2 RHS - Instruction *VL1 = cast<Instruction>(E->Scalars[1]); - BinaryOperator *BinOp1 = cast<BinaryOperator>(VL1); - Value *V1 = Builder.CreateBinOp(BinOp1->getOpcode(), LHS, RHS); + Value *V1 = Builder.CreateBinOp( + static_cast<Instruction::BinaryOps>(AltOpcode), LHS, RHS); // Create shuffle to take alternate operations from the vector. // Also, gather up odd and even scalar ops to propagate IR flags to @@ -2859,7 +3268,6 @@ Value *BoUpSLP::vectorizeTree() { Value * BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { - // All blocks must be scheduled before any instructions are inserted. for (auto &BSIter : BlocksSchedules) { scheduleBlock(BSIter.second.get()); @@ -2905,9 +3313,14 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { continue; TreeEntry *E = getTreeEntry(Scalar); assert(E && "Invalid scalar"); - assert(!E->NeedToGather && "Extracting from a gather list"); + assert((!E->NeedToGather) && "Extracting from a gather list"); - Value *Vec = E->VectorizedValue; + Value *Vec = dyn_cast<ShuffleVectorInst>(E->VectorizedValue); + if (Vec && dyn_cast<LoadInst>(cast<Instruction>(Vec)->getOperand(0))) { + Vec = cast<Instruction>(E->VectorizedValue)->getOperand(0); + } else { + Vec = E->VectorizedValue; + } assert(Vec && "Can't find vectorizable value"); Value *Lane = Builder.getInt32(ExternalUse.Lane); @@ -2975,14 +3388,15 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { for (TreeEntry &EIdx : VectorizableTree) { TreeEntry *Entry = &EIdx; + // No need to handle users of gathered values. + if (Entry->NeedToGather) + continue; + + assert(Entry->VectorizedValue && "Can't find vectorizable value"); + // For each lane: for (int Lane = 0, LE = Entry->Scalars.size(); Lane != LE; ++Lane) { Value *Scalar = Entry->Scalars[Lane]; - // No need to handle users of gathered values. - if (Entry->NeedToGather) - continue; - - assert(Entry->VectorizedValue && "Can't find vectorizable value"); Type *Ty = Scalar->getType(); if (!Ty->isVoidTy()) { @@ -2990,9 +3404,8 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { for (User *U : Scalar->users()) { DEBUG(dbgs() << "SLP: \tvalidating user:" << *U << ".\n"); - assert((getTreeEntry(U) || - // It is legal to replace users in the ignorelist by undef. - is_contained(UserIgnoreList, U)) && + // It is legal to replace users in the ignorelist by undef. + assert((getTreeEntry(U) || is_contained(UserIgnoreList, U)) && "Replacing out-of-tree value with undef"); } #endif @@ -3009,7 +3422,7 @@ BoUpSLP::vectorizeTree(ExtraValueToDebugLocsMap &ExternallyUsedValues) { return VectorizableTree[0].VectorizedValue; } -void BoUpSLP::optimizeGatherSequence() { +void BoUpSLP::optimizeGatherSequence(Function &F) { DEBUG(dbgs() << "SLP: Optimizing " << GatherSeq.size() << " gather sequences instructions.\n"); // LICM InsertElementInst sequences. @@ -3043,30 +3456,16 @@ void BoUpSLP::optimizeGatherSequence() { Insert->moveBefore(PreHeader->getTerminator()); } - // Make a list of all reachable blocks in our CSE queue. - SmallVector<const DomTreeNode *, 8> CSEWorkList; - CSEWorkList.reserve(CSEBlocks.size()); - for (BasicBlock *BB : CSEBlocks) - if (DomTreeNode *N = DT->getNode(BB)) { - assert(DT->isReachableFromEntry(N)); - CSEWorkList.push_back(N); - } - - // Sort blocks by domination. This ensures we visit a block after all blocks - // dominating it are visited. - std::stable_sort(CSEWorkList.begin(), CSEWorkList.end(), - [this](const DomTreeNode *A, const DomTreeNode *B) { - return DT->properlyDominates(A, B); - }); - // Perform O(N^2) search over the gather sequences and merge identical // instructions. TODO: We can further optimize this scan if we split the // instructions into different buckets based on the insert lane. SmallVector<Instruction *, 16> Visited; - for (auto I = CSEWorkList.begin(), E = CSEWorkList.end(); I != E; ++I) { - assert((I == CSEWorkList.begin() || !DT->dominates(*I, *std::prev(I))) && - "Worklist not sorted properly!"); - BasicBlock *BB = (*I)->getBlock(); + ReversePostOrderTraversal<Function *> RPOT(&F); + for (auto BB : RPOT) { + // Traverse CSEBlocks by RPOT order. + if (!CSEBlocks.count(BB)) + continue; + // For all instructions in blocks containing gather sequences: for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e;) { Instruction *In = &*it++; @@ -3111,7 +3510,7 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, // Make sure that the scheduling region contains all // instructions of the bundle. for (Value *V : VL) { - if (!extendSchedulingRegion(V)) + if (!extendSchedulingRegion(V, OpValue)) return false; } @@ -3148,8 +3547,9 @@ bool BoUpSLP::BlockScheduling::tryScheduleBundle(ArrayRef<Value *> VL, // It is seldom that this needs to be done a second time after adding the // initial bundle to the region. for (auto *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { - ScheduleData *SD = getScheduleData(I); - SD->clearDependencies(); + doForAllOpcodes(I, [](ScheduleData *SD) { + SD->clearDependencies(); + }); } ReSchedule = true; } @@ -3210,17 +3610,43 @@ void BoUpSLP::BlockScheduling::cancelScheduling(ArrayRef<Value *> VL, } } -bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V) { - if (getScheduleData(V)) +BoUpSLP::ScheduleData *BoUpSLP::BlockScheduling::allocateScheduleDataChunks() { + // Allocate a new ScheduleData for the instruction. + if (ChunkPos >= ChunkSize) { + ScheduleDataChunks.push_back(llvm::make_unique<ScheduleData[]>(ChunkSize)); + ChunkPos = 0; + } + return &(ScheduleDataChunks.back()[ChunkPos++]); +} + +bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V, + Value *OpValue) { + if (getScheduleData(V, isOneOf(OpValue, V))) return true; Instruction *I = dyn_cast<Instruction>(V); assert(I && "bundle member must be an instruction"); assert(!isa<PHINode>(I) && "phi nodes don't need to be scheduled"); + auto &&CheckSheduleForI = [this, OpValue](Instruction *I) -> bool { + ScheduleData *ISD = getScheduleData(I); + if (!ISD) + return false; + assert(isInSchedulingRegion(ISD) && + "ScheduleData not in scheduling region"); + ScheduleData *SD = allocateScheduleDataChunks(); + SD->Inst = I; + SD->init(SchedulingRegionID, OpValue); + ExtraScheduleDataMap[I][OpValue] = SD; + return true; + }; + if (CheckSheduleForI(I)) + return true; if (!ScheduleStart) { // It's the first instruction in the new region. initScheduleData(I, I->getNextNode(), nullptr, nullptr); ScheduleStart = I; ScheduleEnd = I->getNextNode(); + if (isOneOf(OpValue, I) != I) + CheckSheduleForI(I); assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); DEBUG(dbgs() << "SLP: initialize schedule region to " << *I << "\n"); return true; @@ -3232,7 +3658,7 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V) { BasicBlock::reverse_iterator UpperEnd = BB->rend(); BasicBlock::iterator DownIter = ScheduleEnd->getIterator(); BasicBlock::iterator LowerEnd = BB->end(); - for (;;) { + while (true) { if (++ScheduleRegionSize > ScheduleRegionSizeLimit) { DEBUG(dbgs() << "SLP: exceeded schedule region size limit\n"); return false; @@ -3242,6 +3668,8 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V) { if (&*UpIter == I) { initScheduleData(I, ScheduleStart, nullptr, FirstLoadStoreInRegion); ScheduleStart = I; + if (isOneOf(OpValue, I) != I) + CheckSheduleForI(I); DEBUG(dbgs() << "SLP: extend schedule region start to " << *I << "\n"); return true; } @@ -3252,6 +3680,8 @@ bool BoUpSLP::BlockScheduling::extendSchedulingRegion(Value *V) { initScheduleData(ScheduleEnd, I->getNextNode(), LastLoadStoreInRegion, nullptr); ScheduleEnd = I->getNextNode(); + if (isOneOf(OpValue, I) != I) + CheckSheduleForI(I); assert(ScheduleEnd && "tried to vectorize a TerminatorInst?"); DEBUG(dbgs() << "SLP: extend schedule region end to " << *I << "\n"); return true; @@ -3272,21 +3702,17 @@ void BoUpSLP::BlockScheduling::initScheduleData(Instruction *FromI, for (Instruction *I = FromI; I != ToI; I = I->getNextNode()) { ScheduleData *SD = ScheduleDataMap[I]; if (!SD) { - // Allocate a new ScheduleData for the instruction. - if (ChunkPos >= ChunkSize) { - ScheduleDataChunks.push_back( - llvm::make_unique<ScheduleData[]>(ChunkSize)); - ChunkPos = 0; - } - SD = &(ScheduleDataChunks.back()[ChunkPos++]); + SD = allocateScheduleDataChunks(); ScheduleDataMap[I] = SD; SD->Inst = I; } assert(!isInSchedulingRegion(SD) && "new ScheduleData already in scheduling region"); - SD->init(SchedulingRegionID); + SD->init(SchedulingRegionID, I); - if (I->mayReadOrWriteMemory()) { + if (I->mayReadOrWriteMemory() && + (!isa<IntrinsicInst>(I) || + cast<IntrinsicInst>(I)->getIntrinsicID() != Intrinsic::sideeffect)) { // Update the linked list of memory accessing instructions. if (CurrentLoadStore) { CurrentLoadStore->NextLoadStore = SD; @@ -3326,23 +3752,35 @@ void BoUpSLP::BlockScheduling::calculateDependencies(ScheduleData *SD, BundleMember->resetUnscheduledDeps(); // Handle def-use chain dependencies. - for (User *U : BundleMember->Inst->users()) { - if (isa<Instruction>(U)) { - ScheduleData *UseSD = getScheduleData(U); - if (UseSD && isInSchedulingRegion(UseSD->FirstInBundle)) { + if (BundleMember->OpValue != BundleMember->Inst) { + ScheduleData *UseSD = getScheduleData(BundleMember->Inst); + if (UseSD && isInSchedulingRegion(UseSD->FirstInBundle)) { + BundleMember->Dependencies++; + ScheduleData *DestBundle = UseSD->FirstInBundle; + if (!DestBundle->IsScheduled) + BundleMember->incrementUnscheduledDeps(1); + if (!DestBundle->hasValidDependencies()) + WorkList.push_back(DestBundle); + } + } else { + for (User *U : BundleMember->Inst->users()) { + if (isa<Instruction>(U)) { + ScheduleData *UseSD = getScheduleData(U); + if (UseSD && isInSchedulingRegion(UseSD->FirstInBundle)) { + BundleMember->Dependencies++; + ScheduleData *DestBundle = UseSD->FirstInBundle; + if (!DestBundle->IsScheduled) + BundleMember->incrementUnscheduledDeps(1); + if (!DestBundle->hasValidDependencies()) + WorkList.push_back(DestBundle); + } + } else { + // I'm not sure if this can ever happen. But we need to be safe. + // This lets the instruction/bundle never be scheduled and + // eventually disable vectorization. BundleMember->Dependencies++; - ScheduleData *DestBundle = UseSD->FirstInBundle; - if (!DestBundle->IsScheduled) - BundleMember->incrementUnscheduledDeps(1); - if (!DestBundle->hasValidDependencies()) - WorkList.push_back(DestBundle); + BundleMember->incrementUnscheduledDeps(1); } - } else { - // I'm not sure if this can ever happen. But we need to be safe. - // This lets the instruction/bundle never be scheduled and - // eventually disable vectorization. - BundleMember->Dependencies++; - BundleMember->incrementUnscheduledDeps(1); } } @@ -3419,16 +3857,17 @@ void BoUpSLP::BlockScheduling::resetSchedule() { assert(ScheduleStart && "tried to reset schedule on block which has not been scheduled"); for (Instruction *I = ScheduleStart; I != ScheduleEnd; I = I->getNextNode()) { - ScheduleData *SD = getScheduleData(I); - assert(isInSchedulingRegion(SD)); - SD->IsScheduled = false; - SD->resetUnscheduledDeps(); + doForAllOpcodes(I, [&](ScheduleData *SD) { + assert(isInSchedulingRegion(SD) && + "ScheduleData not in scheduling region"); + SD->IsScheduled = false; + SD->resetUnscheduledDeps(); + }); } ReadyInsts.clear(); } void BoUpSLP::scheduleBlock(BlockScheduling *BS) { - if (!BS->ScheduleStart) return; @@ -3452,15 +3891,16 @@ void BoUpSLP::scheduleBlock(BlockScheduling *BS) { int NumToSchedule = 0; for (auto *I = BS->ScheduleStart; I != BS->ScheduleEnd; I = I->getNextNode()) { - ScheduleData *SD = BS->getScheduleData(I); - assert( - SD->isPartOfBundle() == (getTreeEntry(SD->Inst) != nullptr) && - "scheduler and vectorizer have different opinion on what is a bundle"); - SD->FirstInBundle->SchedulingPriority = Idx++; - if (SD->isSchedulingEntity()) { - BS->calculateDependencies(SD, false, this); - NumToSchedule++; - } + BS->doForAllOpcodes(I, [this, &Idx, &NumToSchedule, BS](ScheduleData *SD) { + assert(SD->isPartOfBundle() == + (getTreeEntry(SD->Inst) != nullptr) && + "scheduler and vectorizer bundle mismatch"); + SD->FirstInBundle->SchedulingPriority = Idx++; + if (SD->isSchedulingEntity()) { + BS->calculateDependencies(SD, false, this); + NumToSchedule++; + } + }); } BS->initialFillReadyList(ReadyInsts); @@ -3559,7 +3999,6 @@ unsigned BoUpSLP::getVectorElementSize(Value *V) { static bool collectValuesToDemote(Value *V, SmallPtrSetImpl<Value *> &Expr, SmallVectorImpl<Value *> &ToDemote, SmallVectorImpl<Value *> &Roots) { - // We can always demote constants. if (isa<Constant>(V)) { ToDemote.push_back(V); @@ -3702,7 +4141,7 @@ void BoUpSLP::computeMinimumValueSizes() { // Determine if the sign bit of all the roots is known to be zero. If not, // IsKnownPositive is set to False. - IsKnownPositive = all_of(TreeRoot, [&](Value *R) { + IsKnownPositive = llvm::all_of(TreeRoot, [&](Value *R) { KnownBits Known = computeKnownBits(R, *DL); return Known.isNonNegative(); }); @@ -3710,7 +4149,7 @@ void BoUpSLP::computeMinimumValueSizes() { // Determine the maximum number of bits required to store the scalar // values. for (auto *Scalar : ToDemote) { - auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, 0, DT); + auto NumSignBits = ComputeNumSignBits(Scalar, *DL, 0, AC, nullptr, DT); auto NumTypeBits = DL->getTypeSizeInBits(Scalar->getType()); MaxBitWidth = std::max<unsigned>(NumTypeBits - NumSignBits, MaxBitWidth); } @@ -3755,6 +4194,7 @@ void BoUpSLP::computeMinimumValueSizes() { } namespace { + /// The SLPVectorizer Pass. struct SLPVectorizer : public FunctionPass { SLPVectorizerPass Impl; @@ -3766,7 +4206,6 @@ struct SLPVectorizer : public FunctionPass { initializeSLPVectorizerPass(*PassRegistry::getPassRegistry()); } - bool doInitialization(Module &M) override { return false; } @@ -3806,6 +4245,7 @@ struct SLPVectorizer : public FunctionPass { AU.setPreservesCFG(); } }; + } // end anonymous namespace PreservedAnalyses SLPVectorizerPass::run(Function &F, FunctionAnalysisManager &AM) { @@ -3893,7 +4333,7 @@ bool SLPVectorizerPass::runImpl(Function &F, ScalarEvolution *SE_, } if (Changed) { - R.optimizeGatherSequence(); + R.optimizeGatherSequence(F); DEBUG(dbgs() << "SLP: vectorized \"" << F.getName() << "\"\n"); DEBUG(verifyFunction(F)); } @@ -3952,7 +4392,9 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, DEBUG(dbgs() << "SLP: Found cost=" << Cost << " for VF=" << VF << "\n"); if (Cost < -SLPCostThreshold) { DEBUG(dbgs() << "SLP: Decided to vectorize cost=" << Cost << "\n"); + using namespace ore; + R.getORE()->emit(OptimizationRemark(SV_NAME, "StoresVectorized", cast<StoreInst>(Chain[i])) << "Stores SLP vectorized with cost " << NV("Cost", Cost) @@ -3972,7 +4414,8 @@ bool SLPVectorizerPass::vectorizeStoreChain(ArrayRef<Value *> Chain, BoUpSLP &R, bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, BoUpSLP &R) { - SetVector<StoreInst *> Heads, Tails; + SetVector<StoreInst *> Heads; + SmallDenseSet<StoreInst *> Tails; SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain; // We may run into multiple chains that merge into a single chain. We mark the @@ -3980,45 +4423,51 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, BoUpSLP::ValueSet VectorizedStores; bool Changed = false; - // Do a quadratic search on all of the given stores and find + // Do a quadratic search on all of the given stores in reverse order and find // all of the pairs of stores that follow each other. SmallVector<unsigned, 16> IndexQueue; - for (unsigned i = 0, e = Stores.size(); i < e; ++i) { - IndexQueue.clear(); + unsigned E = Stores.size(); + IndexQueue.resize(E - 1); + for (unsigned I = E; I > 0; --I) { + unsigned Idx = I - 1; // If a store has multiple consecutive store candidates, search Stores - // array according to the sequence: from i+1 to e, then from i-1 to 0. + // array according to the sequence: Idx-1, Idx+1, Idx-2, Idx+2, ... // This is because usually pairing with immediate succeeding or preceding // candidate create the best chance to find slp vectorization opportunity. - unsigned j = 0; - for (j = i + 1; j < e; ++j) - IndexQueue.push_back(j); - for (j = i; j > 0; --j) - IndexQueue.push_back(j - 1); - - for (auto &k : IndexQueue) { - if (isConsecutiveAccess(Stores[i], Stores[k], *DL, *SE)) { - Tails.insert(Stores[k]); - Heads.insert(Stores[i]); - ConsecutiveChain[Stores[i]] = Stores[k]; + unsigned Offset = 1; + unsigned Cnt = 0; + for (unsigned J = 0; J < E - 1; ++J, ++Offset) { + if (Idx >= Offset) { + IndexQueue[Cnt] = Idx - Offset; + ++Cnt; + } + if (Idx + Offset < E) { + IndexQueue[Cnt] = Idx + Offset; + ++Cnt; + } + } + + for (auto K : IndexQueue) { + if (isConsecutiveAccess(Stores[K], Stores[Idx], *DL, *SE)) { + Tails.insert(Stores[Idx]); + Heads.insert(Stores[K]); + ConsecutiveChain[Stores[K]] = Stores[Idx]; break; } } } // For stores that start but don't end a link in the chain: - for (SetVector<StoreInst *>::iterator it = Heads.begin(), e = Heads.end(); - it != e; ++it) { - if (Tails.count(*it)) + for (auto *SI : llvm::reverse(Heads)) { + if (Tails.count(SI)) continue; // We found a store instr that starts a chain. Now follow the chain and try // to vectorize it. BoUpSLP::ValueList Operands; - StoreInst *I = *it; + StoreInst *I = SI; // Collect the chain into a list. - while (Tails.count(I) || Heads.count(I)) { - if (VectorizedStores.count(I)) - break; + while ((Tails.count(I) || Heads.count(I)) && !VectorizedStores.count(I)) { Operands.push_back(I); // Move to the next value in the chain. I = ConsecutiveChain[I]; @@ -4041,7 +4490,6 @@ bool SLPVectorizerPass::vectorizeStores(ArrayRef<StoreInst *> Stores, } void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) { - // Initialize the collections. We will make a single pass over the block. Stores.clear(); GEPs.clear(); @@ -4050,7 +4498,6 @@ void SLPVectorizerPass::collectSeedInstructions(BasicBlock *BB) { // Stores and GEPs according to the underlying objects of their pointer // operands. for (Instruction &I : *BB) { - // Ignore store instructions that are volatile or have a pointer operand // that doesn't point to a scalar type. if (auto *SI = dyn_cast<StoreInst>(&I)) { @@ -4086,7 +4533,8 @@ bool SLPVectorizerPass::tryToVectorizePair(Value *A, Value *B, BoUpSLP &R) { bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, ArrayRef<Value *> BuildVector, - bool AllowReorder) { + bool AllowReorder, + bool NeedExtraction) { if (VL.size() < 2) return false; @@ -4103,19 +4551,51 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, unsigned Sz = R.getVectorElementSize(I0); unsigned MinVF = std::max(2U, R.getMinVecRegSize() / Sz); unsigned MaxVF = std::max<unsigned>(PowerOf2Floor(VL.size()), MinVF); - if (MaxVF < 2) - return false; + if (MaxVF < 2) { + R.getORE()->emit([&]() { + return OptimizationRemarkMissed( + SV_NAME, "SmallVF", I0) + << "Cannot SLP vectorize list: vectorization factor " + << "less than 2 is not supported"; + }); + return false; + } for (Value *V : VL) { Type *Ty = V->getType(); - if (!isValidElementType(Ty)) + if (!isValidElementType(Ty)) { + // NOTE: the following will give user internal llvm type name, which may not be useful + R.getORE()->emit([&]() { + std::string type_str; + llvm::raw_string_ostream rso(type_str); + Ty->print(rso); + return OptimizationRemarkMissed( + SV_NAME, "UnsupportedType", I0) + << "Cannot SLP vectorize list: type " + << rso.str() + " is unsupported by vectorizer"; + }); return false; + } Instruction *Inst = dyn_cast<Instruction>(V); - if (!Inst || Inst->getOpcode() != Opcode0) + + if (!Inst) + return false; + if (Inst->getOpcode() != Opcode0) { + R.getORE()->emit([&]() { + return OptimizationRemarkMissed( + SV_NAME, "InequableTypes", I0) + << "Cannot SLP vectorize list: not all of the " + << "parts of scalar instructions are of the same type: " + << ore::NV("Instruction1Opcode", I0) << " and " + << ore::NV("Instruction2Opcode", Inst); + }); return false; + } } bool Changed = false; + bool CandidateFound = false; + int MinCost = SLPCostThreshold; // Keep track of values that were deleted by vectorizing in the loop below. SmallVector<WeakTrackingVH, 8> TrackValues(VL.begin(), VL.end()); @@ -4148,11 +4628,12 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, << "\n"); ArrayRef<Value *> Ops = VL.slice(I, OpsWidth); + ArrayRef<Value *> EmptyArray; ArrayRef<Value *> BuildVectorSlice; if (!BuildVector.empty()) BuildVectorSlice = BuildVector.slice(I, OpsWidth); - R.buildTree(Ops, BuildVectorSlice); + R.buildTree(Ops, NeedExtraction ? EmptyArray : BuildVectorSlice); // TODO: check if we can allow reordering for more cases. if (AllowReorder && R.shouldReorder()) { // Conceptually, there is nothing actually preventing us from trying to @@ -4169,14 +4650,16 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, R.computeMinimumValueSizes(); int Cost = R.getTreeCost(); + CandidateFound = true; + MinCost = std::min(MinCost, Cost); if (Cost < -SLPCostThreshold) { DEBUG(dbgs() << "SLP: Vectorizing list at cost:" << Cost << ".\n"); R.getORE()->emit(OptimizationRemark(SV_NAME, "VectorizedList", - cast<Instruction>(Ops[0])) - << "SLP vectorized with cost " << ore::NV("Cost", Cost) - << " and with tree size " - << ore::NV("TreeSize", R.getTreeSize())); + cast<Instruction>(Ops[0])) + << "SLP vectorized with cost " << ore::NV("Cost", Cost) + << " and with tree size " + << ore::NV("TreeSize", R.getTreeSize())); Value *VectorizedRoot = R.vectorizeTree(); @@ -4199,8 +4682,7 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, cast<Instruction>(Builder.CreateExtractElement( VectorizedRoot, Builder.getInt32(VecIdx++))); I->setOperand(1, Extract); - I->removeFromParent(); - I->insertAfter(Extract); + I->moveAfter(Extract); InsertAfter = I; } } @@ -4212,18 +4694,37 @@ bool SLPVectorizerPass::tryToVectorizeList(ArrayRef<Value *> VL, BoUpSLP &R, } } + if (!Changed && CandidateFound) { + R.getORE()->emit([&]() { + return OptimizationRemarkMissed( + SV_NAME, "NotBeneficial", I0) + << "List vectorization was possible but not beneficial with cost " + << ore::NV("Cost", MinCost) << " >= " + << ore::NV("Treshold", -SLPCostThreshold); + }); + } else if (!Changed) { + R.getORE()->emit([&]() { + return OptimizationRemarkMissed( + SV_NAME, "NotPossible", I0) + << "Cannot SLP vectorize list: vectorization was impossible" + << " with available vectorization factors"; + }); + } return Changed; } -bool SLPVectorizerPass::tryToVectorize(BinaryOperator *V, BoUpSLP &R) { - if (!V) +bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) { + if (!I) return false; - Value *P = V->getParent(); + if (!isa<BinaryOperator>(I) && !isa<CmpInst>(I)) + return false; + + Value *P = I->getParent(); // Vectorize in current basic block only. - auto *Op0 = dyn_cast<Instruction>(V->getOperand(0)); - auto *Op1 = dyn_cast<Instruction>(V->getOperand(1)); + auto *Op0 = dyn_cast<Instruction>(I->getOperand(0)); + auto *Op1 = dyn_cast<Instruction>(I->getOperand(1)); if (!Op0 || !Op1 || Op0->getParent() != P || Op1->getParent() != P) return false; @@ -4286,6 +4787,7 @@ static Value *createRdxShuffleMask(unsigned VecLen, unsigned NumEltsToRdx, } namespace { + /// Model horizontal reductions. /// /// A horizontal reduction is a tree of reduction operations (currently add and @@ -4314,17 +4816,375 @@ namespace { /// *p = /// class HorizontalReduction { - SmallVector<Value *, 16> ReductionOps; + using ReductionOpsType = SmallVector<Value *, 16>; + using ReductionOpsListType = SmallVector<ReductionOpsType, 2>; + ReductionOpsListType ReductionOps; SmallVector<Value *, 32> ReducedVals; // Use map vector to make stable output. MapVector<Instruction *, Value *> ExtraArgs; - BinaryOperator *ReductionRoot = nullptr; + /// Kind of the reduction data. + enum ReductionKind { + RK_None, /// Not a reduction. + RK_Arithmetic, /// Binary reduction data. + RK_Min, /// Minimum reduction data. + RK_UMin, /// Unsigned minimum reduction data. + RK_Max, /// Maximum reduction data. + RK_UMax, /// Unsigned maximum reduction data. + }; + + /// Contains info about operation, like its opcode, left and right operands. + class OperationData { + /// Opcode of the instruction. + unsigned Opcode = 0; + + /// Left operand of the reduction operation. + Value *LHS = nullptr; + + /// Right operand of the reduction operation. + Value *RHS = nullptr; + + /// Kind of the reduction operation. + ReductionKind Kind = RK_None; + + /// True if float point min/max reduction has no NaNs. + bool NoNaN = false; + + /// Checks if the reduction operation can be vectorized. + bool isVectorizable() const { + return LHS && RHS && + // We currently only support adds && min/max reductions. + ((Kind == RK_Arithmetic && + (Opcode == Instruction::Add || Opcode == Instruction::FAdd)) || + ((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && + (Kind == RK_Min || Kind == RK_Max)) || + (Opcode == Instruction::ICmp && + (Kind == RK_UMin || Kind == RK_UMax))); + } + + /// Creates reduction operation with the current opcode. + Value *createOp(IRBuilder<> &Builder, const Twine &Name) const { + assert(isVectorizable() && + "Expected add|fadd or min/max reduction operation."); + Value *Cmp; + switch (Kind) { + case RK_Arithmetic: + return Builder.CreateBinOp((Instruction::BinaryOps)Opcode, LHS, RHS, + Name); + case RK_Min: + Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSLT(LHS, RHS) + : Builder.CreateFCmpOLT(LHS, RHS); + break; + case RK_Max: + Cmp = Opcode == Instruction::ICmp ? Builder.CreateICmpSGT(LHS, RHS) + : Builder.CreateFCmpOGT(LHS, RHS); + break; + case RK_UMin: + assert(Opcode == Instruction::ICmp && "Expected integer types."); + Cmp = Builder.CreateICmpULT(LHS, RHS); + break; + case RK_UMax: + assert(Opcode == Instruction::ICmp && "Expected integer types."); + Cmp = Builder.CreateICmpUGT(LHS, RHS); + break; + case RK_None: + llvm_unreachable("Unknown reduction operation."); + } + return Builder.CreateSelect(Cmp, LHS, RHS, Name); + } + + public: + explicit OperationData() = default; + + /// Construction for reduced values. They are identified by opcode only and + /// don't have associated LHS/RHS values. + explicit OperationData(Value *V) { + if (auto *I = dyn_cast<Instruction>(V)) + Opcode = I->getOpcode(); + } + + /// Constructor for reduction operations with opcode and its left and + /// right operands. + OperationData(unsigned Opcode, Value *LHS, Value *RHS, ReductionKind Kind, + bool NoNaN = false) + : Opcode(Opcode), LHS(LHS), RHS(RHS), Kind(Kind), NoNaN(NoNaN) { + assert(Kind != RK_None && "One of the reduction operations is expected."); + } + + explicit operator bool() const { return Opcode; } + + /// Get the index of the first operand. + unsigned getFirstOperandIndex() const { + assert(!!*this && "The opcode is not set."); + switch (Kind) { + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + return 1; + case RK_Arithmetic: + case RK_None: + break; + } + return 0; + } + + /// Total number of operands in the reduction operation. + unsigned getNumberOfOperands() const { + assert(Kind != RK_None && !!*this && LHS && RHS && + "Expected reduction operation."); + switch (Kind) { + case RK_Arithmetic: + return 2; + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + return 3; + case RK_None: + break; + } + llvm_unreachable("Reduction kind is not set"); + } + + /// Checks if the operation has the same parent as \p P. + bool hasSameParent(Instruction *I, Value *P, bool IsRedOp) const { + assert(Kind != RK_None && !!*this && LHS && RHS && + "Expected reduction operation."); + if (!IsRedOp) + return I->getParent() == P; + switch (Kind) { + case RK_Arithmetic: + // Arithmetic reduction operation must be used once only. + return I->getParent() == P; + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: { + // SelectInst must be used twice while the condition op must have single + // use only. + auto *Cmp = cast<Instruction>(cast<SelectInst>(I)->getCondition()); + return I->getParent() == P && Cmp && Cmp->getParent() == P; + } + case RK_None: + break; + } + llvm_unreachable("Reduction kind is not set"); + } + /// Expected number of uses for reduction operations/reduced values. + bool hasRequiredNumberOfUses(Instruction *I, bool IsReductionOp) const { + assert(Kind != RK_None && !!*this && LHS && RHS && + "Expected reduction operation."); + switch (Kind) { + case RK_Arithmetic: + return I->hasOneUse(); + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + return I->hasNUses(2) && + (!IsReductionOp || + cast<SelectInst>(I)->getCondition()->hasOneUse()); + case RK_None: + break; + } + llvm_unreachable("Reduction kind is not set"); + } + + /// Initializes the list of reduction operations. + void initReductionOps(ReductionOpsListType &ReductionOps) { + assert(Kind != RK_None && !!*this && LHS && RHS && + "Expected reduction operation."); + switch (Kind) { + case RK_Arithmetic: + ReductionOps.assign(1, ReductionOpsType()); + break; + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + ReductionOps.assign(2, ReductionOpsType()); + break; + case RK_None: + llvm_unreachable("Reduction kind is not set"); + } + } + /// Add all reduction operations for the reduction instruction \p I. + void addReductionOps(Instruction *I, ReductionOpsListType &ReductionOps) { + assert(Kind != RK_None && !!*this && LHS && RHS && + "Expected reduction operation."); + switch (Kind) { + case RK_Arithmetic: + ReductionOps[0].emplace_back(I); + break; + case RK_Min: + case RK_UMin: + case RK_Max: + case RK_UMax: + ReductionOps[0].emplace_back(cast<SelectInst>(I)->getCondition()); + ReductionOps[1].emplace_back(I); + break; + case RK_None: + llvm_unreachable("Reduction kind is not set"); + } + } + + /// Checks if instruction is associative and can be vectorized. + bool isAssociative(Instruction *I) const { + assert(Kind != RK_None && *this && LHS && RHS && + "Expected reduction operation."); + switch (Kind) { + case RK_Arithmetic: + return I->isAssociative(); + case RK_Min: + case RK_Max: + return Opcode == Instruction::ICmp || + cast<Instruction>(I->getOperand(0))->isFast(); + case RK_UMin: + case RK_UMax: + assert(Opcode == Instruction::ICmp && + "Only integer compare operation is expected."); + return true; + case RK_None: + break; + } + llvm_unreachable("Reduction kind is not set"); + } + + /// Checks if the reduction operation can be vectorized. + bool isVectorizable(Instruction *I) const { + return isVectorizable() && isAssociative(I); + } + + /// Checks if two operation data are both a reduction op or both a reduced + /// value. + bool operator==(const OperationData &OD) { + assert(((Kind != OD.Kind) || ((!LHS == !OD.LHS) && (!RHS == !OD.RHS))) && + "One of the comparing operations is incorrect."); + return this == &OD || (Kind == OD.Kind && Opcode == OD.Opcode); + } + bool operator!=(const OperationData &OD) { return !(*this == OD); } + void clear() { + Opcode = 0; + LHS = nullptr; + RHS = nullptr; + Kind = RK_None; + NoNaN = false; + } + + /// Get the opcode of the reduction operation. + unsigned getOpcode() const { + assert(isVectorizable() && "Expected vectorizable operation."); + return Opcode; + } + + /// Get kind of reduction data. + ReductionKind getKind() const { return Kind; } + Value *getLHS() const { return LHS; } + Value *getRHS() const { return RHS; } + Type *getConditionType() const { + switch (Kind) { + case RK_Arithmetic: + return nullptr; + case RK_Min: + case RK_Max: + case RK_UMin: + case RK_UMax: + return CmpInst::makeCmpResultType(LHS->getType()); + case RK_None: + break; + } + llvm_unreachable("Reduction kind is not set"); + } + + /// Creates reduction operation with the current opcode with the IR flags + /// from \p ReductionOps. + Value *createOp(IRBuilder<> &Builder, const Twine &Name, + const ReductionOpsListType &ReductionOps) const { + assert(isVectorizable() && + "Expected add|fadd or min/max reduction operation."); + auto *Op = createOp(Builder, Name); + switch (Kind) { + case RK_Arithmetic: + propagateIRFlags(Op, ReductionOps[0]); + return Op; + case RK_Min: + case RK_Max: + case RK_UMin: + case RK_UMax: + if (auto *SI = dyn_cast<SelectInst>(Op)) + propagateIRFlags(SI->getCondition(), ReductionOps[0]); + propagateIRFlags(Op, ReductionOps[1]); + return Op; + case RK_None: + break; + } + llvm_unreachable("Unknown reduction operation."); + } + /// Creates reduction operation with the current opcode with the IR flags + /// from \p I. + Value *createOp(IRBuilder<> &Builder, const Twine &Name, + Instruction *I) const { + assert(isVectorizable() && + "Expected add|fadd or min/max reduction operation."); + auto *Op = createOp(Builder, Name); + switch (Kind) { + case RK_Arithmetic: + propagateIRFlags(Op, I); + return Op; + case RK_Min: + case RK_Max: + case RK_UMin: + case RK_UMax: + if (auto *SI = dyn_cast<SelectInst>(Op)) { + propagateIRFlags(SI->getCondition(), + cast<SelectInst>(I)->getCondition()); + } + propagateIRFlags(Op, I); + return Op; + case RK_None: + break; + } + llvm_unreachable("Unknown reduction operation."); + } + + TargetTransformInfo::ReductionFlags getFlags() const { + TargetTransformInfo::ReductionFlags Flags; + Flags.NoNaN = NoNaN; + switch (Kind) { + case RK_Arithmetic: + break; + case RK_Min: + Flags.IsSigned = Opcode == Instruction::ICmp; + Flags.IsMaxOp = false; + break; + case RK_Max: + Flags.IsSigned = Opcode == Instruction::ICmp; + Flags.IsMaxOp = true; + break; + case RK_UMin: + Flags.IsSigned = false; + Flags.IsMaxOp = false; + break; + case RK_UMax: + Flags.IsSigned = false; + Flags.IsMaxOp = true; + break; + case RK_None: + llvm_unreachable("Reduction kind is not set"); + } + return Flags; + } + }; + + Instruction *ReductionRoot = nullptr; + + /// The operation data of the reduction operation. + OperationData ReductionData; + + /// The operation data of the values we perform a reduction on. + OperationData ReducedValueData; - /// The opcode of the reduction. - Instruction::BinaryOps ReductionOpcode = Instruction::BinaryOpsEnd; - /// The opcode of the values we perform a reduction on. - unsigned ReducedValueOpcode = 0; /// Should we model this reduction as a pairwise reduction tree or a tree that /// splits the vector in halves and adds those halves. bool IsPairwiseReduction = false; @@ -4349,55 +5209,89 @@ class HorizontalReduction { } } + static OperationData getOperationData(Value *V) { + if (!V) + return OperationData(); + + Value *LHS; + Value *RHS; + if (m_BinOp(m_Value(LHS), m_Value(RHS)).match(V)) { + return OperationData(cast<BinaryOperator>(V)->getOpcode(), LHS, RHS, + RK_Arithmetic); + } + if (auto *Select = dyn_cast<SelectInst>(V)) { + // Look for a min/max pattern. + if (m_UMin(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_UMin); + } else if (m_SMin(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_Min); + } else if (m_OrdFMin(m_Value(LHS), m_Value(RHS)).match(Select) || + m_UnordFMin(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData( + Instruction::FCmp, LHS, RHS, RK_Min, + cast<Instruction>(Select->getCondition())->hasNoNaNs()); + } else if (m_UMax(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_UMax); + } else if (m_SMax(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData(Instruction::ICmp, LHS, RHS, RK_Max); + } else if (m_OrdFMax(m_Value(LHS), m_Value(RHS)).match(Select) || + m_UnordFMax(m_Value(LHS), m_Value(RHS)).match(Select)) { + return OperationData( + Instruction::FCmp, LHS, RHS, RK_Max, + cast<Instruction>(Select->getCondition())->hasNoNaNs()); + } + } + return OperationData(V); + } + public: HorizontalReduction() = default; /// \brief Try to find a reduction tree. - bool matchAssociativeReduction(PHINode *Phi, BinaryOperator *B) { + bool matchAssociativeReduction(PHINode *Phi, Instruction *B) { assert((!Phi || is_contained(Phi->operands(), B)) && "Thi phi needs to use the binary operator"); + ReductionData = getOperationData(B); + // We could have a initial reductions that is not an add. // r *= v1 + v2 + v3 + v4 // In such a case start looking for a tree rooted in the first '+'. if (Phi) { - if (B->getOperand(0) == Phi) { + if (ReductionData.getLHS() == Phi) { Phi = nullptr; - B = dyn_cast<BinaryOperator>(B->getOperand(1)); - } else if (B->getOperand(1) == Phi) { + B = dyn_cast<Instruction>(ReductionData.getRHS()); + ReductionData = getOperationData(B); + } else if (ReductionData.getRHS() == Phi) { Phi = nullptr; - B = dyn_cast<BinaryOperator>(B->getOperand(0)); + B = dyn_cast<Instruction>(ReductionData.getLHS()); + ReductionData = getOperationData(B); } } - if (!B) + if (!ReductionData.isVectorizable(B)) return false; Type *Ty = B->getType(); if (!isValidElementType(Ty)) return false; - ReductionOpcode = B->getOpcode(); - ReducedValueOpcode = 0; + ReducedValueData.clear(); ReductionRoot = B; - // We currently only support adds. - if ((ReductionOpcode != Instruction::Add && - ReductionOpcode != Instruction::FAdd) || - !B->isAssociative()) - return false; - // Post order traverse the reduction tree starting at B. We only handle true - // trees containing only binary operators or selects. + // trees containing only binary operators. SmallVector<std::pair<Instruction *, unsigned>, 32> Stack; - Stack.push_back(std::make_pair(B, 0)); + Stack.push_back(std::make_pair(B, ReductionData.getFirstOperandIndex())); + ReductionData.initReductionOps(ReductionOps); while (!Stack.empty()) { Instruction *TreeN = Stack.back().first; unsigned EdgeToVist = Stack.back().second++; - bool IsReducedValue = TreeN->getOpcode() != ReductionOpcode; + OperationData OpData = getOperationData(TreeN); + bool IsReducedValue = OpData != ReductionData; // Postorder vist. - if (EdgeToVist == 2 || IsReducedValue) { + if (IsReducedValue || EdgeToVist == OpData.getNumberOfOperands()) { if (IsReducedValue) ReducedVals.push_back(TreeN); else { @@ -4415,7 +5309,7 @@ public: markExtraArg(Stack[Stack.size() - 2], TreeN); ExtraArgs.erase(TreeN); } else - ReductionOps.push_back(TreeN); + ReductionData.addReductionOps(TreeN, ReductionOps); } // Retract. Stack.pop_back(); @@ -4426,45 +5320,50 @@ public: Value *NextV = TreeN->getOperand(EdgeToVist); if (NextV != Phi) { auto *I = dyn_cast<Instruction>(NextV); + OpData = getOperationData(I); // Continue analysis if the next operand is a reduction operation or // (possibly) a reduced value. If the reduced value opcode is not set, // the first met operation != reduction operation is considered as the // reduced value class. - if (I && (!ReducedValueOpcode || I->getOpcode() == ReducedValueOpcode || - I->getOpcode() == ReductionOpcode)) { + if (I && (!ReducedValueData || OpData == ReducedValueData || + OpData == ReductionData)) { + const bool IsReductionOperation = OpData == ReductionData; // Only handle trees in the current basic block. - if (I->getParent() != B->getParent()) { + if (!ReductionData.hasSameParent(I, B->getParent(), + IsReductionOperation)) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - // Each tree node needs to have one user except for the ultimate - // reduction. - if (!I->hasOneUse() && I != B) { + // Each tree node needs to have minimal number of users except for the + // ultimate reduction. + if (!ReductionData.hasRequiredNumberOfUses(I, + OpData == ReductionData) && + I != B) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - if (I->getOpcode() == ReductionOpcode) { + if (IsReductionOperation) { // We need to be able to reassociate the reduction operations. - if (!I->isAssociative()) { + if (!OpData.isAssociative(I)) { // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; } - } else if (ReducedValueOpcode && - ReducedValueOpcode != I->getOpcode()) { + } else if (ReducedValueData && + ReducedValueData != OpData) { // Make sure that the opcodes of the operations that we are going to // reduce match. // I is an extra argument for TreeN (its parent operation). markExtraArg(Stack.back(), I); continue; - } else if (!ReducedValueOpcode) - ReducedValueOpcode = I->getOpcode(); + } else if (!ReducedValueData) + ReducedValueData = OpData; - Stack.push_back(std::make_pair(I, 0)); + Stack.push_back(std::make_pair(I, OpData.getFirstOperandIndex())); continue; } } @@ -4492,7 +5391,7 @@ public: Value *VectorizedTree = nullptr; IRBuilder<> Builder(ReductionRoot); FastMathFlags Unsafe; - Unsafe.setUnsafeAlgebra(); + Unsafe.setFast(); Builder.setFastMathFlags(Unsafe); unsigned i = 0; @@ -4501,12 +5400,15 @@ public: // to use it. for (auto &Pair : ExtraArgs) ExternallyUsedValues[Pair.second].push_back(Pair.first); + SmallVector<Value *, 16> IgnoreList; + for (auto &V : ReductionOps) + IgnoreList.append(V.begin(), V.end()); while (i < NumReducedVals - ReduxWidth + 1 && ReduxWidth > 2) { auto VL = makeArrayRef(&ReducedVals[i], ReduxWidth); - V.buildTree(VL, ExternallyUsedValues, ReductionOps); + V.buildTree(VL, ExternallyUsedValues, IgnoreList); if (V.shouldReorder()) { SmallVector<Value *, 8> Reversed(VL.rbegin(), VL.rend()); - V.buildTree(Reversed, ExternallyUsedValues, ReductionOps); + V.buildTree(Reversed, ExternallyUsedValues, IgnoreList); } if (V.isTreeTinyAndNotFullyVectorizable()) break; @@ -4516,17 +5418,27 @@ public: // Estimate cost. int Cost = V.getTreeCost() + getReductionCost(TTI, ReducedVals[i], ReduxWidth); - if (Cost >= -SLPCostThreshold) - break; + if (Cost >= -SLPCostThreshold) { + V.getORE()->emit([&]() { + return OptimizationRemarkMissed( + SV_NAME, "HorSLPNotBeneficial", cast<Instruction>(VL[0])) + << "Vectorizing horizontal reduction is possible" + << "but not beneficial with cost " + << ore::NV("Cost", Cost) << " and threshold " + << ore::NV("Threshold", -SLPCostThreshold); + }); + break; + } DEBUG(dbgs() << "SLP: Vectorizing horizontal reduction at cost:" << Cost << ". (HorRdx)\n"); - auto *I0 = cast<Instruction>(VL[0]); - V.getORE()->emit( - OptimizationRemark(SV_NAME, "VectorizedHorizontalReduction", I0) + V.getORE()->emit([&]() { + return OptimizationRemark( + SV_NAME, "VectorizedHorizontalReduction", cast<Instruction>(VL[0])) << "Vectorized horizontal reduction with cost " << ore::NV("Cost", Cost) << " and with tree size " - << ore::NV("TreeSize", V.getTreeSize())); + << ore::NV("TreeSize", V.getTreeSize()); + }); // Vectorize a tree. DebugLoc Loc = cast<Instruction>(ReducedVals[i])->getDebugLoc(); @@ -4534,12 +5446,14 @@ public: // Emit a reduction. Value *ReducedSubTree = - emitReduction(VectorizedRoot, Builder, ReduxWidth, ReductionOps, TTI); + emitReduction(VectorizedRoot, Builder, ReduxWidth, TTI); if (VectorizedTree) { Builder.SetCurrentDebugLocation(Loc); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - ReducedSubTree, "bin.rdx"); - propagateIRFlags(VectorizedTree, ReductionOps); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, ReducedSubTree, + ReductionData.getKind()); + VectorizedTree = + VectReductionData.createOp(Builder, "op.rdx", ReductionOps); } else VectorizedTree = ReducedSubTree; i += ReduxWidth; @@ -4551,9 +5465,10 @@ public: for (; i < NumReducedVals; ++i) { auto *I = cast<Instruction>(ReducedVals[i]); Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = - Builder.CreateBinOp(ReductionOpcode, VectorizedTree, I); - propagateIRFlags(VectorizedTree, ReductionOps); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, I, + ReductionData.getKind()); + VectorizedTree = VectReductionData.createOp(Builder, "", ReductionOps); } for (auto &Pair : ExternallyUsedValues) { assert(!Pair.second.empty() && @@ -4561,9 +5476,10 @@ public: // Add each externally used value to the final reduction. for (auto *I : Pair.second) { Builder.SetCurrentDebugLocation(I->getDebugLoc()); - VectorizedTree = Builder.CreateBinOp(ReductionOpcode, VectorizedTree, - Pair.first, "bin.extra"); - propagateIRFlags(VectorizedTree, I); + OperationData VectReductionData(ReductionData.getOpcode(), + VectorizedTree, Pair.first, + ReductionData.getKind()); + VectorizedTree = VectReductionData.createOp(Builder, "op.extra", I); } } // Update users. @@ -4583,15 +5499,58 @@ private: Type *ScalarTy = FirstReducedVal->getType(); Type *VecTy = VectorType::get(ScalarTy, ReduxWidth); - int PairwiseRdxCost = TTI->getReductionCost(ReductionOpcode, VecTy, true); - int SplittingRdxCost = TTI->getReductionCost(ReductionOpcode, VecTy, false); + int PairwiseRdxCost; + int SplittingRdxCost; + switch (ReductionData.getKind()) { + case RK_Arithmetic: + PairwiseRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/true); + SplittingRdxCost = + TTI->getArithmeticReductionCost(ReductionData.getOpcode(), VecTy, + /*IsPairwiseForm=*/false); + break; + case RK_Min: + case RK_Max: + case RK_UMin: + case RK_UMax: { + Type *VecCondTy = CmpInst::makeCmpResultType(VecTy); + bool IsUnsigned = ReductionData.getKind() == RK_UMin || + ReductionData.getKind() == RK_UMax; + PairwiseRdxCost = + TTI->getMinMaxReductionCost(VecTy, VecCondTy, + /*IsPairwiseForm=*/true, IsUnsigned); + SplittingRdxCost = + TTI->getMinMaxReductionCost(VecTy, VecCondTy, + /*IsPairwiseForm=*/false, IsUnsigned); + break; + } + case RK_None: + llvm_unreachable("Expected arithmetic or min/max reduction operation"); + } IsPairwiseReduction = PairwiseRdxCost < SplittingRdxCost; int VecReduxCost = IsPairwiseReduction ? PairwiseRdxCost : SplittingRdxCost; - int ScalarReduxCost = - (ReduxWidth - 1) * - TTI->getArithmeticInstrCost(ReductionOpcode, ScalarTy); + int ScalarReduxCost; + switch (ReductionData.getKind()) { + case RK_Arithmetic: + ScalarReduxCost = + TTI->getArithmeticInstrCost(ReductionData.getOpcode(), ScalarTy); + break; + case RK_Min: + case RK_Max: + case RK_UMin: + case RK_UMax: + ScalarReduxCost = + TTI->getCmpSelInstrCost(ReductionData.getOpcode(), ScalarTy) + + TTI->getCmpSelInstrCost(Instruction::Select, ScalarTy, + CmpInst::makeCmpResultType(ScalarTy)); + break; + case RK_None: + llvm_unreachable("Expected arithmetic or min/max reduction operation"); + } + ScalarReduxCost *= (ReduxWidth - 1); DEBUG(dbgs() << "SLP: Adding cost " << VecReduxCost - ScalarReduxCost << " for reduction that starts with " << *FirstReducedVal @@ -4604,16 +5563,15 @@ private: /// \brief Emit a horizontal reduction of the vectorized value. Value *emitReduction(Value *VectorizedValue, IRBuilder<> &Builder, - unsigned ReduxWidth, ArrayRef<Value *> RedOps, - const TargetTransformInfo *TTI) { + unsigned ReduxWidth, const TargetTransformInfo *TTI) { assert(VectorizedValue && "Need to have a vectorized tree node"); assert(isPowerOf2_32(ReduxWidth) && "We only handle power-of-two reductions for now"); if (!IsPairwiseReduction) return createSimpleTargetReduction( - Builder, TTI, ReductionOpcode, VectorizedValue, - TargetTransformInfo::ReductionFlags(), RedOps); + Builder, TTI, ReductionData.getOpcode(), VectorizedValue, + ReductionData.getFlags(), ReductionOps.back()); Value *TmpVec = VectorizedValue; for (unsigned i = ReduxWidth / 2; i != 0; i >>= 1) { @@ -4627,15 +5585,16 @@ private: Value *RightShuf = Builder.CreateShuffleVector( TmpVec, UndefValue::get(TmpVec->getType()), (RightMask), "rdx.shuf.r"); - TmpVec = - Builder.CreateBinOp(ReductionOpcode, LeftShuf, RightShuf, "bin.rdx"); - propagateIRFlags(TmpVec, RedOps); + OperationData VectReductionData(ReductionData.getOpcode(), LeftShuf, + RightShuf, ReductionData.getKind()); + TmpVec = VectReductionData.createOp(Builder, "op.rdx", ReductionOps); } // The result is in the first element of the vector. return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0)); } }; + } // end anonymous namespace /// \brief Recognize construction of vectors like @@ -4643,39 +5602,29 @@ private: /// %rb = insertelement <4 x float> %ra, float %s1, i32 1 /// %rc = insertelement <4 x float> %rb, float %s2, i32 2 /// %rd = insertelement <4 x float> %rc, float %s3, i32 3 +/// starting from the last insertelement instruction. /// /// Returns true if it matches -/// -static bool findBuildVector(InsertElementInst *FirstInsertElem, +static bool findBuildVector(InsertElementInst *LastInsertElem, SmallVectorImpl<Value *> &BuildVector, SmallVectorImpl<Value *> &BuildVectorOpds) { - if (!isa<UndefValue>(FirstInsertElem->getOperand(0))) - return false; - - InsertElementInst *IE = FirstInsertElem; - while (true) { - BuildVector.push_back(IE); - BuildVectorOpds.push_back(IE->getOperand(1)); - - if (IE->use_empty()) - return false; - - InsertElementInst *NextUse = dyn_cast<InsertElementInst>(IE->user_back()); - if (!NextUse) - return true; - - // If this isn't the final use, make sure the next insertelement is the only - // use. It's OK if the final constructed vector is used multiple times - if (!IE->hasOneUse()) + Value *V = nullptr; + do { + BuildVector.push_back(LastInsertElem); + BuildVectorOpds.push_back(LastInsertElem->getOperand(1)); + V = LastInsertElem->getOperand(0); + if (isa<UndefValue>(V)) + break; + LastInsertElem = dyn_cast<InsertElementInst>(V); + if (!LastInsertElem || !LastInsertElem->hasOneUse()) return false; - - IE = NextUse; - } - - return false; + } while (true); + std::reverse(BuildVector.begin(), BuildVector.end()); + std::reverse(BuildVectorOpds.begin(), BuildVectorOpds.end()); + return true; } -/// \brief Like findBuildVector, but looks backwards for construction of aggregate. +/// \brief Like findBuildVector, but looks for construction of aggregate. /// /// \return true if it matches. static bool findBuildAggregate(InsertValueInst *IV, @@ -4767,14 +5716,14 @@ static Value *getReductionValue(const DominatorTree *DT, PHINode *P, static bool tryToVectorizeHorReductionOrInstOperands( PHINode *P, Instruction *Root, BasicBlock *BB, BoUpSLP &R, TargetTransformInfo *TTI, - const function_ref<bool(BinaryOperator *, BoUpSLP &)> Vectorize) { + const function_ref<bool(Instruction *, BoUpSLP &)> Vectorize) { if (!ShouldVectorizeHor) return false; if (!Root) return false; - if (Root->getParent() != BB) + if (Root->getParent() != BB || isa<PHINode>(Root)) return false; // Start analysis starting from Root instruction. If horizontal reduction is // found, try to vectorize it. If it is not a horizontal reduction or @@ -4795,11 +5744,13 @@ static bool tryToVectorizeHorReductionOrInstOperands( if (!V) continue; auto *Inst = dyn_cast<Instruction>(V); - if (!Inst || isa<PHINode>(Inst)) + if (!Inst) continue; - if (auto *BI = dyn_cast<BinaryOperator>(Inst)) { + auto *BI = dyn_cast<BinaryOperator>(Inst); + auto *SI = dyn_cast<SelectInst>(Inst); + if (BI || SI) { HorizontalReduction HorRdx; - if (HorRdx.matchAssociativeReduction(P, BI)) { + if (HorRdx.matchAssociativeReduction(P, Inst)) { if (HorRdx.tryToReduce(R, TTI)) { Res = true; // Set P to nullptr to avoid re-analysis of phi node in @@ -4808,7 +5759,7 @@ static bool tryToVectorizeHorReductionOrInstOperands( continue; } } - if (P) { + if (P && BI) { Inst = dyn_cast<Instruction>(BI->getOperand(0)); if (Inst == P) Inst = dyn_cast<Instruction>(BI->getOperand(1)); @@ -4823,15 +5774,20 @@ static bool tryToVectorizeHorReductionOrInstOperands( // Set P to nullptr to avoid re-analysis of phi node in // matchAssociativeReduction function unless this is the root node. P = nullptr; - if (Vectorize(dyn_cast<BinaryOperator>(Inst), R)) { + if (Vectorize(Inst, R)) { Res = true; continue; } // Try to vectorize operands. + // Continue analysis for the instruction from the same basic block only to + // save compile time. if (++Level < RecursionMaxDepth) for (auto *Op : Inst->operand_values()) - Stack.emplace_back(Op, Level); + if (VisitedInstrs.insert(Op).second) + if (auto *I = dyn_cast<Instruction>(Op)) + if (!isa<PHINode>(I) && I->getParent() == BB) + Stack.emplace_back(Op, Level); } return Res; } @@ -4848,10 +5804,71 @@ bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Value *V, if (!isa<BinaryOperator>(I)) P = nullptr; // Try to match and vectorize a horizontal reduction. - return tryToVectorizeHorReductionOrInstOperands( - P, I, BB, R, TTI, [this](BinaryOperator *BI, BoUpSLP &R) -> bool { - return tryToVectorize(BI, R); - }); + auto &&ExtraVectorization = [this](Instruction *I, BoUpSLP &R) -> bool { + return tryToVectorize(I, R); + }; + return tryToVectorizeHorReductionOrInstOperands(P, I, BB, R, TTI, + ExtraVectorization); +} + +bool SLPVectorizerPass::vectorizeInsertValueInst(InsertValueInst *IVI, + BasicBlock *BB, BoUpSLP &R) { + const DataLayout &DL = BB->getModule()->getDataLayout(); + if (!R.canMapToVector(IVI->getType(), DL)) + return false; + + SmallVector<Value *, 16> BuildVector; + SmallVector<Value *, 16> BuildVectorOpds; + if (!findBuildAggregate(IVI, BuildVector, BuildVectorOpds)) + return false; + + DEBUG(dbgs() << "SLP: array mappable to vector: " << *IVI << "\n"); + // Aggregate value is unlikely to be processed in vector register, we need to + // extract scalars into scalar registers, so NeedExtraction is set true. + return tryToVectorizeList(BuildVectorOpds, R, BuildVector, false, true); +} + +bool SLPVectorizerPass::vectorizeInsertElementInst(InsertElementInst *IEI, + BasicBlock *BB, BoUpSLP &R) { + SmallVector<Value *, 16> BuildVector; + SmallVector<Value *, 16> BuildVectorOpds; + if (!findBuildVector(IEI, BuildVector, BuildVectorOpds)) + return false; + + // Vectorize starting with the build vector operands ignoring the BuildVector + // instructions for the purpose of scheduling and user extraction. + return tryToVectorizeList(BuildVectorOpds, R, BuildVector); +} + +bool SLPVectorizerPass::vectorizeCmpInst(CmpInst *CI, BasicBlock *BB, + BoUpSLP &R) { + if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) + return true; + + bool OpsChanged = false; + for (int Idx = 0; Idx < 2; ++Idx) { + OpsChanged |= + vectorizeRootInstruction(nullptr, CI->getOperand(Idx), BB, R, TTI); + } + return OpsChanged; +} + +bool SLPVectorizerPass::vectorizeSimpleInstructions( + SmallVectorImpl<WeakVH> &Instructions, BasicBlock *BB, BoUpSLP &R) { + bool OpsChanged = false; + for (auto &VH : reverse(Instructions)) { + auto *I = dyn_cast_or_null<Instruction>(VH); + if (!I) + continue; + if (auto *LastInsertValue = dyn_cast<InsertValueInst>(I)) + OpsChanged |= vectorizeInsertValueInst(LastInsertValue, BB, R); + else if (auto *LastInsertElem = dyn_cast<InsertElementInst>(I)) + OpsChanged |= vectorizeInsertElementInst(LastInsertElem, BB, R); + else if (auto *CI = dyn_cast<CmpInst>(I)) + OpsChanged |= vectorizeCmpInst(CI, BB, R); + } + Instructions.clear(); + return OpsChanged; } bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { @@ -4913,10 +5930,21 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { VisitedInstrs.clear(); + SmallVector<WeakVH, 8> PostProcessInstructions; + SmallDenseSet<Instruction *, 4> KeyNodes; for (BasicBlock::iterator it = BB->begin(), e = BB->end(); it != e; it++) { // We may go through BB multiple times so skip the one we have checked. - if (!VisitedInstrs.insert(&*it).second) + if (!VisitedInstrs.insert(&*it).second) { + if (it->use_empty() && KeyNodes.count(&*it) > 0 && + vectorizeSimpleInstructions(PostProcessInstructions, BB, R)) { + // We would like to start over since some instructions are deleted + // and the iterator may become invalid value. + Changed = true; + it = BB->begin(); + e = BB->end(); + } continue; + } if (isa<DbgInfoIntrinsic>(it)) continue; @@ -4938,96 +5966,37 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { continue; } - if (ShouldStartVectorizeHorAtStore) { - if (StoreInst *SI = dyn_cast<StoreInst>(it)) { - // Try to match and vectorize a horizontal reduction. - if (vectorizeRootInstruction(nullptr, SI->getValueOperand(), BB, R, - TTI)) { - Changed = true; - it = BB->begin(); - e = BB->end(); - continue; + // Ran into an instruction without users, like terminator, or function call + // with ignored return value, store. Ignore unused instructions (basing on + // instruction type, except for CallInst and InvokeInst). + if (it->use_empty() && (it->getType()->isVoidTy() || isa<CallInst>(it) || + isa<InvokeInst>(it))) { + KeyNodes.insert(&*it); + bool OpsChanged = false; + if (ShouldStartVectorizeHorAtStore || !isa<StoreInst>(it)) { + for (auto *V : it->operand_values()) { + // Try to match and vectorize a horizontal reduction. + OpsChanged |= vectorizeRootInstruction(nullptr, V, BB, R, TTI); } } - } - - // Try to vectorize horizontal reductions feeding into a return. - if (ReturnInst *RI = dyn_cast<ReturnInst>(it)) { - if (RI->getNumOperands() != 0) { - // Try to match and vectorize a horizontal reduction. - if (vectorizeRootInstruction(nullptr, RI->getOperand(0), BB, R, TTI)) { - Changed = true; - it = BB->begin(); - e = BB->end(); - continue; - } - } - } - - // Try to vectorize trees that start at compare instructions. - if (CmpInst *CI = dyn_cast<CmpInst>(it)) { - if (tryToVectorizePair(CI->getOperand(0), CI->getOperand(1), R)) { - Changed = true; + // Start vectorization of post-process list of instructions from the + // top-tree instructions to try to vectorize as many instructions as + // possible. + OpsChanged |= vectorizeSimpleInstructions(PostProcessInstructions, BB, R); + if (OpsChanged) { // We would like to start over since some instructions are deleted // and the iterator may become invalid value. - it = BB->begin(); - e = BB->end(); - continue; - } - - for (int I = 0; I < 2; ++I) { - if (vectorizeRootInstruction(nullptr, CI->getOperand(I), BB, R, TTI)) { - Changed = true; - // We would like to start over since some instructions are deleted - // and the iterator may become invalid value. - it = BB->begin(); - e = BB->end(); - break; - } - } - continue; - } - - // Try to vectorize trees that start at insertelement instructions. - if (InsertElementInst *FirstInsertElem = dyn_cast<InsertElementInst>(it)) { - SmallVector<Value *, 16> BuildVector; - SmallVector<Value *, 16> BuildVectorOpds; - if (!findBuildVector(FirstInsertElem, BuildVector, BuildVectorOpds)) - continue; - - // Vectorize starting with the build vector operands ignoring the - // BuildVector instructions for the purpose of scheduling and user - // extraction. - if (tryToVectorizeList(BuildVectorOpds, R, BuildVector)) { Changed = true; it = BB->begin(); e = BB->end(); + continue; } - - continue; } - // Try to vectorize trees that start at insertvalue instructions feeding into - // a store. - if (StoreInst *SI = dyn_cast<StoreInst>(it)) { - if (InsertValueInst *LastInsertValue = dyn_cast<InsertValueInst>(SI->getValueOperand())) { - const DataLayout &DL = BB->getModule()->getDataLayout(); - if (R.canMapToVector(SI->getValueOperand()->getType(), DL)) { - SmallVector<Value *, 16> BuildVector; - SmallVector<Value *, 16> BuildVectorOpds; - if (!findBuildAggregate(LastInsertValue, BuildVector, BuildVectorOpds)) - continue; + if (isa<InsertElementInst>(it) || isa<CmpInst>(it) || + isa<InsertValueInst>(it)) + PostProcessInstructions.push_back(&*it); - DEBUG(dbgs() << "SLP: store of array mappable to vector: " << *SI << "\n"); - if (tryToVectorizeList(BuildVectorOpds, R, BuildVector, false)) { - Changed = true; - it = BB->begin(); - e = BB->end(); - } - continue; - } - } - } } return Changed; @@ -5036,7 +6005,6 @@ bool SLPVectorizerPass::vectorizeChainsInBlock(BasicBlock *BB, BoUpSLP &R) { bool SLPVectorizerPass::vectorizeGEPIndices(BasicBlock *BB, BoUpSLP &R) { auto Changed = false; for (auto &Entry : GEPs) { - // If the getelementptr list has fewer than two elements, there's nothing // to do. if (Entry.second.size() < 2) @@ -5141,7 +6109,9 @@ bool SLPVectorizerPass::vectorizeStoreChains(BoUpSLP &R) { } char SLPVectorizer::ID = 0; + static const char lv_name[] = "SLP Vectorizer"; + INITIALIZE_PASS_BEGIN(SLPVectorizer, SV_NAME, lv_name, false, false) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) @@ -5152,6 +6122,4 @@ INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) INITIALIZE_PASS_END(SLPVectorizer, SV_NAME, lv_name, false, false) -namespace llvm { -Pass *createSLPVectorizerPass() { return new SLPVectorizer(); } -} +Pass *llvm::createSLPVectorizerPass() { return new SLPVectorizer(); } diff --git a/lib/Transforms/Vectorize/VPlan.cpp b/lib/Transforms/Vectorize/VPlan.cpp new file mode 100644 index 000000000000..4e54fc6db2a5 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlan.cpp @@ -0,0 +1,557 @@ +//===- VPlan.cpp - Vectorizer Plan ----------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This is the LLVM vectorization plan. It represents a candidate for +/// vectorization, allowing to plan and optimize how to vectorize a given loop +/// before generating LLVM-IR. +/// The vectorizer uses vectorization plans to estimate the costs of potential +/// candidates and if profitable to execute the desired plan, generating vector +/// LLVM-IR code. +/// +//===----------------------------------------------------------------------===// + +#include "VPlan.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/GraphWriter.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include <cassert> +#include <iterator> +#include <string> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "vplan" + +raw_ostream &llvm::operator<<(raw_ostream &OS, const VPValue &V) { + if (const VPInstruction *Instr = dyn_cast<VPInstruction>(&V)) + Instr->print(OS); + else + V.printAsOperand(OS); + return OS; +} + +/// \return the VPBasicBlock that is the entry of Block, possibly indirectly. +const VPBasicBlock *VPBlockBase::getEntryBasicBlock() const { + const VPBlockBase *Block = this; + while (const VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) + Block = Region->getEntry(); + return cast<VPBasicBlock>(Block); +} + +VPBasicBlock *VPBlockBase::getEntryBasicBlock() { + VPBlockBase *Block = this; + while (VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) + Block = Region->getEntry(); + return cast<VPBasicBlock>(Block); +} + +/// \return the VPBasicBlock that is the exit of Block, possibly indirectly. +const VPBasicBlock *VPBlockBase::getExitBasicBlock() const { + const VPBlockBase *Block = this; + while (const VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) + Block = Region->getExit(); + return cast<VPBasicBlock>(Block); +} + +VPBasicBlock *VPBlockBase::getExitBasicBlock() { + VPBlockBase *Block = this; + while (VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) + Block = Region->getExit(); + return cast<VPBasicBlock>(Block); +} + +VPBlockBase *VPBlockBase::getEnclosingBlockWithSuccessors() { + if (!Successors.empty() || !Parent) + return this; + assert(Parent->getExit() == this && + "Block w/o successors not the exit of its parent."); + return Parent->getEnclosingBlockWithSuccessors(); +} + +VPBlockBase *VPBlockBase::getEnclosingBlockWithPredecessors() { + if (!Predecessors.empty() || !Parent) + return this; + assert(Parent->getEntry() == this && + "Block w/o predecessors not the entry of its parent."); + return Parent->getEnclosingBlockWithPredecessors(); +} + +void VPBlockBase::deleteCFG(VPBlockBase *Entry) { + SmallVector<VPBlockBase *, 8> Blocks; + for (VPBlockBase *Block : depth_first(Entry)) + Blocks.push_back(Block); + + for (VPBlockBase *Block : Blocks) + delete Block; +} + +BasicBlock * +VPBasicBlock::createEmptyBasicBlock(VPTransformState::CFGState &CFG) { + // BB stands for IR BasicBlocks. VPBB stands for VPlan VPBasicBlocks. + // Pred stands for Predessor. Prev stands for Previous - last visited/created. + BasicBlock *PrevBB = CFG.PrevBB; + BasicBlock *NewBB = BasicBlock::Create(PrevBB->getContext(), getName(), + PrevBB->getParent(), CFG.LastBB); + DEBUG(dbgs() << "LV: created " << NewBB->getName() << '\n'); + + // Hook up the new basic block to its predecessors. + for (VPBlockBase *PredVPBlock : getHierarchicalPredecessors()) { + VPBasicBlock *PredVPBB = PredVPBlock->getExitBasicBlock(); + auto &PredVPSuccessors = PredVPBB->getSuccessors(); + BasicBlock *PredBB = CFG.VPBB2IRBB[PredVPBB]; + assert(PredBB && "Predecessor basic-block not found building successor."); + auto *PredBBTerminator = PredBB->getTerminator(); + DEBUG(dbgs() << "LV: draw edge from" << PredBB->getName() << '\n'); + if (isa<UnreachableInst>(PredBBTerminator)) { + assert(PredVPSuccessors.size() == 1 && + "Predecessor ending w/o branch must have single successor."); + PredBBTerminator->eraseFromParent(); + BranchInst::Create(NewBB, PredBB); + } else { + assert(PredVPSuccessors.size() == 2 && + "Predecessor ending with branch must have two successors."); + unsigned idx = PredVPSuccessors.front() == this ? 0 : 1; + assert(!PredBBTerminator->getSuccessor(idx) && + "Trying to reset an existing successor block."); + PredBBTerminator->setSuccessor(idx, NewBB); + } + } + return NewBB; +} + +void VPBasicBlock::execute(VPTransformState *State) { + bool Replica = State->Instance && + !(State->Instance->Part == 0 && State->Instance->Lane == 0); + VPBasicBlock *PrevVPBB = State->CFG.PrevVPBB; + VPBlockBase *SingleHPred = nullptr; + BasicBlock *NewBB = State->CFG.PrevBB; // Reuse it if possible. + + // 1. Create an IR basic block, or reuse the last one if possible. + // The last IR basic block is reused, as an optimization, in three cases: + // A. the first VPBB reuses the loop header BB - when PrevVPBB is null; + // B. when the current VPBB has a single (hierarchical) predecessor which + // is PrevVPBB and the latter has a single (hierarchical) successor; and + // C. when the current VPBB is an entry of a region replica - where PrevVPBB + // is the exit of this region from a previous instance, or the predecessor + // of this region. + if (PrevVPBB && /* A */ + !((SingleHPred = getSingleHierarchicalPredecessor()) && + SingleHPred->getExitBasicBlock() == PrevVPBB && + PrevVPBB->getSingleHierarchicalSuccessor()) && /* B */ + !(Replica && getPredecessors().empty())) { /* C */ + NewBB = createEmptyBasicBlock(State->CFG); + State->Builder.SetInsertPoint(NewBB); + // Temporarily terminate with unreachable until CFG is rewired. + UnreachableInst *Terminator = State->Builder.CreateUnreachable(); + State->Builder.SetInsertPoint(Terminator); + // Register NewBB in its loop. In innermost loops its the same for all BB's. + Loop *L = State->LI->getLoopFor(State->CFG.LastBB); + L->addBasicBlockToLoop(NewBB, *State->LI); + State->CFG.PrevBB = NewBB; + } + + // 2. Fill the IR basic block with IR instructions. + DEBUG(dbgs() << "LV: vectorizing VPBB:" << getName() + << " in BB:" << NewBB->getName() << '\n'); + + State->CFG.VPBB2IRBB[this] = NewBB; + State->CFG.PrevVPBB = this; + + for (VPRecipeBase &Recipe : Recipes) + Recipe.execute(*State); + + DEBUG(dbgs() << "LV: filled BB:" << *NewBB); +} + +void VPRegionBlock::execute(VPTransformState *State) { + ReversePostOrderTraversal<VPBlockBase *> RPOT(Entry); + + if (!isReplicator()) { + // Visit the VPBlocks connected to "this", starting from it. + for (VPBlockBase *Block : RPOT) { + DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); + Block->execute(State); + } + return; + } + + assert(!State->Instance && "Replicating a Region with non-null instance."); + + // Enter replicating mode. + State->Instance = {0, 0}; + + for (unsigned Part = 0, UF = State->UF; Part < UF; ++Part) { + State->Instance->Part = Part; + for (unsigned Lane = 0, VF = State->VF; Lane < VF; ++Lane) { + State->Instance->Lane = Lane; + // Visit the VPBlocks connected to \p this, starting from it. + for (VPBlockBase *Block : RPOT) { + DEBUG(dbgs() << "LV: VPBlock in RPO " << Block->getName() << '\n'); + Block->execute(State); + } + } + } + + // Exit replicating mode. + State->Instance.reset(); +} + +void VPInstruction::generateInstruction(VPTransformState &State, + unsigned Part) { + IRBuilder<> &Builder = State.Builder; + + if (Instruction::isBinaryOp(getOpcode())) { + Value *A = State.get(getOperand(0), Part); + Value *B = State.get(getOperand(1), Part); + Value *V = Builder.CreateBinOp((Instruction::BinaryOps)getOpcode(), A, B); + State.set(this, V, Part); + return; + } + + switch (getOpcode()) { + case VPInstruction::Not: { + Value *A = State.get(getOperand(0), Part); + Value *V = Builder.CreateNot(A); + State.set(this, V, Part); + break; + } + default: + llvm_unreachable("Unsupported opcode for instruction"); + } +} + +void VPInstruction::execute(VPTransformState &State) { + assert(!State.Instance && "VPInstruction executing an Instance"); + for (unsigned Part = 0; Part < State.UF; ++Part) + generateInstruction(State, Part); +} + +void VPInstruction::print(raw_ostream &O, const Twine &Indent) const { + O << " +\n" << Indent << "\"EMIT "; + print(O); + O << "\\l\""; +} + +void VPInstruction::print(raw_ostream &O) const { + printAsOperand(O); + O << " = "; + + switch (getOpcode()) { + case VPInstruction::Not: + O << "not"; + break; + default: + O << Instruction::getOpcodeName(getOpcode()); + } + + for (const VPValue *Operand : operands()) { + O << " "; + Operand->printAsOperand(O); + } +} + +/// Generate the code inside the body of the vectorized loop. Assumes a single +/// LoopVectorBody basic-block was created for this. Introduce additional +/// basic-blocks as needed, and fill them all. +void VPlan::execute(VPTransformState *State) { + // 0. Set the reverse mapping from VPValues to Values for code generation. + for (auto &Entry : Value2VPValue) + State->VPValue2Value[Entry.second] = Entry.first; + + BasicBlock *VectorPreHeaderBB = State->CFG.PrevBB; + BasicBlock *VectorHeaderBB = VectorPreHeaderBB->getSingleSuccessor(); + assert(VectorHeaderBB && "Loop preheader does not have a single successor."); + BasicBlock *VectorLatchBB = VectorHeaderBB; + + // 1. Make room to generate basic-blocks inside loop body if needed. + VectorLatchBB = VectorHeaderBB->splitBasicBlock( + VectorHeaderBB->getFirstInsertionPt(), "vector.body.latch"); + Loop *L = State->LI->getLoopFor(VectorHeaderBB); + L->addBasicBlockToLoop(VectorLatchBB, *State->LI); + // Remove the edge between Header and Latch to allow other connections. + // Temporarily terminate with unreachable until CFG is rewired. + // Note: this asserts the generated code's assumption that + // getFirstInsertionPt() can be dereferenced into an Instruction. + VectorHeaderBB->getTerminator()->eraseFromParent(); + State->Builder.SetInsertPoint(VectorHeaderBB); + UnreachableInst *Terminator = State->Builder.CreateUnreachable(); + State->Builder.SetInsertPoint(Terminator); + + // 2. Generate code in loop body. + State->CFG.PrevVPBB = nullptr; + State->CFG.PrevBB = VectorHeaderBB; + State->CFG.LastBB = VectorLatchBB; + + for (VPBlockBase *Block : depth_first(Entry)) + Block->execute(State); + + // 3. Merge the temporary latch created with the last basic-block filled. + BasicBlock *LastBB = State->CFG.PrevBB; + // Connect LastBB to VectorLatchBB to facilitate their merge. + assert(isa<UnreachableInst>(LastBB->getTerminator()) && + "Expected VPlan CFG to terminate with unreachable"); + LastBB->getTerminator()->eraseFromParent(); + BranchInst::Create(VectorLatchBB, LastBB); + + // Merge LastBB with Latch. + bool Merged = MergeBlockIntoPredecessor(VectorLatchBB, nullptr, State->LI); + (void)Merged; + assert(Merged && "Could not merge last basic block with latch."); + VectorLatchBB = LastBB; + + updateDominatorTree(State->DT, VectorPreHeaderBB, VectorLatchBB); +} + +void VPlan::updateDominatorTree(DominatorTree *DT, BasicBlock *LoopPreHeaderBB, + BasicBlock *LoopLatchBB) { + BasicBlock *LoopHeaderBB = LoopPreHeaderBB->getSingleSuccessor(); + assert(LoopHeaderBB && "Loop preheader does not have a single successor."); + DT->addNewBlock(LoopHeaderBB, LoopPreHeaderBB); + // The vector body may be more than a single basic-block by this point. + // Update the dominator tree information inside the vector body by propagating + // it from header to latch, expecting only triangular control-flow, if any. + BasicBlock *PostDomSucc = nullptr; + for (auto *BB = LoopHeaderBB; BB != LoopLatchBB; BB = PostDomSucc) { + // Get the list of successors of this block. + std::vector<BasicBlock *> Succs(succ_begin(BB), succ_end(BB)); + assert(Succs.size() <= 2 && + "Basic block in vector loop has more than 2 successors."); + PostDomSucc = Succs[0]; + if (Succs.size() == 1) { + assert(PostDomSucc->getSinglePredecessor() && + "PostDom successor has more than one predecessor."); + DT->addNewBlock(PostDomSucc, BB); + continue; + } + BasicBlock *InterimSucc = Succs[1]; + if (PostDomSucc->getSingleSuccessor() == InterimSucc) { + PostDomSucc = Succs[1]; + InterimSucc = Succs[0]; + } + assert(InterimSucc->getSingleSuccessor() == PostDomSucc && + "One successor of a basic block does not lead to the other."); + assert(InterimSucc->getSinglePredecessor() && + "Interim successor has more than one predecessor."); + assert(std::distance(pred_begin(PostDomSucc), pred_end(PostDomSucc)) == 2 && + "PostDom successor has more than two predecessors."); + DT->addNewBlock(InterimSucc, BB); + DT->addNewBlock(PostDomSucc, BB); + } +} + +const Twine VPlanPrinter::getUID(const VPBlockBase *Block) { + return (isa<VPRegionBlock>(Block) ? "cluster_N" : "N") + + Twine(getOrCreateBID(Block)); +} + +const Twine VPlanPrinter::getOrCreateName(const VPBlockBase *Block) { + const std::string &Name = Block->getName(); + if (!Name.empty()) + return Name; + return "VPB" + Twine(getOrCreateBID(Block)); +} + +void VPlanPrinter::dump() { + Depth = 1; + bumpIndent(0); + OS << "digraph VPlan {\n"; + OS << "graph [labelloc=t, fontsize=30; label=\"Vectorization Plan"; + if (!Plan.getName().empty()) + OS << "\\n" << DOT::EscapeString(Plan.getName()); + if (!Plan.Value2VPValue.empty()) { + OS << ", where:"; + for (auto Entry : Plan.Value2VPValue) { + OS << "\\n" << *Entry.second; + OS << DOT::EscapeString(" := "); + Entry.first->printAsOperand(OS, false); + } + } + OS << "\"]\n"; + OS << "node [shape=rect, fontname=Courier, fontsize=30]\n"; + OS << "edge [fontname=Courier, fontsize=30]\n"; + OS << "compound=true\n"; + + for (VPBlockBase *Block : depth_first(Plan.getEntry())) + dumpBlock(Block); + + OS << "}\n"; +} + +void VPlanPrinter::dumpBlock(const VPBlockBase *Block) { + if (const VPBasicBlock *BasicBlock = dyn_cast<VPBasicBlock>(Block)) + dumpBasicBlock(BasicBlock); + else if (const VPRegionBlock *Region = dyn_cast<VPRegionBlock>(Block)) + dumpRegion(Region); + else + llvm_unreachable("Unsupported kind of VPBlock."); +} + +void VPlanPrinter::drawEdge(const VPBlockBase *From, const VPBlockBase *To, + bool Hidden, const Twine &Label) { + // Due to "dot" we print an edge between two regions as an edge between the + // exit basic block and the entry basic of the respective regions. + const VPBlockBase *Tail = From->getExitBasicBlock(); + const VPBlockBase *Head = To->getEntryBasicBlock(); + OS << Indent << getUID(Tail) << " -> " << getUID(Head); + OS << " [ label=\"" << Label << '\"'; + if (Tail != From) + OS << " ltail=" << getUID(From); + if (Head != To) + OS << " lhead=" << getUID(To); + if (Hidden) + OS << "; splines=none"; + OS << "]\n"; +} + +void VPlanPrinter::dumpEdges(const VPBlockBase *Block) { + auto &Successors = Block->getSuccessors(); + if (Successors.size() == 1) + drawEdge(Block, Successors.front(), false, ""); + else if (Successors.size() == 2) { + drawEdge(Block, Successors.front(), false, "T"); + drawEdge(Block, Successors.back(), false, "F"); + } else { + unsigned SuccessorNumber = 0; + for (auto *Successor : Successors) + drawEdge(Block, Successor, false, Twine(SuccessorNumber++)); + } +} + +void VPlanPrinter::dumpBasicBlock(const VPBasicBlock *BasicBlock) { + OS << Indent << getUID(BasicBlock) << " [label =\n"; + bumpIndent(1); + OS << Indent << "\"" << DOT::EscapeString(BasicBlock->getName()) << ":\\n\""; + bumpIndent(1); + for (const VPRecipeBase &Recipe : *BasicBlock) + Recipe.print(OS, Indent); + bumpIndent(-2); + OS << "\n" << Indent << "]\n"; + dumpEdges(BasicBlock); +} + +void VPlanPrinter::dumpRegion(const VPRegionBlock *Region) { + OS << Indent << "subgraph " << getUID(Region) << " {\n"; + bumpIndent(1); + OS << Indent << "fontname=Courier\n" + << Indent << "label=\"" + << DOT::EscapeString(Region->isReplicator() ? "<xVFxUF> " : "<x1> ") + << DOT::EscapeString(Region->getName()) << "\"\n"; + // Dump the blocks of the region. + assert(Region->getEntry() && "Region contains no inner blocks."); + for (const VPBlockBase *Block : depth_first(Region->getEntry())) + dumpBlock(Block); + bumpIndent(-1); + OS << Indent << "}\n"; + dumpEdges(Region); +} + +void VPlanPrinter::printAsIngredient(raw_ostream &O, Value *V) { + std::string IngredientString; + raw_string_ostream RSO(IngredientString); + if (auto *Inst = dyn_cast<Instruction>(V)) { + if (!Inst->getType()->isVoidTy()) { + Inst->printAsOperand(RSO, false); + RSO << " = "; + } + RSO << Inst->getOpcodeName() << " "; + unsigned E = Inst->getNumOperands(); + if (E > 0) { + Inst->getOperand(0)->printAsOperand(RSO, false); + for (unsigned I = 1; I < E; ++I) + Inst->getOperand(I)->printAsOperand(RSO << ", ", false); + } + } else // !Inst + V->printAsOperand(RSO, false); + RSO.flush(); + O << DOT::EscapeString(IngredientString); +} + +void VPWidenRecipe::print(raw_ostream &O, const Twine &Indent) const { + O << " +\n" << Indent << "\"WIDEN\\l\""; + for (auto &Instr : make_range(Begin, End)) + O << " +\n" << Indent << "\" " << VPlanIngredient(&Instr) << "\\l\""; +} + +void VPWidenIntOrFpInductionRecipe::print(raw_ostream &O, + const Twine &Indent) const { + O << " +\n" << Indent << "\"WIDEN-INDUCTION"; + if (Trunc) { + O << "\\l\""; + O << " +\n" << Indent << "\" " << VPlanIngredient(IV) << "\\l\""; + O << " +\n" << Indent << "\" " << VPlanIngredient(Trunc) << "\\l\""; + } else + O << " " << VPlanIngredient(IV) << "\\l\""; +} + +void VPWidenPHIRecipe::print(raw_ostream &O, const Twine &Indent) const { + O << " +\n" << Indent << "\"WIDEN-PHI " << VPlanIngredient(Phi) << "\\l\""; +} + +void VPBlendRecipe::print(raw_ostream &O, const Twine &Indent) const { + O << " +\n" << Indent << "\"BLEND "; + Phi->printAsOperand(O, false); + O << " ="; + if (!User) { + // Not a User of any mask: not really blending, this is a + // single-predecessor phi. + O << " "; + Phi->getIncomingValue(0)->printAsOperand(O, false); + } else { + for (unsigned I = 0, E = User->getNumOperands(); I < E; ++I) { + O << " "; + Phi->getIncomingValue(I)->printAsOperand(O, false); + O << "/"; + User->getOperand(I)->printAsOperand(O); + } + } + O << "\\l\""; +} + +void VPReplicateRecipe::print(raw_ostream &O, const Twine &Indent) const { + O << " +\n" + << Indent << "\"" << (IsUniform ? "CLONE " : "REPLICATE ") + << VPlanIngredient(Ingredient); + if (AlsoPack) + O << " (S->V)"; + O << "\\l\""; +} + +void VPPredInstPHIRecipe::print(raw_ostream &O, const Twine &Indent) const { + O << " +\n" + << Indent << "\"PHI-PREDICATED-INSTRUCTION " << VPlanIngredient(PredInst) + << "\\l\""; +} + +void VPWidenMemoryInstructionRecipe::print(raw_ostream &O, + const Twine &Indent) const { + O << " +\n" << Indent << "\"WIDEN " << VPlanIngredient(&Instr); + if (User) { + O << ", "; + User->getOperand(0)->printAsOperand(O); + } + O << "\\l\""; +} diff --git a/lib/Transforms/Vectorize/VPlan.h b/lib/Transforms/Vectorize/VPlan.h new file mode 100644 index 000000000000..2ccabfd6af25 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlan.h @@ -0,0 +1,1194 @@ +//===- VPlan.h - Represent A Vectorizer Plan --------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file contains the declarations of the Vectorization Plan base classes: +/// 1. VPBasicBlock and VPRegionBlock that inherit from a common pure virtual +/// VPBlockBase, together implementing a Hierarchical CFG; +/// 2. Specializations of GraphTraits that allow VPBlockBase graphs to be +/// treated as proper graphs for generic algorithms; +/// 3. Pure virtual VPRecipeBase serving as the base class for recipes contained +/// within VPBasicBlocks; +/// 4. VPInstruction, a concrete Recipe and VPUser modeling a single planned +/// instruction; +/// 5. The VPlan class holding a candidate for vectorization; +/// 6. The VPlanPrinter class providing a way to print a plan in dot format; +/// These are documented in docs/VectorizationPlan.rst. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_H + +#include "VPlanValue.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/ilist.h" +#include "llvm/ADT/ilist_node.h" +#include "llvm/IR/IRBuilder.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <map> +#include <string> + +// The (re)use of existing LoopVectorize classes is subject to future VPlan +// refactoring. +namespace { +class LoopVectorizationLegality; +class LoopVectorizationCostModel; +} // namespace + +namespace llvm { + +class BasicBlock; +class DominatorTree; +class InnerLoopVectorizer; +class InterleaveGroup; +class LoopInfo; +class raw_ostream; +class Value; +class VPBasicBlock; +class VPRegionBlock; + +/// In what follows, the term "input IR" refers to code that is fed into the +/// vectorizer whereas the term "output IR" refers to code that is generated by +/// the vectorizer. + +/// VPIteration represents a single point in the iteration space of the output +/// (vectorized and/or unrolled) IR loop. +struct VPIteration { + /// in [0..UF) + unsigned Part; + + /// in [0..VF) + unsigned Lane; +}; + +/// This is a helper struct for maintaining vectorization state. It's used for +/// mapping values from the original loop to their corresponding values in +/// the new loop. Two mappings are maintained: one for vectorized values and +/// one for scalarized values. Vectorized values are represented with UF +/// vector values in the new loop, and scalarized values are represented with +/// UF x VF scalar values in the new loop. UF and VF are the unroll and +/// vectorization factors, respectively. +/// +/// Entries can be added to either map with setVectorValue and setScalarValue, +/// which assert that an entry was not already added before. If an entry is to +/// replace an existing one, call resetVectorValue and resetScalarValue. This is +/// currently needed to modify the mapped values during "fix-up" operations that +/// occur once the first phase of widening is complete. These operations include +/// type truncation and the second phase of recurrence widening. +/// +/// Entries from either map can be retrieved using the getVectorValue and +/// getScalarValue functions, which assert that the desired value exists. +struct VectorizerValueMap { + friend struct VPTransformState; + +private: + /// The unroll factor. Each entry in the vector map contains UF vector values. + unsigned UF; + + /// The vectorization factor. Each entry in the scalar map contains UF x VF + /// scalar values. + unsigned VF; + + /// The vector and scalar map storage. We use std::map and not DenseMap + /// because insertions to DenseMap invalidate its iterators. + using VectorParts = SmallVector<Value *, 2>; + using ScalarParts = SmallVector<SmallVector<Value *, 4>, 2>; + std::map<Value *, VectorParts> VectorMapStorage; + std::map<Value *, ScalarParts> ScalarMapStorage; + +public: + /// Construct an empty map with the given unroll and vectorization factors. + VectorizerValueMap(unsigned UF, unsigned VF) : UF(UF), VF(VF) {} + + /// \return True if the map has any vector entry for \p Key. + bool hasAnyVectorValue(Value *Key) const { + return VectorMapStorage.count(Key); + } + + /// \return True if the map has a vector entry for \p Key and \p Part. + bool hasVectorValue(Value *Key, unsigned Part) const { + assert(Part < UF && "Queried Vector Part is too large."); + if (!hasAnyVectorValue(Key)) + return false; + const VectorParts &Entry = VectorMapStorage.find(Key)->second; + assert(Entry.size() == UF && "VectorParts has wrong dimensions."); + return Entry[Part] != nullptr; + } + + /// \return True if the map has any scalar entry for \p Key. + bool hasAnyScalarValue(Value *Key) const { + return ScalarMapStorage.count(Key); + } + + /// \return True if the map has a scalar entry for \p Key and \p Instance. + bool hasScalarValue(Value *Key, const VPIteration &Instance) const { + assert(Instance.Part < UF && "Queried Scalar Part is too large."); + assert(Instance.Lane < VF && "Queried Scalar Lane is too large."); + if (!hasAnyScalarValue(Key)) + return false; + const ScalarParts &Entry = ScalarMapStorage.find(Key)->second; + assert(Entry.size() == UF && "ScalarParts has wrong dimensions."); + assert(Entry[Instance.Part].size() == VF && + "ScalarParts has wrong dimensions."); + return Entry[Instance.Part][Instance.Lane] != nullptr; + } + + /// Retrieve the existing vector value that corresponds to \p Key and + /// \p Part. + Value *getVectorValue(Value *Key, unsigned Part) { + assert(hasVectorValue(Key, Part) && "Getting non-existent value."); + return VectorMapStorage[Key][Part]; + } + + /// Retrieve the existing scalar value that corresponds to \p Key and + /// \p Instance. + Value *getScalarValue(Value *Key, const VPIteration &Instance) { + assert(hasScalarValue(Key, Instance) && "Getting non-existent value."); + return ScalarMapStorage[Key][Instance.Part][Instance.Lane]; + } + + /// Set a vector value associated with \p Key and \p Part. Assumes such a + /// value is not already set. If it is, use resetVectorValue() instead. + void setVectorValue(Value *Key, unsigned Part, Value *Vector) { + assert(!hasVectorValue(Key, Part) && "Vector value already set for part"); + if (!VectorMapStorage.count(Key)) { + VectorParts Entry(UF); + VectorMapStorage[Key] = Entry; + } + VectorMapStorage[Key][Part] = Vector; + } + + /// Set a scalar value associated with \p Key and \p Instance. Assumes such a + /// value is not already set. + void setScalarValue(Value *Key, const VPIteration &Instance, Value *Scalar) { + assert(!hasScalarValue(Key, Instance) && "Scalar value already set"); + if (!ScalarMapStorage.count(Key)) { + ScalarParts Entry(UF); + // TODO: Consider storing uniform values only per-part, as they occupy + // lane 0 only, keeping the other VF-1 redundant entries null. + for (unsigned Part = 0; Part < UF; ++Part) + Entry[Part].resize(VF, nullptr); + ScalarMapStorage[Key] = Entry; + } + ScalarMapStorage[Key][Instance.Part][Instance.Lane] = Scalar; + } + + /// Reset the vector value associated with \p Key for the given \p Part. + /// This function can be used to update values that have already been + /// vectorized. This is the case for "fix-up" operations including type + /// truncation and the second phase of recurrence vectorization. + void resetVectorValue(Value *Key, unsigned Part, Value *Vector) { + assert(hasVectorValue(Key, Part) && "Vector value not set for part"); + VectorMapStorage[Key][Part] = Vector; + } + + /// Reset the scalar value associated with \p Key for \p Part and \p Lane. + /// This function can be used to update values that have already been + /// scalarized. This is the case for "fix-up" operations including scalar phi + /// nodes for scalarized and predicated instructions. + void resetScalarValue(Value *Key, const VPIteration &Instance, + Value *Scalar) { + assert(hasScalarValue(Key, Instance) && + "Scalar value not set for part and lane"); + ScalarMapStorage[Key][Instance.Part][Instance.Lane] = Scalar; + } +}; + +/// This class is used to enable the VPlan to invoke a method of ILV. This is +/// needed until the method is refactored out of ILV and becomes reusable. +struct VPCallback { + virtual ~VPCallback() {} + virtual Value *getOrCreateVectorValues(Value *V, unsigned Part) = 0; +}; + +/// VPTransformState holds information passed down when "executing" a VPlan, +/// needed for generating the output IR. +struct VPTransformState { + VPTransformState(unsigned VF, unsigned UF, LoopInfo *LI, DominatorTree *DT, + IRBuilder<> &Builder, VectorizerValueMap &ValueMap, + InnerLoopVectorizer *ILV, VPCallback &Callback) + : VF(VF), UF(UF), Instance(), LI(LI), DT(DT), Builder(Builder), + ValueMap(ValueMap), ILV(ILV), Callback(Callback) {} + + /// The chosen Vectorization and Unroll Factors of the loop being vectorized. + unsigned VF; + unsigned UF; + + /// Hold the indices to generate specific scalar instructions. Null indicates + /// that all instances are to be generated, using either scalar or vector + /// instructions. + Optional<VPIteration> Instance; + + struct DataState { + /// A type for vectorized values in the new loop. Each value from the + /// original loop, when vectorized, is represented by UF vector values in + /// the new unrolled loop, where UF is the unroll factor. + typedef SmallVector<Value *, 2> PerPartValuesTy; + + DenseMap<VPValue *, PerPartValuesTy> PerPartOutput; + } Data; + + /// Get the generated Value for a given VPValue and a given Part. Note that + /// as some Defs are still created by ILV and managed in its ValueMap, this + /// method will delegate the call to ILV in such cases in order to provide + /// callers a consistent API. + /// \see set. + Value *get(VPValue *Def, unsigned Part) { + // If Values have been set for this Def return the one relevant for \p Part. + if (Data.PerPartOutput.count(Def)) + return Data.PerPartOutput[Def][Part]; + // Def is managed by ILV: bring the Values from ValueMap. + return Callback.getOrCreateVectorValues(VPValue2Value[Def], Part); + } + + /// Set the generated Value for a given VPValue and a given Part. + void set(VPValue *Def, Value *V, unsigned Part) { + if (!Data.PerPartOutput.count(Def)) { + DataState::PerPartValuesTy Entry(UF); + Data.PerPartOutput[Def] = Entry; + } + Data.PerPartOutput[Def][Part] = V; + } + + /// Hold state information used when constructing the CFG of the output IR, + /// traversing the VPBasicBlocks and generating corresponding IR BasicBlocks. + struct CFGState { + /// The previous VPBasicBlock visited. Initially set to null. + VPBasicBlock *PrevVPBB = nullptr; + + /// The previous IR BasicBlock created or used. Initially set to the new + /// header BasicBlock. + BasicBlock *PrevBB = nullptr; + + /// The last IR BasicBlock in the output IR. Set to the new latch + /// BasicBlock, used for placing the newly created BasicBlocks. + BasicBlock *LastBB = nullptr; + + /// A mapping of each VPBasicBlock to the corresponding BasicBlock. In case + /// of replication, maps the BasicBlock of the last replica created. + SmallDenseMap<VPBasicBlock *, BasicBlock *> VPBB2IRBB; + + CFGState() = default; + } CFG; + + /// Hold a pointer to LoopInfo to register new basic blocks in the loop. + LoopInfo *LI; + + /// Hold a pointer to Dominator Tree to register new basic blocks in the loop. + DominatorTree *DT; + + /// Hold a reference to the IRBuilder used to generate output IR code. + IRBuilder<> &Builder; + + /// Hold a reference to the Value state information used when generating the + /// Values of the output IR. + VectorizerValueMap &ValueMap; + + /// Hold a reference to a mapping between VPValues in VPlan and original + /// Values they correspond to. + VPValue2ValueTy VPValue2Value; + + /// Hold a pointer to InnerLoopVectorizer to reuse its IR generation methods. + InnerLoopVectorizer *ILV; + + VPCallback &Callback; +}; + +/// VPBlockBase is the building block of the Hierarchical Control-Flow Graph. +/// A VPBlockBase can be either a VPBasicBlock or a VPRegionBlock. +class VPBlockBase { +private: + const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). + + /// An optional name for the block. + std::string Name; + + /// The immediate VPRegionBlock which this VPBlockBase belongs to, or null if + /// it is a topmost VPBlockBase. + VPRegionBlock *Parent = nullptr; + + /// List of predecessor blocks. + SmallVector<VPBlockBase *, 1> Predecessors; + + /// List of successor blocks. + SmallVector<VPBlockBase *, 1> Successors; + + /// Add \p Successor as the last successor to this block. + void appendSuccessor(VPBlockBase *Successor) { + assert(Successor && "Cannot add nullptr successor!"); + Successors.push_back(Successor); + } + + /// Add \p Predecessor as the last predecessor to this block. + void appendPredecessor(VPBlockBase *Predecessor) { + assert(Predecessor && "Cannot add nullptr predecessor!"); + Predecessors.push_back(Predecessor); + } + + /// Remove \p Predecessor from the predecessors of this block. + void removePredecessor(VPBlockBase *Predecessor) { + auto Pos = std::find(Predecessors.begin(), Predecessors.end(), Predecessor); + assert(Pos && "Predecessor does not exist"); + Predecessors.erase(Pos); + } + + /// Remove \p Successor from the successors of this block. + void removeSuccessor(VPBlockBase *Successor) { + auto Pos = std::find(Successors.begin(), Successors.end(), Successor); + assert(Pos && "Successor does not exist"); + Successors.erase(Pos); + } + +protected: + VPBlockBase(const unsigned char SC, const std::string &N) + : SubclassID(SC), Name(N) {} + +public: + /// An enumeration for keeping track of the concrete subclass of VPBlockBase + /// that are actually instantiated. Values of this enumeration are kept in the + /// SubclassID field of the VPBlockBase objects. They are used for concrete + /// type identification. + using VPBlockTy = enum { VPBasicBlockSC, VPRegionBlockSC }; + + using VPBlocksTy = SmallVectorImpl<VPBlockBase *>; + + virtual ~VPBlockBase() = default; + + const std::string &getName() const { return Name; } + + void setName(const Twine &newName) { Name = newName.str(); } + + /// \return an ID for the concrete type of this object. + /// This is used to implement the classof checks. This should not be used + /// for any other purpose, as the values may change as LLVM evolves. + unsigned getVPBlockID() const { return SubclassID; } + + const VPRegionBlock *getParent() const { return Parent; } + + void setParent(VPRegionBlock *P) { Parent = P; } + + /// \return the VPBasicBlock that is the entry of this VPBlockBase, + /// recursively, if the latter is a VPRegionBlock. Otherwise, if this + /// VPBlockBase is a VPBasicBlock, it is returned. + const VPBasicBlock *getEntryBasicBlock() const; + VPBasicBlock *getEntryBasicBlock(); + + /// \return the VPBasicBlock that is the exit of this VPBlockBase, + /// recursively, if the latter is a VPRegionBlock. Otherwise, if this + /// VPBlockBase is a VPBasicBlock, it is returned. + const VPBasicBlock *getExitBasicBlock() const; + VPBasicBlock *getExitBasicBlock(); + + const VPBlocksTy &getSuccessors() const { return Successors; } + VPBlocksTy &getSuccessors() { return Successors; } + + const VPBlocksTy &getPredecessors() const { return Predecessors; } + VPBlocksTy &getPredecessors() { return Predecessors; } + + /// \return the successor of this VPBlockBase if it has a single successor. + /// Otherwise return a null pointer. + VPBlockBase *getSingleSuccessor() const { + return (Successors.size() == 1 ? *Successors.begin() : nullptr); + } + + /// \return the predecessor of this VPBlockBase if it has a single + /// predecessor. Otherwise return a null pointer. + VPBlockBase *getSinglePredecessor() const { + return (Predecessors.size() == 1 ? *Predecessors.begin() : nullptr); + } + + /// An Enclosing Block of a block B is any block containing B, including B + /// itself. \return the closest enclosing block starting from "this", which + /// has successors. \return the root enclosing block if all enclosing blocks + /// have no successors. + VPBlockBase *getEnclosingBlockWithSuccessors(); + + /// \return the closest enclosing block starting from "this", which has + /// predecessors. \return the root enclosing block if all enclosing blocks + /// have no predecessors. + VPBlockBase *getEnclosingBlockWithPredecessors(); + + /// \return the successors either attached directly to this VPBlockBase or, if + /// this VPBlockBase is the exit block of a VPRegionBlock and has no + /// successors of its own, search recursively for the first enclosing + /// VPRegionBlock that has successors and return them. If no such + /// VPRegionBlock exists, return the (empty) successors of the topmost + /// VPBlockBase reached. + const VPBlocksTy &getHierarchicalSuccessors() { + return getEnclosingBlockWithSuccessors()->getSuccessors(); + } + + /// \return the hierarchical successor of this VPBlockBase if it has a single + /// hierarchical successor. Otherwise return a null pointer. + VPBlockBase *getSingleHierarchicalSuccessor() { + return getEnclosingBlockWithSuccessors()->getSingleSuccessor(); + } + + /// \return the predecessors either attached directly to this VPBlockBase or, + /// if this VPBlockBase is the entry block of a VPRegionBlock and has no + /// predecessors of its own, search recursively for the first enclosing + /// VPRegionBlock that has predecessors and return them. If no such + /// VPRegionBlock exists, return the (empty) predecessors of the topmost + /// VPBlockBase reached. + const VPBlocksTy &getHierarchicalPredecessors() { + return getEnclosingBlockWithPredecessors()->getPredecessors(); + } + + /// \return the hierarchical predecessor of this VPBlockBase if it has a + /// single hierarchical predecessor. Otherwise return a null pointer. + VPBlockBase *getSingleHierarchicalPredecessor() { + return getEnclosingBlockWithPredecessors()->getSinglePredecessor(); + } + + /// Sets a given VPBlockBase \p Successor as the single successor and \return + /// \p Successor. The parent of this Block is copied to be the parent of + /// \p Successor. + VPBlockBase *setOneSuccessor(VPBlockBase *Successor) { + assert(Successors.empty() && "Setting one successor when others exist."); + appendSuccessor(Successor); + Successor->appendPredecessor(this); + Successor->Parent = Parent; + return Successor; + } + + /// Sets two given VPBlockBases \p IfTrue and \p IfFalse to be the two + /// successors. The parent of this Block is copied to be the parent of both + /// \p IfTrue and \p IfFalse. + void setTwoSuccessors(VPBlockBase *IfTrue, VPBlockBase *IfFalse) { + assert(Successors.empty() && "Setting two successors when others exist."); + appendSuccessor(IfTrue); + appendSuccessor(IfFalse); + IfTrue->appendPredecessor(this); + IfFalse->appendPredecessor(this); + IfTrue->Parent = Parent; + IfFalse->Parent = Parent; + } + + void disconnectSuccessor(VPBlockBase *Successor) { + assert(Successor && "Successor to disconnect is null."); + removeSuccessor(Successor); + Successor->removePredecessor(this); + } + + /// The method which generates the output IR that correspond to this + /// VPBlockBase, thereby "executing" the VPlan. + virtual void execute(struct VPTransformState *State) = 0; + + /// Delete all blocks reachable from a given VPBlockBase, inclusive. + static void deleteCFG(VPBlockBase *Entry); +}; + +/// VPRecipeBase is a base class modeling a sequence of one or more output IR +/// instructions. +class VPRecipeBase : public ilist_node_with_parent<VPRecipeBase, VPBasicBlock> { + friend VPBasicBlock; + +private: + const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). + + /// Each VPRecipe belongs to a single VPBasicBlock. + VPBasicBlock *Parent = nullptr; + +public: + /// An enumeration for keeping track of the concrete subclass of VPRecipeBase + /// that is actually instantiated. Values of this enumeration are kept in the + /// SubclassID field of the VPRecipeBase objects. They are used for concrete + /// type identification. + using VPRecipeTy = enum { + VPBlendSC, + VPBranchOnMaskSC, + VPInstructionSC, + VPInterleaveSC, + VPPredInstPHISC, + VPReplicateSC, + VPWidenIntOrFpInductionSC, + VPWidenMemoryInstructionSC, + VPWidenPHISC, + VPWidenSC, + }; + + VPRecipeBase(const unsigned char SC) : SubclassID(SC) {} + virtual ~VPRecipeBase() = default; + + /// \return an ID for the concrete type of this object. + /// This is used to implement the classof checks. This should not be used + /// for any other purpose, as the values may change as LLVM evolves. + unsigned getVPRecipeID() const { return SubclassID; } + + /// \return the VPBasicBlock which this VPRecipe belongs to. + VPBasicBlock *getParent() { return Parent; } + const VPBasicBlock *getParent() const { return Parent; } + + /// The method which generates the output IR instructions that correspond to + /// this VPRecipe, thereby "executing" the VPlan. + virtual void execute(struct VPTransformState &State) = 0; + + /// Each recipe prints itself. + virtual void print(raw_ostream &O, const Twine &Indent) const = 0; +}; + +/// This is a concrete Recipe that models a single VPlan-level instruction. +/// While as any Recipe it may generate a sequence of IR instructions when +/// executed, these instructions would always form a single-def expression as +/// the VPInstruction is also a single def-use vertex. +class VPInstruction : public VPUser, public VPRecipeBase { +public: + /// VPlan opcodes, extending LLVM IR with idiomatics instructions. + enum { Not = Instruction::OtherOpsEnd + 1 }; + +private: + typedef unsigned char OpcodeTy; + OpcodeTy Opcode; + + /// Utility method serving execute(): generates a single instance of the + /// modeled instruction. + void generateInstruction(VPTransformState &State, unsigned Part); + +public: + VPInstruction(unsigned Opcode, std::initializer_list<VPValue *> Operands) + : VPUser(VPValue::VPInstructionSC, Operands), + VPRecipeBase(VPRecipeBase::VPInstructionSC), Opcode(Opcode) {} + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPValue *V) { + return V->getVPValueID() == VPValue::VPInstructionSC; + } + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *R) { + return R->getVPRecipeID() == VPRecipeBase::VPInstructionSC; + } + + unsigned getOpcode() const { return Opcode; } + + /// Generate the instruction. + /// TODO: We currently execute only per-part unless a specific instance is + /// provided. + void execute(VPTransformState &State) override; + + /// Print the Recipe. + void print(raw_ostream &O, const Twine &Indent) const override; + + /// Print the VPInstruction. + void print(raw_ostream &O) const; +}; + +/// VPWidenRecipe is a recipe for producing a copy of vector type for each +/// Instruction in its ingredients independently, in order. This recipe covers +/// most of the traditional vectorization cases where each ingredient transforms +/// into a vectorized version of itself. +class VPWidenRecipe : public VPRecipeBase { +private: + /// Hold the ingredients by pointing to their original BasicBlock location. + BasicBlock::iterator Begin; + BasicBlock::iterator End; + +public: + VPWidenRecipe(Instruction *I) : VPRecipeBase(VPWidenSC) { + End = I->getIterator(); + Begin = End++; + } + + ~VPWidenRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPWidenSC; + } + + /// Produce widened copies of all Ingredients. + void execute(VPTransformState &State) override; + + /// Augment the recipe to include Instr, if it lies at its End. + bool appendInstruction(Instruction *Instr) { + if (End != Instr->getIterator()) + return false; + End++; + return true; + } + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; +}; + +/// A recipe for handling phi nodes of integer and floating-point inductions, +/// producing their vector and scalar values. +class VPWidenIntOrFpInductionRecipe : public VPRecipeBase { +private: + PHINode *IV; + TruncInst *Trunc; + +public: + VPWidenIntOrFpInductionRecipe(PHINode *IV, TruncInst *Trunc = nullptr) + : VPRecipeBase(VPWidenIntOrFpInductionSC), IV(IV), Trunc(Trunc) {} + ~VPWidenIntOrFpInductionRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPWidenIntOrFpInductionSC; + } + + /// Generate the vectorized and scalarized versions of the phi node as + /// needed by their users. + void execute(VPTransformState &State) override; + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; +}; + +/// A recipe for handling all phi nodes except for integer and FP inductions. +class VPWidenPHIRecipe : public VPRecipeBase { +private: + PHINode *Phi; + +public: + VPWidenPHIRecipe(PHINode *Phi) : VPRecipeBase(VPWidenPHISC), Phi(Phi) {} + ~VPWidenPHIRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPWidenPHISC; + } + + /// Generate the phi/select nodes. + void execute(VPTransformState &State) override; + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; +}; + +/// A recipe for vectorizing a phi-node as a sequence of mask-based select +/// instructions. +class VPBlendRecipe : public VPRecipeBase { +private: + PHINode *Phi; + + /// The blend operation is a User of a mask, if not null. + std::unique_ptr<VPUser> User; + +public: + VPBlendRecipe(PHINode *Phi, ArrayRef<VPValue *> Masks) + : VPRecipeBase(VPBlendSC), Phi(Phi) { + assert((Phi->getNumIncomingValues() == 1 || + Phi->getNumIncomingValues() == Masks.size()) && + "Expected the same number of incoming values and masks"); + if (!Masks.empty()) + User.reset(new VPUser(Masks)); + } + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPBlendSC; + } + + /// Generate the phi/select nodes. + void execute(VPTransformState &State) override; + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; +}; + +/// VPInterleaveRecipe is a recipe for transforming an interleave group of load +/// or stores into one wide load/store and shuffles. +class VPInterleaveRecipe : public VPRecipeBase { +private: + const InterleaveGroup *IG; + +public: + VPInterleaveRecipe(const InterleaveGroup *IG) + : VPRecipeBase(VPInterleaveSC), IG(IG) {} + ~VPInterleaveRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPInterleaveSC; + } + + /// Generate the wide load or store, and shuffles. + void execute(VPTransformState &State) override; + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; + + const InterleaveGroup *getInterleaveGroup() { return IG; } +}; + +/// VPReplicateRecipe replicates a given instruction producing multiple scalar +/// copies of the original scalar type, one per lane, instead of producing a +/// single copy of widened type for all lanes. If the instruction is known to be +/// uniform only one copy, per lane zero, will be generated. +class VPReplicateRecipe : public VPRecipeBase { +private: + /// The instruction being replicated. + Instruction *Ingredient; + + /// Indicator if only a single replica per lane is needed. + bool IsUniform; + + /// Indicator if the replicas are also predicated. + bool IsPredicated; + + /// Indicator if the scalar values should also be packed into a vector. + bool AlsoPack; + +public: + VPReplicateRecipe(Instruction *I, bool IsUniform, bool IsPredicated = false) + : VPRecipeBase(VPReplicateSC), Ingredient(I), IsUniform(IsUniform), + IsPredicated(IsPredicated) { + // Retain the previous behavior of predicateInstructions(), where an + // insert-element of a predicated instruction got hoisted into the + // predicated basic block iff it was its only user. This is achieved by + // having predicated instructions also pack their values into a vector by + // default unless they have a replicated user which uses their scalar value. + AlsoPack = IsPredicated && !I->use_empty(); + } + + ~VPReplicateRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPReplicateSC; + } + + /// Generate replicas of the desired Ingredient. Replicas will be generated + /// for all parts and lanes unless a specific part and lane are specified in + /// the \p State. + void execute(VPTransformState &State) override; + + void setAlsoPack(bool Pack) { AlsoPack = Pack; } + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; +}; + +/// A recipe for generating conditional branches on the bits of a mask. +class VPBranchOnMaskRecipe : public VPRecipeBase { +private: + std::unique_ptr<VPUser> User; + +public: + VPBranchOnMaskRecipe(VPValue *BlockInMask) : VPRecipeBase(VPBranchOnMaskSC) { + if (BlockInMask) // nullptr means all-one mask. + User.reset(new VPUser({BlockInMask})); + } + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPBranchOnMaskSC; + } + + /// Generate the extraction of the appropriate bit from the block mask and the + /// conditional branch. + void execute(VPTransformState &State) override; + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override { + O << " +\n" << Indent << "\"BRANCH-ON-MASK "; + if (User) + O << *User->getOperand(0); + else + O << " All-One"; + O << "\\l\""; + } +}; + +/// VPPredInstPHIRecipe is a recipe for generating the phi nodes needed when +/// control converges back from a Branch-on-Mask. The phi nodes are needed in +/// order to merge values that are set under such a branch and feed their uses. +/// The phi nodes can be scalar or vector depending on the users of the value. +/// This recipe works in concert with VPBranchOnMaskRecipe. +class VPPredInstPHIRecipe : public VPRecipeBase { +private: + Instruction *PredInst; + +public: + /// Construct a VPPredInstPHIRecipe given \p PredInst whose value needs a phi + /// nodes after merging back from a Branch-on-Mask. + VPPredInstPHIRecipe(Instruction *PredInst) + : VPRecipeBase(VPPredInstPHISC), PredInst(PredInst) {} + ~VPPredInstPHIRecipe() override = default; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPPredInstPHISC; + } + + /// Generates phi nodes for live-outs as needed to retain SSA form. + void execute(VPTransformState &State) override; + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; +}; + +/// A Recipe for widening load/store operations. +/// TODO: We currently execute only per-part unless a specific instance is +/// provided. +class VPWidenMemoryInstructionRecipe : public VPRecipeBase { +private: + Instruction &Instr; + std::unique_ptr<VPUser> User; + +public: + VPWidenMemoryInstructionRecipe(Instruction &Instr, VPValue *Mask) + : VPRecipeBase(VPWidenMemoryInstructionSC), Instr(Instr) { + if (Mask) // Create a VPInstruction to register as a user of the mask. + User.reset(new VPUser({Mask})); + } + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPRecipeBase *V) { + return V->getVPRecipeID() == VPRecipeBase::VPWidenMemoryInstructionSC; + } + + /// Generate the wide load/store. + void execute(VPTransformState &State) override; + + /// Print the recipe. + void print(raw_ostream &O, const Twine &Indent) const override; +}; + +/// VPBasicBlock serves as the leaf of the Hierarchical Control-Flow Graph. It +/// holds a sequence of zero or more VPRecipe's each representing a sequence of +/// output IR instructions. +class VPBasicBlock : public VPBlockBase { +public: + using RecipeListTy = iplist<VPRecipeBase>; + +private: + /// The VPRecipes held in the order of output instructions to generate. + RecipeListTy Recipes; + +public: + VPBasicBlock(const Twine &Name = "", VPRecipeBase *Recipe = nullptr) + : VPBlockBase(VPBasicBlockSC, Name.str()) { + if (Recipe) + appendRecipe(Recipe); + } + + ~VPBasicBlock() override { Recipes.clear(); } + + /// Instruction iterators... + using iterator = RecipeListTy::iterator; + using const_iterator = RecipeListTy::const_iterator; + using reverse_iterator = RecipeListTy::reverse_iterator; + using const_reverse_iterator = RecipeListTy::const_reverse_iterator; + + //===--------------------------------------------------------------------===// + /// Recipe iterator methods + /// + inline iterator begin() { return Recipes.begin(); } + inline const_iterator begin() const { return Recipes.begin(); } + inline iterator end() { return Recipes.end(); } + inline const_iterator end() const { return Recipes.end(); } + + inline reverse_iterator rbegin() { return Recipes.rbegin(); } + inline const_reverse_iterator rbegin() const { return Recipes.rbegin(); } + inline reverse_iterator rend() { return Recipes.rend(); } + inline const_reverse_iterator rend() const { return Recipes.rend(); } + + inline size_t size() const { return Recipes.size(); } + inline bool empty() const { return Recipes.empty(); } + inline const VPRecipeBase &front() const { return Recipes.front(); } + inline VPRecipeBase &front() { return Recipes.front(); } + inline const VPRecipeBase &back() const { return Recipes.back(); } + inline VPRecipeBase &back() { return Recipes.back(); } + + /// \brief Returns a pointer to a member of the recipe list. + static RecipeListTy VPBasicBlock::*getSublistAccess(VPRecipeBase *) { + return &VPBasicBlock::Recipes; + } + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPBlockBase *V) { + return V->getVPBlockID() == VPBlockBase::VPBasicBlockSC; + } + + void insert(VPRecipeBase *Recipe, iterator InsertPt) { + assert(Recipe && "No recipe to append."); + assert(!Recipe->Parent && "Recipe already in VPlan"); + Recipe->Parent = this; + Recipes.insert(InsertPt, Recipe); + } + + /// Augment the existing recipes of a VPBasicBlock with an additional + /// \p Recipe as the last recipe. + void appendRecipe(VPRecipeBase *Recipe) { insert(Recipe, end()); } + + /// The method which generates the output IR instructions that correspond to + /// this VPBasicBlock, thereby "executing" the VPlan. + void execute(struct VPTransformState *State) override; + +private: + /// Create an IR BasicBlock to hold the output instructions generated by this + /// VPBasicBlock, and return it. Update the CFGState accordingly. + BasicBlock *createEmptyBasicBlock(VPTransformState::CFGState &CFG); +}; + +/// VPRegionBlock represents a collection of VPBasicBlocks and VPRegionBlocks +/// which form a Single-Entry-Single-Exit subgraph of the output IR CFG. +/// A VPRegionBlock may indicate that its contents are to be replicated several +/// times. This is designed to support predicated scalarization, in which a +/// scalar if-then code structure needs to be generated VF * UF times. Having +/// this replication indicator helps to keep a single model for multiple +/// candidate VF's. The actual replication takes place only once the desired VF +/// and UF have been determined. +class VPRegionBlock : public VPBlockBase { +private: + /// Hold the Single Entry of the SESE region modelled by the VPRegionBlock. + VPBlockBase *Entry; + + /// Hold the Single Exit of the SESE region modelled by the VPRegionBlock. + VPBlockBase *Exit; + + /// An indicator whether this region is to generate multiple replicated + /// instances of output IR corresponding to its VPBlockBases. + bool IsReplicator; + +public: + VPRegionBlock(VPBlockBase *Entry, VPBlockBase *Exit, + const std::string &Name = "", bool IsReplicator = false) + : VPBlockBase(VPRegionBlockSC, Name), Entry(Entry), Exit(Exit), + IsReplicator(IsReplicator) { + assert(Entry->getPredecessors().empty() && "Entry block has predecessors."); + assert(Exit->getSuccessors().empty() && "Exit block has successors."); + Entry->setParent(this); + Exit->setParent(this); + } + + ~VPRegionBlock() override { + if (Entry) + deleteCFG(Entry); + } + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPBlockBase *V) { + return V->getVPBlockID() == VPBlockBase::VPRegionBlockSC; + } + + const VPBlockBase *getEntry() const { return Entry; } + VPBlockBase *getEntry() { return Entry; } + + const VPBlockBase *getExit() const { return Exit; } + VPBlockBase *getExit() { return Exit; } + + /// An indicator whether this region is to generate multiple replicated + /// instances of output IR corresponding to its VPBlockBases. + bool isReplicator() const { return IsReplicator; } + + /// The method which generates the output IR instructions that correspond to + /// this VPRegionBlock, thereby "executing" the VPlan. + void execute(struct VPTransformState *State) override; +}; + +/// VPlan models a candidate for vectorization, encoding various decisions take +/// to produce efficient output IR, including which branches, basic-blocks and +/// output IR instructions to generate, and their cost. VPlan holds a +/// Hierarchical-CFG of VPBasicBlocks and VPRegionBlocks rooted at an Entry +/// VPBlock. +class VPlan { + friend class VPlanPrinter; + +private: + /// Hold the single entry to the Hierarchical CFG of the VPlan. + VPBlockBase *Entry; + + /// Holds the VFs applicable to this VPlan. + SmallSet<unsigned, 2> VFs; + + /// Holds the name of the VPlan, for printing. + std::string Name; + + /// Holds a mapping between Values and their corresponding VPValue inside + /// VPlan. + Value2VPValueTy Value2VPValue; + +public: + VPlan(VPBlockBase *Entry = nullptr) : Entry(Entry) {} + + ~VPlan() { + if (Entry) + VPBlockBase::deleteCFG(Entry); + for (auto &MapEntry : Value2VPValue) + delete MapEntry.second; + } + + /// Generate the IR code for this VPlan. + void execute(struct VPTransformState *State); + + VPBlockBase *getEntry() { return Entry; } + const VPBlockBase *getEntry() const { return Entry; } + + VPBlockBase *setEntry(VPBlockBase *Block) { return Entry = Block; } + + void addVF(unsigned VF) { VFs.insert(VF); } + + bool hasVF(unsigned VF) { return VFs.count(VF); } + + const std::string &getName() const { return Name; } + + void setName(const Twine &newName) { Name = newName.str(); } + + void addVPValue(Value *V) { + assert(V && "Trying to add a null Value to VPlan"); + assert(!Value2VPValue.count(V) && "Value already exists in VPlan"); + Value2VPValue[V] = new VPValue(); + } + + VPValue *getVPValue(Value *V) { + assert(V && "Trying to get the VPValue of a null Value"); + assert(Value2VPValue.count(V) && "Value does not exist in VPlan"); + return Value2VPValue[V]; + } + +private: + /// Add to the given dominator tree the header block and every new basic block + /// that was created between it and the latch block, inclusive. + static void updateDominatorTree(DominatorTree *DT, + BasicBlock *LoopPreHeaderBB, + BasicBlock *LoopLatchBB); +}; + +/// VPlanPrinter prints a given VPlan to a given output stream. The printing is +/// indented and follows the dot format. +class VPlanPrinter { + friend inline raw_ostream &operator<<(raw_ostream &OS, VPlan &Plan); + friend inline raw_ostream &operator<<(raw_ostream &OS, + const struct VPlanIngredient &I); + +private: + raw_ostream &OS; + VPlan &Plan; + unsigned Depth; + unsigned TabWidth = 2; + std::string Indent; + unsigned BID = 0; + SmallDenseMap<const VPBlockBase *, unsigned> BlockID; + + VPlanPrinter(raw_ostream &O, VPlan &P) : OS(O), Plan(P) {} + + /// Handle indentation. + void bumpIndent(int b) { Indent = std::string((Depth += b) * TabWidth, ' '); } + + /// Print a given \p Block of the Plan. + void dumpBlock(const VPBlockBase *Block); + + /// Print the information related to the CFG edges going out of a given + /// \p Block, followed by printing the successor blocks themselves. + void dumpEdges(const VPBlockBase *Block); + + /// Print a given \p BasicBlock, including its VPRecipes, followed by printing + /// its successor blocks. + void dumpBasicBlock(const VPBasicBlock *BasicBlock); + + /// Print a given \p Region of the Plan. + void dumpRegion(const VPRegionBlock *Region); + + unsigned getOrCreateBID(const VPBlockBase *Block) { + return BlockID.count(Block) ? BlockID[Block] : BlockID[Block] = BID++; + } + + const Twine getOrCreateName(const VPBlockBase *Block); + + const Twine getUID(const VPBlockBase *Block); + + /// Print the information related to a CFG edge between two VPBlockBases. + void drawEdge(const VPBlockBase *From, const VPBlockBase *To, bool Hidden, + const Twine &Label); + + void dump(); + + static void printAsIngredient(raw_ostream &O, Value *V); +}; + +struct VPlanIngredient { + Value *V; + + VPlanIngredient(Value *V) : V(V) {} +}; + +inline raw_ostream &operator<<(raw_ostream &OS, const VPlanIngredient &I) { + VPlanPrinter::printAsIngredient(OS, I.V); + return OS; +} + +inline raw_ostream &operator<<(raw_ostream &OS, VPlan &Plan) { + VPlanPrinter Printer(OS, Plan); + Printer.dump(); + return OS; +} + +//===--------------------------------------------------------------------===// +// GraphTraits specializations for VPlan/VPRegionBlock Control-Flow Graphs // +//===--------------------------------------------------------------------===// + +// Provide specializations of GraphTraits to be able to treat a VPBlockBase as a +// graph of VPBlockBase nodes... + +template <> struct GraphTraits<VPBlockBase *> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::iterator; + + static NodeRef getEntryNode(NodeRef N) { return N; } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +template <> struct GraphTraits<const VPBlockBase *> { + using NodeRef = const VPBlockBase *; + using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::const_iterator; + + static NodeRef getEntryNode(NodeRef N) { return N; } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getSuccessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getSuccessors().end(); + } +}; + +// Provide specializations of GraphTraits to be able to treat a VPBlockBase as a +// graph of VPBlockBase nodes... and to walk it in inverse order. Inverse order +// for a VPBlockBase is considered to be when traversing the predecessors of a +// VPBlockBase instead of its successors. +template <> struct GraphTraits<Inverse<VPBlockBase *>> { + using NodeRef = VPBlockBase *; + using ChildIteratorType = SmallVectorImpl<VPBlockBase *>::iterator; + + static Inverse<VPBlockBase *> getEntryNode(Inverse<VPBlockBase *> B) { + return B; + } + + static inline ChildIteratorType child_begin(NodeRef N) { + return N->getPredecessors().begin(); + } + + static inline ChildIteratorType child_end(NodeRef N) { + return N->getPredecessors().end(); + } +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_H diff --git a/lib/Transforms/Vectorize/VPlanBuilder.h b/lib/Transforms/Vectorize/VPlanBuilder.h new file mode 100644 index 000000000000..d6eb3397d044 --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanBuilder.h @@ -0,0 +1,61 @@ +//===- VPlanBuilder.h - A VPlan utility for constructing VPInstructions ---===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file provides a VPlan-based builder utility analogous to IRBuilder. +/// It provides an instruction-level API for generating VPInstructions while +/// abstracting away the Recipe manipulation details. +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_BUILDER_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_BUILDER_H + +#include "VPlan.h" + +namespace llvm { + +class VPBuilder { +private: + VPBasicBlock *BB = nullptr; + VPBasicBlock::iterator InsertPt = VPBasicBlock::iterator(); + + VPInstruction *createInstruction(unsigned Opcode, + std::initializer_list<VPValue *> Operands) { + VPInstruction *Instr = new VPInstruction(Opcode, Operands); + BB->insert(Instr, InsertPt); + return Instr; + } + +public: + VPBuilder() {} + + /// \brief This specifies that created VPInstructions should be appended to + /// the end of the specified block. + void setInsertPoint(VPBasicBlock *TheBB) { + assert(TheBB && "Attempting to set a null insert point"); + BB = TheBB; + InsertPt = BB->end(); + } + + VPValue *createNot(VPValue *Operand) { + return createInstruction(VPInstruction::Not, {Operand}); + } + + VPValue *createAnd(VPValue *LHS, VPValue *RHS) { + return createInstruction(Instruction::BinaryOps::And, {LHS, RHS}); + } + + VPValue *createOr(VPValue *LHS, VPValue *RHS) { + return createInstruction(Instruction::BinaryOps::Or, {LHS, RHS}); + } +}; + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_BUILDER_H diff --git a/lib/Transforms/Vectorize/VPlanValue.h b/lib/Transforms/Vectorize/VPlanValue.h new file mode 100644 index 000000000000..50966891e0eb --- /dev/null +++ b/lib/Transforms/Vectorize/VPlanValue.h @@ -0,0 +1,146 @@ +//===- VPlanValue.h - Represent Values in Vectorizer Plan -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// +/// \file +/// This file contains the declarations of the entities induced by Vectorization +/// Plans, e.g. the instructions the VPlan intends to generate if executed. +/// VPlan models the following entities: +/// VPValue +/// |-- VPUser +/// | |-- VPInstruction +/// These are documented in docs/VectorizationPlan.rst. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_VECTORIZE_VPLAN_VALUE_H +#define LLVM_TRANSFORMS_VECTORIZE_VPLAN_VALUE_H + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Value.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" + +namespace llvm { + +// Forward declarations. +class VPUser; + +// This is the base class of the VPlan Def/Use graph, used for modeling the data +// flow into, within and out of the VPlan. VPValues can stand for live-ins +// coming from the input IR, instructions which VPlan will generate if executed +// and live-outs which the VPlan will need to fix accordingly. +class VPValue { +private: + const unsigned char SubclassID; ///< Subclass identifier (for isa/dyn_cast). + + SmallVector<VPUser *, 1> Users; + +protected: + VPValue(const unsigned char SC) : SubclassID(SC) {} + +public: + /// An enumeration for keeping track of the concrete subclass of VPValue that + /// are actually instantiated. Values of this enumeration are kept in the + /// SubclassID field of the VPValue objects. They are used for concrete + /// type identification. + enum { VPValueSC, VPUserSC, VPInstructionSC }; + + VPValue() : SubclassID(VPValueSC) {} + VPValue(const VPValue &) = delete; + VPValue &operator=(const VPValue &) = delete; + + /// \return an ID for the concrete type of this object. + /// This is used to implement the classof checks. This should not be used + /// for any other purpose, as the values may change as LLVM evolves. + unsigned getVPValueID() const { return SubclassID; } + + void printAsOperand(raw_ostream &OS) const { + OS << "%vp" << (unsigned short)(unsigned long long)this; + } + + unsigned getNumUsers() const { return Users.size(); } + void addUser(VPUser &User) { Users.push_back(&User); } + + typedef SmallVectorImpl<VPUser *>::iterator user_iterator; + typedef SmallVectorImpl<VPUser *>::const_iterator const_user_iterator; + typedef iterator_range<user_iterator> user_range; + typedef iterator_range<const_user_iterator> const_user_range; + + user_iterator user_begin() { return Users.begin(); } + const_user_iterator user_begin() const { return Users.begin(); } + user_iterator user_end() { return Users.end(); } + const_user_iterator user_end() const { return Users.end(); } + user_range users() { return user_range(user_begin(), user_end()); } + const_user_range users() const { + return const_user_range(user_begin(), user_end()); + } +}; + +typedef DenseMap<Value *, VPValue *> Value2VPValueTy; +typedef DenseMap<VPValue *, Value *> VPValue2ValueTy; + +raw_ostream &operator<<(raw_ostream &OS, const VPValue &V); + +/// This class augments VPValue with operands which provide the inverse def-use +/// edges from VPValue's users to their defs. +class VPUser : public VPValue { +private: + SmallVector<VPValue *, 2> Operands; + + void addOperand(VPValue *Operand) { + Operands.push_back(Operand); + Operand->addUser(*this); + } + +protected: + VPUser(const unsigned char SC) : VPValue(SC) {} + VPUser(const unsigned char SC, ArrayRef<VPValue *> Operands) : VPValue(SC) { + for (VPValue *Operand : Operands) + addOperand(Operand); + } + +public: + VPUser() : VPValue(VPValue::VPUserSC) {} + VPUser(ArrayRef<VPValue *> Operands) : VPUser(VPValue::VPUserSC, Operands) {} + VPUser(std::initializer_list<VPValue *> Operands) + : VPUser(ArrayRef<VPValue *>(Operands)) {} + VPUser(const VPUser &) = delete; + VPUser &operator=(const VPUser &) = delete; + + /// Method to support type inquiry through isa, cast, and dyn_cast. + static inline bool classof(const VPValue *V) { + return V->getVPValueID() >= VPUserSC && + V->getVPValueID() <= VPInstructionSC; + } + + unsigned getNumOperands() const { return Operands.size(); } + inline VPValue *getOperand(unsigned N) const { + assert(N < Operands.size() && "Operand index out of bounds"); + return Operands[N]; + } + + typedef SmallVectorImpl<VPValue *>::iterator operand_iterator; + typedef SmallVectorImpl<VPValue *>::const_iterator const_operand_iterator; + typedef iterator_range<operand_iterator> operand_range; + typedef iterator_range<const_operand_iterator> const_operand_range; + + operand_iterator op_begin() { return Operands.begin(); } + const_operand_iterator op_begin() const { return Operands.begin(); } + operand_iterator op_end() { return Operands.end(); } + const_operand_iterator op_end() const { return Operands.end(); } + operand_range operands() { return operand_range(op_begin(), op_end()); } + const_operand_range operands() const { + return const_operand_range(op_begin(), op_end()); + } +}; + +} // namespace llvm + +#endif // LLVM_TRANSFORMS_VECTORIZE_VPLAN_VALUE_H diff --git a/lib/Transforms/Vectorize/Vectorize.cpp b/lib/Transforms/Vectorize/Vectorize.cpp index fb2f509dcbaa..b04905bfc6fa 100644 --- a/lib/Transforms/Vectorize/Vectorize.cpp +++ b/lib/Transforms/Vectorize/Vectorize.cpp @@ -18,7 +18,6 @@ #include "llvm-c/Transforms/Vectorize.h" #include "llvm/Analysis/Passes.h" #include "llvm/IR/LegacyPassManager.h" -#include "llvm/IR/Verifier.h" #include "llvm/InitializePasses.h" using namespace llvm; |