diff options
Diffstat (limited to 'contrib/llvm/lib/Transforms/Scalar')
67 files changed, 70232 insertions, 0 deletions
diff --git a/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp b/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp new file mode 100644 index 000000000000..1e683db50206 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/ADCE.cpp @@ -0,0 +1,730 @@ +//===- ADCE.cpp - Code to perform dead code elimination -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Aggressive Dead Code Elimination pass. This pass +// optimistically assumes that all instructions are dead until proven otherwise, +// allowing it to eliminate dead computations that other DCE passes do not +// catch, particularly involving loop computations. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/ADCE.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/ProfileData/InstrProf.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include <cassert> +#include <cstddef> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "adce" + +STATISTIC(NumRemoved, "Number of instructions removed"); +STATISTIC(NumBranchesRemoved, "Number of branch instructions removed"); + +// This is a temporary option until we change the interface to this pass based +// on optimization level. +static cl::opt<bool> RemoveControlFlowFlag("adce-remove-control-flow", + cl::init(true), cl::Hidden); + +// This option enables removing of may-be-infinite loops which have no other +// effect. +static cl::opt<bool> RemoveLoops("adce-remove-loops", cl::init(false), + cl::Hidden); + +namespace { + +/// Information about Instructions +struct InstInfoType { + /// True if the associated instruction is live. + bool Live = false; + + /// Quick access to information for block containing associated Instruction. + struct BlockInfoType *Block = nullptr; +}; + +/// Information about basic blocks relevant to dead code elimination. +struct BlockInfoType { + /// True when this block contains a live instructions. + bool Live = false; + + /// True when this block ends in an unconditional branch. + bool UnconditionalBranch = false; + + /// True when this block is known to have live PHI nodes. + bool HasLivePhiNodes = false; + + /// Control dependence sources need to be live for this block. + bool CFLive = false; + + /// Quick access to the LiveInfo for the terminator, + /// holds the value &InstInfo[Terminator] + InstInfoType *TerminatorLiveInfo = nullptr; + + /// Corresponding BasicBlock. + BasicBlock *BB = nullptr; + + /// Cache of BB->getTerminator(). + TerminatorInst *Terminator = nullptr; + + /// Post-order numbering of reverse control flow graph. + unsigned PostOrder; + + bool terminatorIsLive() const { return TerminatorLiveInfo->Live; } +}; + +class AggressiveDeadCodeElimination { + Function &F; + + // ADCE does not use DominatorTree per se, but it updates it to preserve the + // analysis. + DominatorTree &DT; + PostDominatorTree &PDT; + + /// Mapping of blocks to associated information, an element in BlockInfoVec. + /// Use MapVector to get deterministic iteration order. + MapVector<BasicBlock *, BlockInfoType> BlockInfo; + bool isLive(BasicBlock *BB) { return BlockInfo[BB].Live; } + + /// Mapping of instructions to associated information. + DenseMap<Instruction *, InstInfoType> InstInfo; + bool isLive(Instruction *I) { return InstInfo[I].Live; } + + /// Instructions known to be live where we need to mark + /// reaching definitions as live. + SmallVector<Instruction *, 128> Worklist; + + /// Debug info scopes around a live instruction. + SmallPtrSet<const Metadata *, 32> AliveScopes; + + /// Set of blocks with not known to have live terminators. + SmallPtrSet<BasicBlock *, 16> BlocksWithDeadTerminators; + + /// The set of blocks which we have determined whose control + /// dependence sources must be live and which have not had + /// those dependences analyzed. + SmallPtrSet<BasicBlock *, 16> NewLiveBlocks; + + /// Set up auxiliary data structures for Instructions and BasicBlocks and + /// initialize the Worklist to the set of must-be-live Instruscions. + void initialize(); + + /// Return true for operations which are always treated as live. + bool isAlwaysLive(Instruction &I); + + /// Return true for instrumentation instructions for value profiling. + bool isInstrumentsConstant(Instruction &I); + + /// Propagate liveness to reaching definitions. + void markLiveInstructions(); + + /// Mark an instruction as live. + void markLive(Instruction *I); + + /// Mark a block as live. + void markLive(BlockInfoType &BB); + void markLive(BasicBlock *BB) { markLive(BlockInfo[BB]); } + + /// Mark terminators of control predecessors of a PHI node live. + void markPhiLive(PHINode *PN); + + /// Record the Debug Scopes which surround live debug information. + void collectLiveScopes(const DILocalScope &LS); + void collectLiveScopes(const DILocation &DL); + + /// Analyze dead branches to find those whose branches are the sources + /// of control dependences impacting a live block. Those branches are + /// marked live. + void markLiveBranchesFromControlDependences(); + + /// Remove instructions not marked live, return if any any instruction + /// was removed. + bool removeDeadInstructions(); + + /// Identify connected sections of the control flow graph which have + /// dead terminators and rewrite the control flow graph to remove them. + void updateDeadRegions(); + + /// Set the BlockInfo::PostOrder field based on a post-order + /// numbering of the reverse control flow graph. + void computeReversePostOrder(); + + /// Make the terminator of this block an unconditional branch to \p Target. + void makeUnconditional(BasicBlock *BB, BasicBlock *Target); + +public: + AggressiveDeadCodeElimination(Function &F, DominatorTree &DT, + PostDominatorTree &PDT) + : F(F), DT(DT), PDT(PDT) {} + + bool performDeadCodeElimination(); +}; + +} // end anonymous namespace + +bool AggressiveDeadCodeElimination::performDeadCodeElimination() { + initialize(); + markLiveInstructions(); + return removeDeadInstructions(); +} + +static bool isUnconditionalBranch(TerminatorInst *Term) { + auto *BR = dyn_cast<BranchInst>(Term); + return BR && BR->isUnconditional(); +} + +void AggressiveDeadCodeElimination::initialize() { + auto NumBlocks = F.size(); + + // We will have an entry in the map for each block so we grow the + // structure to twice that size to keep the load factor low in the hash table. + BlockInfo.reserve(NumBlocks); + size_t NumInsts = 0; + + // Iterate over blocks and initialize BlockInfoVec entries, count + // instructions to size the InstInfo hash table. + for (auto &BB : F) { + NumInsts += BB.size(); + auto &Info = BlockInfo[&BB]; + Info.BB = &BB; + Info.Terminator = BB.getTerminator(); + Info.UnconditionalBranch = isUnconditionalBranch(Info.Terminator); + } + + // Initialize instruction map and set pointers to block info. + InstInfo.reserve(NumInsts); + for (auto &BBInfo : BlockInfo) + for (Instruction &I : *BBInfo.second.BB) + InstInfo[&I].Block = &BBInfo.second; + + // Since BlockInfoVec holds pointers into InstInfo and vice-versa, we may not + // add any more elements to either after this point. + for (auto &BBInfo : BlockInfo) + BBInfo.second.TerminatorLiveInfo = &InstInfo[BBInfo.second.Terminator]; + + // Collect the set of "root" instructions that are known live. + for (Instruction &I : instructions(F)) + if (isAlwaysLive(I)) + markLive(&I); + + if (!RemoveControlFlowFlag) + return; + + if (!RemoveLoops) { + // This stores state for the depth-first iterator. In addition + // to recording which nodes have been visited we also record whether + // a node is currently on the "stack" of active ancestors of the current + // node. + using StatusMap = DenseMap<BasicBlock *, bool>; + + class DFState : public StatusMap { + public: + std::pair<StatusMap::iterator, bool> insert(BasicBlock *BB) { + return StatusMap::insert(std::make_pair(BB, true)); + } + + // Invoked after we have visited all children of a node. + void completed(BasicBlock *BB) { (*this)[BB] = false; } + + // Return true if \p BB is currently on the active stack + // of ancestors. + bool onStack(BasicBlock *BB) { + auto Iter = find(BB); + return Iter != end() && Iter->second; + } + } State; + + State.reserve(F.size()); + // Iterate over blocks in depth-first pre-order and + // treat all edges to a block already seen as loop back edges + // and mark the branch live it if there is a back edge. + for (auto *BB: depth_first_ext(&F.getEntryBlock(), State)) { + TerminatorInst *Term = BB->getTerminator(); + if (isLive(Term)) + continue; + + for (auto *Succ : successors(BB)) + if (State.onStack(Succ)) { + // back edge.... + markLive(Term); + break; + } + } + } + + // Mark blocks live if there is no path from the block to a + // return of the function. + // We do this by seeing which of the postdomtree root children exit the + // program, and for all others, mark the subtree live. + for (auto &PDTChild : children<DomTreeNode *>(PDT.getRootNode())) { + auto *BB = PDTChild->getBlock(); + auto &Info = BlockInfo[BB]; + // Real function return + if (isa<ReturnInst>(Info.Terminator)) { + DEBUG(dbgs() << "post-dom root child is a return: " << BB->getName() + << '\n';); + continue; + } + + // This child is something else, like an infinite loop. + for (auto DFNode : depth_first(PDTChild)) + markLive(BlockInfo[DFNode->getBlock()].Terminator); + } + + // Treat the entry block as always live + auto *BB = &F.getEntryBlock(); + auto &EntryInfo = BlockInfo[BB]; + EntryInfo.Live = true; + if (EntryInfo.UnconditionalBranch) + markLive(EntryInfo.Terminator); + + // Build initial collection of blocks with dead terminators + for (auto &BBInfo : BlockInfo) + if (!BBInfo.second.terminatorIsLive()) + BlocksWithDeadTerminators.insert(BBInfo.second.BB); +} + +bool AggressiveDeadCodeElimination::isAlwaysLive(Instruction &I) { + // TODO -- use llvm::isInstructionTriviallyDead + if (I.isEHPad() || I.mayHaveSideEffects()) { + // Skip any value profile instrumentation calls if they are + // instrumenting constants. + if (isInstrumentsConstant(I)) + return false; + return true; + } + if (!isa<TerminatorInst>(I)) + return false; + if (RemoveControlFlowFlag && (isa<BranchInst>(I) || isa<SwitchInst>(I))) + return false; + return true; +} + +// Check if this instruction is a runtime call for value profiling and +// if it's instrumenting a constant. +bool AggressiveDeadCodeElimination::isInstrumentsConstant(Instruction &I) { + // TODO -- move this test into llvm::isInstructionTriviallyDead + if (CallInst *CI = dyn_cast<CallInst>(&I)) + if (Function *Callee = CI->getCalledFunction()) + if (Callee->getName().equals(getInstrProfValueProfFuncName())) + if (isa<Constant>(CI->getArgOperand(0))) + return true; + return false; +} + +void AggressiveDeadCodeElimination::markLiveInstructions() { + // Propagate liveness backwards to operands. + do { + // Worklist holds newly discovered live instructions + // where we need to mark the inputs as live. + while (!Worklist.empty()) { + Instruction *LiveInst = Worklist.pop_back_val(); + DEBUG(dbgs() << "work live: "; LiveInst->dump();); + + for (Use &OI : LiveInst->operands()) + if (Instruction *Inst = dyn_cast<Instruction>(OI)) + markLive(Inst); + + if (auto *PN = dyn_cast<PHINode>(LiveInst)) + markPhiLive(PN); + } + + // After data flow liveness has been identified, examine which branch + // decisions are required to determine live instructions are executed. + markLiveBranchesFromControlDependences(); + + } while (!Worklist.empty()); +} + +void AggressiveDeadCodeElimination::markLive(Instruction *I) { + auto &Info = InstInfo[I]; + if (Info.Live) + return; + + DEBUG(dbgs() << "mark live: "; I->dump()); + Info.Live = true; + Worklist.push_back(I); + + // Collect the live debug info scopes attached to this instruction. + if (const DILocation *DL = I->getDebugLoc()) + collectLiveScopes(*DL); + + // Mark the containing block live + auto &BBInfo = *Info.Block; + if (BBInfo.Terminator == I) { + BlocksWithDeadTerminators.erase(BBInfo.BB); + // For live terminators, mark destination blocks + // live to preserve this control flow edges. + if (!BBInfo.UnconditionalBranch) + for (auto *BB : successors(I->getParent())) + markLive(BB); + } + markLive(BBInfo); +} + +void AggressiveDeadCodeElimination::markLive(BlockInfoType &BBInfo) { + if (BBInfo.Live) + return; + DEBUG(dbgs() << "mark block live: " << BBInfo.BB->getName() << '\n'); + BBInfo.Live = true; + if (!BBInfo.CFLive) { + BBInfo.CFLive = true; + NewLiveBlocks.insert(BBInfo.BB); + } + + // Mark unconditional branches at the end of live + // blocks as live since there is no work to do for them later + if (BBInfo.UnconditionalBranch) + markLive(BBInfo.Terminator); +} + +void AggressiveDeadCodeElimination::collectLiveScopes(const DILocalScope &LS) { + if (!AliveScopes.insert(&LS).second) + return; + + if (isa<DISubprogram>(LS)) + return; + + // Tail-recurse through the scope chain. + collectLiveScopes(cast<DILocalScope>(*LS.getScope())); +} + +void AggressiveDeadCodeElimination::collectLiveScopes(const DILocation &DL) { + // Even though DILocations are not scopes, shove them into AliveScopes so we + // don't revisit them. + if (!AliveScopes.insert(&DL).second) + return; + + // Collect live scopes from the scope chain. + collectLiveScopes(*DL.getScope()); + + // Tail-recurse through the inlined-at chain. + if (const DILocation *IA = DL.getInlinedAt()) + collectLiveScopes(*IA); +} + +void AggressiveDeadCodeElimination::markPhiLive(PHINode *PN) { + auto &Info = BlockInfo[PN->getParent()]; + // Only need to check this once per block. + if (Info.HasLivePhiNodes) + return; + Info.HasLivePhiNodes = true; + + // If a predecessor block is not live, mark it as control-flow live + // which will trigger marking live branches upon which + // that block is control dependent. + for (auto *PredBB : predecessors(Info.BB)) { + auto &Info = BlockInfo[PredBB]; + if (!Info.CFLive) { + Info.CFLive = true; + NewLiveBlocks.insert(PredBB); + } + } +} + +void AggressiveDeadCodeElimination::markLiveBranchesFromControlDependences() { + if (BlocksWithDeadTerminators.empty()) + return; + + DEBUG({ + dbgs() << "new live blocks:\n"; + for (auto *BB : NewLiveBlocks) + dbgs() << "\t" << BB->getName() << '\n'; + dbgs() << "dead terminator blocks:\n"; + for (auto *BB : BlocksWithDeadTerminators) + dbgs() << "\t" << BB->getName() << '\n'; + }); + + // The dominance frontier of a live block X in the reverse + // control graph is the set of blocks upon which X is control + // dependent. The following sequence computes the set of blocks + // which currently have dead terminators that are control + // dependence sources of a block which is in NewLiveBlocks. + + SmallVector<BasicBlock *, 32> IDFBlocks; + ReverseIDFCalculator IDFs(PDT); + IDFs.setDefiningBlocks(NewLiveBlocks); + IDFs.setLiveInBlocks(BlocksWithDeadTerminators); + IDFs.calculate(IDFBlocks); + NewLiveBlocks.clear(); + + // Dead terminators which control live blocks are now marked live. + for (auto *BB : IDFBlocks) { + DEBUG(dbgs() << "live control in: " << BB->getName() << '\n'); + markLive(BB->getTerminator()); + } +} + +//===----------------------------------------------------------------------===// +// +// Routines to update the CFG and SSA information before removing dead code. +// +//===----------------------------------------------------------------------===// +bool AggressiveDeadCodeElimination::removeDeadInstructions() { + // Updates control and dataflow around dead blocks + updateDeadRegions(); + + DEBUG({ + for (Instruction &I : instructions(F)) { + // Check if the instruction is alive. + if (isLive(&I)) + continue; + + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&I)) { + // Check if the scope of this variable location is alive. + if (AliveScopes.count(DII->getDebugLoc()->getScope())) + continue; + + // If intrinsic is pointing at a live SSA value, there may be an + // earlier optimization bug: if we know the location of the variable, + // why isn't the scope of the location alive? + if (Value *V = DII->getVariableLocation()) + if (Instruction *II = dyn_cast<Instruction>(V)) + if (isLive(II)) + dbgs() << "Dropping debug info for " << *DII << "\n"; + } + } + }); + + // The inverse of the live set is the dead set. These are those instructions + // that have no side effects and do not influence the control flow or return + // value of the function, and may therefore be deleted safely. + // NOTE: We reuse the Worklist vector here for memory efficiency. + for (Instruction &I : instructions(F)) { + // Check if the instruction is alive. + if (isLive(&I)) + continue; + + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&I)) { + // Check if the scope of this variable location is alive. + if (AliveScopes.count(DII->getDebugLoc()->getScope())) + continue; + + // Fallthrough and drop the intrinsic. + } + + // Prepare to delete. + Worklist.push_back(&I); + I.dropAllReferences(); + } + + for (Instruction *&I : Worklist) { + ++NumRemoved; + I->eraseFromParent(); + } + + return !Worklist.empty(); +} + +// A dead region is the set of dead blocks with a common live post-dominator. +void AggressiveDeadCodeElimination::updateDeadRegions() { + DEBUG({ + dbgs() << "final dead terminator blocks: " << '\n'; + for (auto *BB : BlocksWithDeadTerminators) + dbgs() << '\t' << BB->getName() + << (BlockInfo[BB].Live ? " LIVE\n" : "\n"); + }); + + // Don't compute the post ordering unless we needed it. + bool HavePostOrder = false; + + for (auto *BB : BlocksWithDeadTerminators) { + auto &Info = BlockInfo[BB]; + if (Info.UnconditionalBranch) { + InstInfo[Info.Terminator].Live = true; + continue; + } + + if (!HavePostOrder) { + computeReversePostOrder(); + HavePostOrder = true; + } + + // Add an unconditional branch to the successor closest to the + // end of the function which insures a path to the exit for each + // live edge. + BlockInfoType *PreferredSucc = nullptr; + for (auto *Succ : successors(BB)) { + auto *Info = &BlockInfo[Succ]; + if (!PreferredSucc || PreferredSucc->PostOrder < Info->PostOrder) + PreferredSucc = Info; + } + assert((PreferredSucc && PreferredSucc->PostOrder > 0) && + "Failed to find safe successor for dead branch"); + + // Collect removed successors to update the (Post)DominatorTrees. + SmallPtrSet<BasicBlock *, 4> RemovedSuccessors; + bool First = true; + for (auto *Succ : successors(BB)) { + if (!First || Succ != PreferredSucc->BB) { + Succ->removePredecessor(BB); + RemovedSuccessors.insert(Succ); + } else + First = false; + } + makeUnconditional(BB, PreferredSucc->BB); + + // Inform the dominators about the deleted CFG edges. + SmallVector<DominatorTree::UpdateType, 4> DeletedEdges; + for (auto *Succ : RemovedSuccessors) { + // It might have happened that the same successor appeared multiple times + // and the CFG edge wasn't really removed. + if (Succ != PreferredSucc->BB) { + DEBUG(dbgs() << "ADCE: (Post)DomTree edge enqueued for deletion" + << BB->getName() << " -> " << Succ->getName() << "\n"); + DeletedEdges.push_back({DominatorTree::Delete, BB, Succ}); + } + } + + DT.applyUpdates(DeletedEdges); + PDT.applyUpdates(DeletedEdges); + + NumBranchesRemoved += 1; + } +} + +// reverse top-sort order +void AggressiveDeadCodeElimination::computeReversePostOrder() { + // This provides a post-order numbering of the reverse control flow graph + // Note that it is incomplete in the presence of infinite loops but we don't + // need numbers blocks which don't reach the end of the functions since + // all branches in those blocks are forced live. + + // For each block without successors, extend the DFS from the block + // backward through the graph + SmallPtrSet<BasicBlock*, 16> Visited; + unsigned PostOrder = 0; + for (auto &BB : F) { + if (succ_begin(&BB) != succ_end(&BB)) + continue; + for (BasicBlock *Block : inverse_post_order_ext(&BB,Visited)) + BlockInfo[Block].PostOrder = PostOrder++; + } +} + +void AggressiveDeadCodeElimination::makeUnconditional(BasicBlock *BB, + BasicBlock *Target) { + TerminatorInst *PredTerm = BB->getTerminator(); + // Collect the live debug info scopes attached to this instruction. + if (const DILocation *DL = PredTerm->getDebugLoc()) + collectLiveScopes(*DL); + + // Just mark live an existing unconditional branch + if (isUnconditionalBranch(PredTerm)) { + PredTerm->setSuccessor(0, Target); + InstInfo[PredTerm].Live = true; + return; + } + DEBUG(dbgs() << "making unconditional " << BB->getName() << '\n'); + NumBranchesRemoved += 1; + IRBuilder<> Builder(PredTerm); + auto *NewTerm = Builder.CreateBr(Target); + InstInfo[NewTerm].Live = true; + if (const DILocation *DL = PredTerm->getDebugLoc()) + NewTerm->setDebugLoc(DL); + + InstInfo.erase(PredTerm); + PredTerm->eraseFromParent(); +} + +//===----------------------------------------------------------------------===// +// +// Pass Manager integration code +// +//===----------------------------------------------------------------------===// +PreservedAnalyses ADCEPass::run(Function &F, FunctionAnalysisManager &FAM) { + auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); + auto &PDT = FAM.getResult<PostDominatorTreeAnalysis>(F); + if (!AggressiveDeadCodeElimination(F, DT, PDT).performDeadCodeElimination()) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<PostDominatorTreeAnalysis>(); + return PA; +} + +namespace { + +struct ADCELegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + + ADCELegacyPass() : FunctionPass(ID) { + initializeADCELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + return AggressiveDeadCodeElimination(F, DT, PDT) + .performDeadCodeElimination(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + // We require DominatorTree here only to update and thus preserve it. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + if (!RemoveControlFlowFlag) + AU.setPreservesCFG(); + else { + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<PostDominatorTreeWrapperPass>(); + } + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; + +} // end anonymous namespace + +char ADCELegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ADCELegacyPass, "adce", + "Aggressive Dead Code Elimination", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_END(ADCELegacyPass, "adce", "Aggressive Dead Code Elimination", + false, false) + +FunctionPass *llvm::createAggressiveDCEPass() { return new ADCELegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp b/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp new file mode 100644 index 000000000000..99480f12da9e --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/AlignmentFromAssumptions.cpp @@ -0,0 +1,450 @@ +//===----------------------- AlignmentFromAssumptions.cpp -----------------===// +// Set Load/Store Alignments From Assumptions +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a ScalarEvolution-based transformation to set +// the alignments of load, stores and memory intrinsics based on the truth +// expressions of assume intrinsics. The primary motivation is to handle +// complex alignment assumptions that apply to vector loads and stores that +// appear after vectorization and unrolling. +// +//===----------------------------------------------------------------------===// + +#define AA_NAME "alignment-from-assumptions" +#define DEBUG_TYPE AA_NAME +#include "llvm/Transforms/Scalar/AlignmentFromAssumptions.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +using namespace llvm; + +STATISTIC(NumLoadAlignChanged, + "Number of loads changed by alignment assumptions"); +STATISTIC(NumStoreAlignChanged, + "Number of stores changed by alignment assumptions"); +STATISTIC(NumMemIntAlignChanged, + "Number of memory intrinsics changed by alignment assumptions"); + +namespace { +struct AlignmentFromAssumptions : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + AlignmentFromAssumptions() : FunctionPass(ID) { + initializeAlignmentFromAssumptionsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + + AU.setPreservesCFG(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); + } + + AlignmentFromAssumptionsPass Impl; +}; +} + +char AlignmentFromAssumptions::ID = 0; +static const char aip_name[] = "Alignment from assumptions"; +INITIALIZE_PASS_BEGIN(AlignmentFromAssumptions, AA_NAME, + aip_name, false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(AlignmentFromAssumptions, AA_NAME, + aip_name, false, false) + +FunctionPass *llvm::createAlignmentFromAssumptionsPass() { + return new AlignmentFromAssumptions(); +} + +// Given an expression for the (constant) alignment, AlignSCEV, and an +// expression for the displacement between a pointer and the aligned address, +// DiffSCEV, compute the alignment of the displaced pointer if it can be reduced +// to a constant. Using SCEV to compute alignment handles the case where +// DiffSCEV is a recurrence with constant start such that the aligned offset +// is constant. e.g. {16,+,32} % 32 -> 16. +static unsigned getNewAlignmentDiff(const SCEV *DiffSCEV, + const SCEV *AlignSCEV, + ScalarEvolution *SE) { + // DiffUnits = Diff % int64_t(Alignment) + const SCEV *DiffAlignDiv = SE->getUDivExpr(DiffSCEV, AlignSCEV); + const SCEV *DiffAlign = SE->getMulExpr(DiffAlignDiv, AlignSCEV); + const SCEV *DiffUnitsSCEV = SE->getMinusSCEV(DiffAlign, DiffSCEV); + + DEBUG(dbgs() << "\talignment relative to " << *AlignSCEV << " is " << + *DiffUnitsSCEV << " (diff: " << *DiffSCEV << ")\n"); + + if (const SCEVConstant *ConstDUSCEV = + dyn_cast<SCEVConstant>(DiffUnitsSCEV)) { + int64_t DiffUnits = ConstDUSCEV->getValue()->getSExtValue(); + + // If the displacement is an exact multiple of the alignment, then the + // displaced pointer has the same alignment as the aligned pointer, so + // return the alignment value. + if (!DiffUnits) + return (unsigned) + cast<SCEVConstant>(AlignSCEV)->getValue()->getSExtValue(); + + // If the displacement is not an exact multiple, but the remainder is a + // constant, then return this remainder (but only if it is a power of 2). + uint64_t DiffUnitsAbs = std::abs(DiffUnits); + if (isPowerOf2_64(DiffUnitsAbs)) + return (unsigned) DiffUnitsAbs; + } + + return 0; +} + +// There is an address given by an offset OffSCEV from AASCEV which has an +// alignment AlignSCEV. Use that information, if possible, to compute a new +// alignment for Ptr. +static unsigned getNewAlignment(const SCEV *AASCEV, const SCEV *AlignSCEV, + const SCEV *OffSCEV, Value *Ptr, + ScalarEvolution *SE) { + const SCEV *PtrSCEV = SE->getSCEV(Ptr); + const SCEV *DiffSCEV = SE->getMinusSCEV(PtrSCEV, AASCEV); + + // On 32-bit platforms, DiffSCEV might now have type i32 -- we've always + // sign-extended OffSCEV to i64, so make sure they agree again. + DiffSCEV = SE->getNoopOrSignExtend(DiffSCEV, OffSCEV->getType()); + + // What we really want to know is the overall offset to the aligned + // address. This address is displaced by the provided offset. + DiffSCEV = SE->getMinusSCEV(DiffSCEV, OffSCEV); + + DEBUG(dbgs() << "AFI: alignment of " << *Ptr << " relative to " << + *AlignSCEV << " and offset " << *OffSCEV << + " using diff " << *DiffSCEV << "\n"); + + unsigned NewAlignment = getNewAlignmentDiff(DiffSCEV, AlignSCEV, SE); + DEBUG(dbgs() << "\tnew alignment: " << NewAlignment << "\n"); + + if (NewAlignment) { + return NewAlignment; + } else if (const SCEVAddRecExpr *DiffARSCEV = + dyn_cast<SCEVAddRecExpr>(DiffSCEV)) { + // The relative offset to the alignment assumption did not yield a constant, + // but we should try harder: if we assume that a is 32-byte aligned, then in + // for (i = 0; i < 1024; i += 4) r += a[i]; not all of the loads from a are + // 32-byte aligned, but instead alternate between 32 and 16-byte alignment. + // As a result, the new alignment will not be a constant, but can still + // be improved over the default (of 4) to 16. + + const SCEV *DiffStartSCEV = DiffARSCEV->getStart(); + const SCEV *DiffIncSCEV = DiffARSCEV->getStepRecurrence(*SE); + + DEBUG(dbgs() << "\ttrying start/inc alignment using start " << + *DiffStartSCEV << " and inc " << *DiffIncSCEV << "\n"); + + // Now compute the new alignment using the displacement to the value in the + // first iteration, and also the alignment using the per-iteration delta. + // If these are the same, then use that answer. Otherwise, use the smaller + // one, but only if it divides the larger one. + NewAlignment = getNewAlignmentDiff(DiffStartSCEV, AlignSCEV, SE); + unsigned NewIncAlignment = getNewAlignmentDiff(DiffIncSCEV, AlignSCEV, SE); + + DEBUG(dbgs() << "\tnew start alignment: " << NewAlignment << "\n"); + DEBUG(dbgs() << "\tnew inc alignment: " << NewIncAlignment << "\n"); + + if (!NewAlignment || !NewIncAlignment) { + return 0; + } else if (NewAlignment > NewIncAlignment) { + if (NewAlignment % NewIncAlignment == 0) { + DEBUG(dbgs() << "\tnew start/inc alignment: " << + NewIncAlignment << "\n"); + return NewIncAlignment; + } + } else if (NewIncAlignment > NewAlignment) { + if (NewIncAlignment % NewAlignment == 0) { + DEBUG(dbgs() << "\tnew start/inc alignment: " << + NewAlignment << "\n"); + return NewAlignment; + } + } else if (NewIncAlignment == NewAlignment) { + DEBUG(dbgs() << "\tnew start/inc alignment: " << + NewAlignment << "\n"); + return NewAlignment; + } + } + + return 0; +} + +bool AlignmentFromAssumptionsPass::extractAlignmentInfo(CallInst *I, + Value *&AAPtr, + const SCEV *&AlignSCEV, + const SCEV *&OffSCEV) { + // An alignment assume must be a statement about the least-significant + // bits of the pointer being zero, possibly with some offset. + ICmpInst *ICI = dyn_cast<ICmpInst>(I->getArgOperand(0)); + if (!ICI) + return false; + + // This must be an expression of the form: x & m == 0. + if (ICI->getPredicate() != ICmpInst::ICMP_EQ) + return false; + + // Swap things around so that the RHS is 0. + Value *CmpLHS = ICI->getOperand(0); + Value *CmpRHS = ICI->getOperand(1); + const SCEV *CmpLHSSCEV = SE->getSCEV(CmpLHS); + const SCEV *CmpRHSSCEV = SE->getSCEV(CmpRHS); + if (CmpLHSSCEV->isZero()) + std::swap(CmpLHS, CmpRHS); + else if (!CmpRHSSCEV->isZero()) + return false; + + BinaryOperator *CmpBO = dyn_cast<BinaryOperator>(CmpLHS); + if (!CmpBO || CmpBO->getOpcode() != Instruction::And) + return false; + + // Swap things around so that the right operand of the and is a constant + // (the mask); we cannot deal with variable masks. + Value *AndLHS = CmpBO->getOperand(0); + Value *AndRHS = CmpBO->getOperand(1); + const SCEV *AndLHSSCEV = SE->getSCEV(AndLHS); + const SCEV *AndRHSSCEV = SE->getSCEV(AndRHS); + if (isa<SCEVConstant>(AndLHSSCEV)) { + std::swap(AndLHS, AndRHS); + std::swap(AndLHSSCEV, AndRHSSCEV); + } + + const SCEVConstant *MaskSCEV = dyn_cast<SCEVConstant>(AndRHSSCEV); + if (!MaskSCEV) + return false; + + // The mask must have some trailing ones (otherwise the condition is + // trivial and tells us nothing about the alignment of the left operand). + unsigned TrailingOnes = MaskSCEV->getAPInt().countTrailingOnes(); + if (!TrailingOnes) + return false; + + // Cap the alignment at the maximum with which LLVM can deal (and make sure + // we don't overflow the shift). + uint64_t Alignment; + TrailingOnes = std::min(TrailingOnes, + unsigned(sizeof(unsigned) * CHAR_BIT - 1)); + Alignment = std::min(1u << TrailingOnes, +Value::MaximumAlignment); + + Type *Int64Ty = Type::getInt64Ty(I->getParent()->getParent()->getContext()); + AlignSCEV = SE->getConstant(Int64Ty, Alignment); + + // The LHS might be a ptrtoint instruction, or it might be the pointer + // with an offset. + AAPtr = nullptr; + OffSCEV = nullptr; + if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(AndLHS)) { + AAPtr = PToI->getPointerOperand(); + OffSCEV = SE->getZero(Int64Ty); + } else if (const SCEVAddExpr* AndLHSAddSCEV = + dyn_cast<SCEVAddExpr>(AndLHSSCEV)) { + // Try to find the ptrtoint; subtract it and the rest is the offset. + for (SCEVAddExpr::op_iterator J = AndLHSAddSCEV->op_begin(), + JE = AndLHSAddSCEV->op_end(); J != JE; ++J) + if (const SCEVUnknown *OpUnk = dyn_cast<SCEVUnknown>(*J)) + if (PtrToIntInst *PToI = dyn_cast<PtrToIntInst>(OpUnk->getValue())) { + AAPtr = PToI->getPointerOperand(); + OffSCEV = SE->getMinusSCEV(AndLHSAddSCEV, *J); + break; + } + } + + if (!AAPtr) + return false; + + // Sign extend the offset to 64 bits (so that it is like all of the other + // expressions). + unsigned OffSCEVBits = OffSCEV->getType()->getPrimitiveSizeInBits(); + if (OffSCEVBits < 64) + OffSCEV = SE->getSignExtendExpr(OffSCEV, Int64Ty); + else if (OffSCEVBits > 64) + return false; + + AAPtr = AAPtr->stripPointerCasts(); + return true; +} + +bool AlignmentFromAssumptionsPass::processAssumption(CallInst *ACall) { + Value *AAPtr; + const SCEV *AlignSCEV, *OffSCEV; + if (!extractAlignmentInfo(ACall, AAPtr, AlignSCEV, OffSCEV)) + return false; + + // Skip ConstantPointerNull and UndefValue. Assumptions on these shouldn't + // affect other users. + if (isa<ConstantData>(AAPtr)) + return false; + + const SCEV *AASCEV = SE->getSCEV(AAPtr); + + // Apply the assumption to all other users of the specified pointer. + SmallPtrSet<Instruction *, 32> Visited; + SmallVector<Instruction*, 16> WorkList; + for (User *J : AAPtr->users()) { + if (J == ACall) + continue; + + if (Instruction *K = dyn_cast<Instruction>(J)) + if (isValidAssumeForContext(ACall, K, DT)) + WorkList.push_back(K); + } + + while (!WorkList.empty()) { + Instruction *J = WorkList.pop_back_val(); + + if (LoadInst *LI = dyn_cast<LoadInst>(J)) { + unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + LI->getPointerOperand(), SE); + + if (NewAlignment > LI->getAlignment()) { + LI->setAlignment(NewAlignment); + ++NumLoadAlignChanged; + } + } else if (StoreInst *SI = dyn_cast<StoreInst>(J)) { + unsigned NewAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + SI->getPointerOperand(), SE); + + if (NewAlignment > SI->getAlignment()) { + SI->setAlignment(NewAlignment); + ++NumStoreAlignChanged; + } + } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(J)) { + unsigned NewDestAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + MI->getDest(), SE); + + // For memory transfers, we need a common alignment for both the + // source and destination. If we have a new alignment for this + // instruction, but only for one operand, save it. If we reach the + // other operand through another assumption later, then we may + // change the alignment at that point. + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI)) { + unsigned NewSrcAlignment = getNewAlignment(AASCEV, AlignSCEV, OffSCEV, + MTI->getSource(), SE); + + DenseMap<MemTransferInst *, unsigned>::iterator DI = + NewDestAlignments.find(MTI); + unsigned AltDestAlignment = (DI == NewDestAlignments.end()) ? + 0 : DI->second; + + DenseMap<MemTransferInst *, unsigned>::iterator SI = + NewSrcAlignments.find(MTI); + unsigned AltSrcAlignment = (SI == NewSrcAlignments.end()) ? + 0 : SI->second; + + DEBUG(dbgs() << "\tmem trans: " << NewDestAlignment << " " << + AltDestAlignment << " " << NewSrcAlignment << + " " << AltSrcAlignment << "\n"); + + // Of these four alignments, pick the largest possible... + unsigned NewAlignment = 0; + if (NewDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment)) + NewAlignment = std::max(NewAlignment, NewDestAlignment); + if (AltDestAlignment <= std::max(NewSrcAlignment, AltSrcAlignment)) + NewAlignment = std::max(NewAlignment, AltDestAlignment); + if (NewSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment)) + NewAlignment = std::max(NewAlignment, NewSrcAlignment); + if (AltSrcAlignment <= std::max(NewDestAlignment, AltDestAlignment)) + NewAlignment = std::max(NewAlignment, AltSrcAlignment); + + if (NewAlignment > MI->getAlignment()) { + MI->setAlignment(ConstantInt::get(Type::getInt32Ty( + MI->getParent()->getContext()), NewAlignment)); + ++NumMemIntAlignChanged; + } + + NewDestAlignments.insert(std::make_pair(MTI, NewDestAlignment)); + NewSrcAlignments.insert(std::make_pair(MTI, NewSrcAlignment)); + } else if (NewDestAlignment > MI->getAlignment()) { + assert((!isa<MemIntrinsic>(MI) || isa<MemSetInst>(MI)) && + "Unknown memory intrinsic"); + + MI->setAlignment(ConstantInt::get(Type::getInt32Ty( + MI->getParent()->getContext()), NewDestAlignment)); + ++NumMemIntAlignChanged; + } + } + + // Now that we've updated that use of the pointer, look for other uses of + // the pointer to update. + Visited.insert(J); + for (User *UJ : J->users()) { + Instruction *K = cast<Instruction>(UJ); + if (!Visited.count(K) && isValidAssumeForContext(ACall, K, DT)) + WorkList.push_back(K); + } + } + + return true; +} + +bool AlignmentFromAssumptions::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + + return Impl.runImpl(F, AC, SE, DT); +} + +bool AlignmentFromAssumptionsPass::runImpl(Function &F, AssumptionCache &AC, + ScalarEvolution *SE_, + DominatorTree *DT_) { + SE = SE_; + DT = DT_; + + NewDestAlignments.clear(); + NewSrcAlignments.clear(); + + bool Changed = false; + for (auto &AssumeVH : AC.assumptions()) + if (AssumeVH) + Changed |= processAssumption(cast<CallInst>(AssumeVH)); + + return Changed; +} + +PreservedAnalyses +AlignmentFromAssumptionsPass::run(Function &F, FunctionAnalysisManager &AM) { + + AssumptionCache &AC = AM.getResult<AssumptionAnalysis>(F); + ScalarEvolution &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + if (!runImpl(F, AC, &SE, &DT)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<AAManager>(); + PA.preserve<ScalarEvolutionAnalysis>(); + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp b/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp new file mode 100644 index 000000000000..851efa000f65 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/BDCE.cpp @@ -0,0 +1,170 @@ +//===---- BDCE.cpp - Bit-tracking dead code elimination -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Bit-Tracking Dead Code Elimination pass. Some +// instructions (shifts, some ands, ors, etc.) kill some of their input bits. +// We track these dead bits and remove instructions that compute only these +// dead bits. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/BDCE.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/DemandedBits.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +using namespace llvm; + +#define DEBUG_TYPE "bdce" + +STATISTIC(NumRemoved, "Number of instructions removed (unused)"); +STATISTIC(NumSimplified, "Number of instructions trivialized (dead bits)"); + +/// If an instruction is trivialized (dead), then the chain of users of that +/// instruction may need to be cleared of assumptions that can no longer be +/// guaranteed correct. +static void clearAssumptionsOfUsers(Instruction *I, DemandedBits &DB) { + assert(I->getType()->isIntegerTy() && "Trivializing a non-integer value?"); + + // Initialize the worklist with eligible direct users. + SmallVector<Instruction *, 16> WorkList; + for (User *JU : I->users()) { + // If all bits of a user are demanded, then we know that nothing below that + // in the def-use chain needs to be changed. + auto *J = dyn_cast<Instruction>(JU); + if (J && J->getType()->isSized() && + !DB.getDemandedBits(J).isAllOnesValue()) + WorkList.push_back(J); + + // Note that we need to check for unsized types above before asking for + // demanded bits. Normally, the only way to reach an instruction with an + // unsized type is via an instruction that has side effects (or otherwise + // will demand its input bits). However, if we have a readnone function + // that returns an unsized type (e.g., void), we must avoid asking for the + // demanded bits of the function call's return value. A void-returning + // readnone function is always dead (and so we can stop walking the use/def + // chain here), but the check is necessary to avoid asserting. + } + + // DFS through subsequent users while tracking visits to avoid cycles. + SmallPtrSet<Instruction *, 16> Visited; + while (!WorkList.empty()) { + Instruction *J = WorkList.pop_back_val(); + + // NSW, NUW, and exact are based on operands that might have changed. + J->dropPoisonGeneratingFlags(); + + // We do not have to worry about llvm.assume or range metadata: + // 1. llvm.assume demands its operand, so trivializing can't change it. + // 2. range metadata only applies to memory accesses which demand all bits. + + Visited.insert(J); + + for (User *KU : J->users()) { + // If all bits of a user are demanded, then we know that nothing below + // that in the def-use chain needs to be changed. + auto *K = dyn_cast<Instruction>(KU); + if (K && !Visited.count(K) && K->getType()->isSized() && + !DB.getDemandedBits(K).isAllOnesValue()) + WorkList.push_back(K); + } + } +} + +static bool bitTrackingDCE(Function &F, DemandedBits &DB) { + SmallVector<Instruction*, 128> Worklist; + bool Changed = false; + for (Instruction &I : instructions(F)) { + // If the instruction has side effects and no non-dbg uses, + // skip it. This way we avoid computing known bits on an instruction + // that will not help us. + if (I.mayHaveSideEffects() && I.use_empty()) + continue; + + if (I.getType()->isIntegerTy() && + !DB.getDemandedBits(&I).getBoolValue()) { + // For live instructions that have all dead bits, first make them dead by + // replacing all uses with something else. Then, if they don't need to + // remain live (because they have side effects, etc.) we can remove them. + DEBUG(dbgs() << "BDCE: Trivializing: " << I << " (all bits dead)\n"); + + clearAssumptionsOfUsers(&I, DB); + + // FIXME: In theory we could substitute undef here instead of zero. + // This should be reconsidered once we settle on the semantics of + // undef, poison, etc. + Value *Zero = ConstantInt::get(I.getType(), 0); + ++NumSimplified; + I.replaceNonMetadataUsesWith(Zero); + Changed = true; + } + if (!DB.isInstructionDead(&I)) + continue; + + Worklist.push_back(&I); + I.dropAllReferences(); + Changed = true; + } + + for (Instruction *&I : Worklist) { + ++NumRemoved; + I->eraseFromParent(); + } + + return Changed; +} + +PreservedAnalyses BDCEPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &DB = AM.getResult<DemandedBitsAnalysis>(F); + if (!bitTrackingDCE(F, DB)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + return PA; +} + +namespace { +struct BDCELegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + BDCELegacyPass() : FunctionPass(ID) { + initializeBDCELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto &DB = getAnalysis<DemandedBitsWrapperPass>().getDemandedBits(); + return bitTrackingDCE(F, DB); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DemandedBitsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; +} + +char BDCELegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(BDCELegacyPass, "bdce", + "Bit-Tracking Dead Code Elimination", false, false) +INITIALIZE_PASS_DEPENDENCY(DemandedBitsWrapperPass) +INITIALIZE_PASS_END(BDCELegacyPass, "bdce", + "Bit-Tracking Dead Code Elimination", false, false) + +FunctionPass *llvm::createBitTrackingDCEPass() { return new BDCELegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/contrib/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp new file mode 100644 index 000000000000..4edea7cc3c82 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -0,0 +1,418 @@ +//===- CallSiteSplitting.cpp ----------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a transformation that tries to split a call-site to pass +// more constrained arguments if its argument is predicated in the control flow +// so that we can expose better context to the later passes (e.g, inliner, jump +// threading, or IPA-CP based function cloning, etc.). +// As of now we support two cases : +// +// 1) Try to a split call-site with constrained arguments, if any constraints +// on any argument can be found by following the single predecessors of the +// all site's predecessors. Currently this pass only handles call-sites with 2 +// predecessors. For example, in the code below, we try to split the call-site +// since we can predicate the argument(ptr) based on the OR condition. +// +// Split from : +// if (!ptr || c) +// callee(ptr); +// to : +// if (!ptr) +// callee(null) // set the known constant value +// else if (c) +// callee(nonnull ptr) // set non-null attribute in the argument +// +// 2) We can also split a call-site based on constant incoming values of a PHI +// For example, +// from : +// Header: +// %c = icmp eq i32 %i1, %i2 +// br i1 %c, label %Tail, label %TBB +// TBB: +// br label Tail% +// Tail: +// %p = phi i32 [ 0, %Header], [ 1, %TBB] +// call void @bar(i32 %p) +// to +// Header: +// %c = icmp eq i32 %i1, %i2 +// br i1 %c, label %Tail-split0, label %TBB +// TBB: +// br label %Tail-split1 +// Tail-split0: +// call void @bar(i32 0) +// br label %Tail +// Tail-split1: +// call void @bar(i32 1) +// br label %Tail +// Tail: +// %p = phi i32 [ 0, %Tail-split0 ], [ 1, %Tail-split1 ] +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/CallSiteSplitting.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "callsite-splitting" + +STATISTIC(NumCallSiteSplit, "Number of call-site split"); + +static void addNonNullAttribute(Instruction *CallI, Instruction *NewCallI, + Value *Op) { + CallSite CS(NewCallI); + unsigned ArgNo = 0; + for (auto &I : CS.args()) { + if (&*I == Op) + CS.addParamAttr(ArgNo, Attribute::NonNull); + ++ArgNo; + } +} + +static void setConstantInArgument(Instruction *CallI, Instruction *NewCallI, + Value *Op, Constant *ConstValue) { + CallSite CS(NewCallI); + unsigned ArgNo = 0; + for (auto &I : CS.args()) { + if (&*I == Op) + CS.setArgument(ArgNo, ConstValue); + ++ArgNo; + } +} + +static bool isCondRelevantToAnyCallArgument(ICmpInst *Cmp, CallSite CS) { + assert(isa<Constant>(Cmp->getOperand(1)) && "Expected a constant operand."); + Value *Op0 = Cmp->getOperand(0); + unsigned ArgNo = 0; + for (CallSite::arg_iterator I = CS.arg_begin(), E = CS.arg_end(); I != E; + ++I, ++ArgNo) { + // Don't consider constant or arguments that are already known non-null. + if (isa<Constant>(*I) || CS.paramHasAttr(ArgNo, Attribute::NonNull)) + continue; + + if (*I == Op0) + return true; + } + return false; +} + +/// If From has a conditional jump to To, add the condition to Conditions, +/// if it is relevant to any argument at CS. +static void +recordCondition(const CallSite &CS, BasicBlock *From, BasicBlock *To, + SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { + auto *BI = dyn_cast<BranchInst>(From->getTerminator()); + if (!BI || !BI->isConditional()) + return; + + CmpInst::Predicate Pred; + Value *Cond = BI->getCondition(); + if (!match(Cond, m_ICmp(Pred, m_Value(), m_Constant()))) + return; + + ICmpInst *Cmp = cast<ICmpInst>(Cond); + if (Pred == ICmpInst::ICMP_EQ || Pred == ICmpInst::ICMP_NE) + if (isCondRelevantToAnyCallArgument(Cmp, CS)) + Conditions.push_back({Cmp, From->getTerminator()->getSuccessor(0) == To + ? Pred + : Cmp->getInversePredicate()}); +} + +/// Record ICmp conditions relevant to any argument in CS following Pred's +/// single successors. If there are conflicting conditions along a path, like +/// x == 1 and x == 0, the first condition will be used. +static void +recordConditions(const CallSite &CS, BasicBlock *Pred, + SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { + recordCondition(CS, Pred, CS.getInstruction()->getParent(), Conditions); + BasicBlock *From = Pred; + BasicBlock *To = Pred; + SmallPtrSet<BasicBlock *, 4> Visited; + while (!Visited.count(From->getSinglePredecessor()) && + (From = From->getSinglePredecessor())) { + recordCondition(CS, From, To, Conditions); + Visited.insert(From); + To = From; + } +} + +static Instruction * +addConditions(CallSite &CS, + SmallVectorImpl<std::pair<ICmpInst *, unsigned>> &Conditions) { + if (Conditions.empty()) + return nullptr; + + Instruction *NewCI = CS.getInstruction()->clone(); + for (auto &Cond : Conditions) { + Value *Arg = Cond.first->getOperand(0); + Constant *ConstVal = cast<Constant>(Cond.first->getOperand(1)); + if (Cond.second == ICmpInst::ICMP_EQ) + setConstantInArgument(CS.getInstruction(), NewCI, Arg, ConstVal); + else if (ConstVal->getType()->isPointerTy() && ConstVal->isNullValue()) { + assert(Cond.second == ICmpInst::ICMP_NE); + addNonNullAttribute(CS.getInstruction(), NewCI, Arg); + } + } + return NewCI; +} + +static SmallVector<BasicBlock *, 2> getTwoPredecessors(BasicBlock *BB) { + SmallVector<BasicBlock *, 2> Preds(predecessors((BB))); + assert(Preds.size() == 2 && "Expected exactly 2 predecessors!"); + return Preds; +} + +static bool canSplitCallSite(CallSite CS) { + // FIXME: As of now we handle only CallInst. InvokeInst could be handled + // without too much effort. + Instruction *Instr = CS.getInstruction(); + if (!isa<CallInst>(Instr)) + return false; + + // Allow splitting a call-site only when there is no instruction before the + // call-site in the basic block. Based on this constraint, we only clone the + // call instruction, and we do not move a call-site across any other + // instruction. + BasicBlock *CallSiteBB = Instr->getParent(); + if (Instr != CallSiteBB->getFirstNonPHIOrDbg()) + return false; + + // Need 2 predecessors and cannot split an edge from an IndirectBrInst. + SmallVector<BasicBlock *, 2> Preds(predecessors(CallSiteBB)); + if (Preds.size() != 2 || isa<IndirectBrInst>(Preds[0]->getTerminator()) || + isa<IndirectBrInst>(Preds[1]->getTerminator())) + return false; + + return CallSiteBB->canSplitPredecessors(); +} + +/// Return true if the CS is split into its new predecessors which are directly +/// hooked to each of its original predecessors pointed by PredBB1 and PredBB2. +/// CallInst1 and CallInst2 will be the new call-sites placed in the new +/// predecessors split for PredBB1 and PredBB2, respectively. +/// For example, in the IR below with an OR condition, the call-site can +/// be split. Assuming PredBB1=Header and PredBB2=TBB, CallInst1 will be the +/// call-site placed between Header and Tail, and CallInst2 will be the +/// call-site between TBB and Tail. +/// +/// From : +/// +/// Header: +/// %c = icmp eq i32* %a, null +/// br i1 %c %Tail, %TBB +/// TBB: +/// %c2 = icmp eq i32* %b, null +/// br i1 %c %Tail, %End +/// Tail: +/// %ca = call i1 @callee (i32* %a, i32* %b) +/// +/// to : +/// +/// Header: // PredBB1 is Header +/// %c = icmp eq i32* %a, null +/// br i1 %c %Tail-split1, %TBB +/// TBB: // PredBB2 is TBB +/// %c2 = icmp eq i32* %b, null +/// br i1 %c %Tail-split2, %End +/// Tail-split1: +/// %ca1 = call @callee (i32* null, i32* %b) // CallInst1 +/// br %Tail +/// Tail-split2: +/// %ca2 = call @callee (i32* nonnull %a, i32* null) // CallInst2 +/// br %Tail +/// Tail: +/// %p = phi i1 [%ca1, %Tail-split1],[%ca2, %Tail-split2] +/// +/// Note that in case any arguments at the call-site are constrained by its +/// predecessors, new call-sites with more constrained arguments will be +/// created in createCallSitesOnPredicatedArgument(). +static void splitCallSite(CallSite CS, BasicBlock *PredBB1, BasicBlock *PredBB2, + Instruction *CallInst1, Instruction *CallInst2) { + Instruction *Instr = CS.getInstruction(); + BasicBlock *TailBB = Instr->getParent(); + assert(Instr == (TailBB->getFirstNonPHIOrDbg()) && "Unexpected call-site"); + + BasicBlock *SplitBlock1 = + SplitBlockPredecessors(TailBB, PredBB1, ".predBB1.split"); + BasicBlock *SplitBlock2 = + SplitBlockPredecessors(TailBB, PredBB2, ".predBB2.split"); + + assert((SplitBlock1 && SplitBlock2) && "Unexpected new basic block split."); + + if (!CallInst1) + CallInst1 = Instr->clone(); + if (!CallInst2) + CallInst2 = Instr->clone(); + + CallInst1->insertBefore(&*SplitBlock1->getFirstInsertionPt()); + CallInst2->insertBefore(&*SplitBlock2->getFirstInsertionPt()); + + CallSite CS1(CallInst1); + CallSite CS2(CallInst2); + + // Handle PHIs used as arguments in the call-site. + for (PHINode &PN : TailBB->phis()) { + unsigned ArgNo = 0; + for (auto &CI : CS.args()) { + if (&*CI == &PN) { + CS1.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock1)); + CS2.setArgument(ArgNo, PN.getIncomingValueForBlock(SplitBlock2)); + } + ++ArgNo; + } + } + + // Replace users of the original call with a PHI mering call-sites split. + if (Instr->getNumUses()) { + PHINode *PN = PHINode::Create(Instr->getType(), 2, "phi.call", + TailBB->getFirstNonPHI()); + PN->addIncoming(CallInst1, SplitBlock1); + PN->addIncoming(CallInst2, SplitBlock2); + Instr->replaceAllUsesWith(PN); + } + DEBUG(dbgs() << "split call-site : " << *Instr << " into \n"); + DEBUG(dbgs() << " " << *CallInst1 << " in " << SplitBlock1->getName() + << "\n"); + DEBUG(dbgs() << " " << *CallInst2 << " in " << SplitBlock2->getName() + << "\n"); + Instr->eraseFromParent(); + NumCallSiteSplit++; +} + +// Return true if the call-site has an argument which is a PHI with only +// constant incoming values. +static bool isPredicatedOnPHI(CallSite CS) { + Instruction *Instr = CS.getInstruction(); + BasicBlock *Parent = Instr->getParent(); + if (Instr != Parent->getFirstNonPHIOrDbg()) + return false; + + for (auto &BI : *Parent) { + if (PHINode *PN = dyn_cast<PHINode>(&BI)) { + for (auto &I : CS.args()) + if (&*I == PN) { + assert(PN->getNumIncomingValues() == 2 && + "Unexpected number of incoming values"); + if (PN->getIncomingBlock(0) == PN->getIncomingBlock(1)) + return false; + if (PN->getIncomingValue(0) == PN->getIncomingValue(1)) + continue; + if (isa<Constant>(PN->getIncomingValue(0)) && + isa<Constant>(PN->getIncomingValue(1))) + return true; + } + } + break; + } + return false; +} + +static bool tryToSplitOnPHIPredicatedArgument(CallSite CS) { + if (!isPredicatedOnPHI(CS)) + return false; + + auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); + splitCallSite(CS, Preds[0], Preds[1], nullptr, nullptr); + return true; +} + +static bool tryToSplitOnPredicatedArgument(CallSite CS) { + auto Preds = getTwoPredecessors(CS.getInstruction()->getParent()); + if (Preds[0] == Preds[1]) + return false; + + SmallVector<std::pair<ICmpInst *, unsigned>, 2> C1, C2; + recordConditions(CS, Preds[0], C1); + recordConditions(CS, Preds[1], C2); + + Instruction *CallInst1 = addConditions(CS, C1); + Instruction *CallInst2 = addConditions(CS, C2); + if (!CallInst1 && !CallInst2) + return false; + + splitCallSite(CS, Preds[1], Preds[0], CallInst2, CallInst1); + return true; +} + +static bool tryToSplitCallSite(CallSite CS) { + if (!CS.arg_size() || !canSplitCallSite(CS)) + return false; + return tryToSplitOnPredicatedArgument(CS) || + tryToSplitOnPHIPredicatedArgument(CS); +} + +static bool doCallSiteSplitting(Function &F, TargetLibraryInfo &TLI) { + bool Changed = false; + for (Function::iterator BI = F.begin(), BE = F.end(); BI != BE;) { + BasicBlock &BB = *BI++; + for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { + Instruction *I = &*II++; + CallSite CS(cast<Value>(I)); + if (!CS || isa<IntrinsicInst>(I) || isInstructionTriviallyDead(I, &TLI)) + continue; + + Function *Callee = CS.getCalledFunction(); + if (!Callee || Callee->isDeclaration()) + continue; + Changed |= tryToSplitCallSite(CS); + } + } + return Changed; +} + +namespace { +struct CallSiteSplittingLegacyPass : public FunctionPass { + static char ID; + CallSiteSplittingLegacyPass() : FunctionPass(ID) { + initializeCallSiteSplittingLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + FunctionPass::getAnalysisUsage(AU); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return doCallSiteSplitting(F, TLI); + } +}; +} // namespace + +char CallSiteSplittingLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(CallSiteSplittingLegacyPass, "callsite-splitting", + "Call-site splitting", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(CallSiteSplittingLegacyPass, "callsite-splitting", + "Call-site splitting", false, false) +FunctionPass *llvm::createCallSiteSplittingPass() { + return new CallSiteSplittingLegacyPass(); +} + +PreservedAnalyses CallSiteSplittingPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + + if (!doCallSiteSplitting(F, TLI)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp b/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp new file mode 100644 index 000000000000..e4b08c5ed305 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/ConstantHoisting.cpp @@ -0,0 +1,822 @@ +//===- ConstantHoisting.cpp - Prepare code for expensive constants --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass identifies expensive constants to hoist and coalesces them to +// better prepare it for SelectionDAG-based code generation. This works around +// the limitations of the basic-block-at-a-time approach. +// +// First it scans all instructions for integer constants and calculates its +// cost. If the constant can be folded into the instruction (the cost is +// TCC_Free) or the cost is just a simple operation (TCC_BASIC), then we don't +// consider it expensive and leave it alone. This is the default behavior and +// the default implementation of getIntImmCost will always return TCC_Free. +// +// If the cost is more than TCC_BASIC, then the integer constant can't be folded +// into the instruction and it might be beneficial to hoist the constant. +// Similar constants are coalesced to reduce register pressure and +// materialization code. +// +// When a constant is hoisted, it is also hidden behind a bitcast to force it to +// be live-out of the basic block. Otherwise the constant would be just +// duplicated and each basic block would have its own copy in the SelectionDAG. +// The SelectionDAG recognizes such constants as opaque and doesn't perform +// certain transformations on them, which would create a new expensive constant. +// +// This optimization is only applied to integer constants in instructions and +// simple (this means not nested) constant cast expressions. For example: +// %0 = load i64* inttoptr (i64 big_constant to i64*) +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/ConstantHoisting.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <tuple> +#include <utility> + +using namespace llvm; +using namespace consthoist; + +#define DEBUG_TYPE "consthoist" + +STATISTIC(NumConstantsHoisted, "Number of constants hoisted"); +STATISTIC(NumConstantsRebased, "Number of constants rebased"); + +static cl::opt<bool> ConstHoistWithBlockFrequency( + "consthoist-with-block-frequency", cl::init(true), cl::Hidden, + cl::desc("Enable the use of the block frequency analysis to reduce the " + "chance to execute const materialization more frequently than " + "without hoisting.")); + +namespace { + +/// \brief The constant hoisting pass. +class ConstantHoistingLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + ConstantHoistingLegacyPass() : FunctionPass(ID) { + initializeConstantHoistingLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &Fn) override; + + StringRef getPassName() const override { return "Constant Hoisting"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + if (ConstHoistWithBlockFrequency) + AU.addRequired<BlockFrequencyInfoWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + void releaseMemory() override { Impl.releaseMemory(); } + +private: + ConstantHoistingPass Impl; +}; + +} // end anonymous namespace + +char ConstantHoistingLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist", + "Constant Hoisting", false, false) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(ConstantHoistingLegacyPass, "consthoist", + "Constant Hoisting", false, false) + +FunctionPass *llvm::createConstantHoistingPass() { + return new ConstantHoistingLegacyPass(); +} + +/// \brief Perform the constant hoisting optimization for the given function. +bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) { + if (skipFunction(Fn)) + return false; + + DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n"); + DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n'); + + bool MadeChange = + Impl.runImpl(Fn, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + ConstHoistWithBlockFrequency + ? &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI() + : nullptr, + Fn.getEntryBlock()); + + if (MadeChange) { + DEBUG(dbgs() << "********** Function after Constant Hoisting: " + << Fn.getName() << '\n'); + DEBUG(dbgs() << Fn); + } + DEBUG(dbgs() << "********** End Constant Hoisting **********\n"); + + return MadeChange; +} + +/// \brief Find the constant materialization insertion point. +Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst, + unsigned Idx) const { + // If the operand is a cast instruction, then we have to materialize the + // constant before the cast instruction. + if (Idx != ~0U) { + Value *Opnd = Inst->getOperand(Idx); + if (auto CastInst = dyn_cast<Instruction>(Opnd)) + if (CastInst->isCast()) + return CastInst; + } + + // The simple and common case. This also includes constant expressions. + if (!isa<PHINode>(Inst) && !Inst->isEHPad()) + return Inst; + + // We can't insert directly before a phi node or an eh pad. Insert before + // the terminator of the incoming or dominating block. + assert(Entry != Inst->getParent() && "PHI or landing pad in entry block!"); + if (Idx != ~0U && isa<PHINode>(Inst)) + return cast<PHINode>(Inst)->getIncomingBlock(Idx)->getTerminator(); + + // This must be an EH pad. Iterate over immediate dominators until we find a + // non-EH pad. We need to skip over catchswitch blocks, which are both EH pads + // and terminators. + auto IDom = DT->getNode(Inst->getParent())->getIDom(); + while (IDom->getBlock()->isEHPad()) { + assert(Entry != IDom->getBlock() && "eh pad in entry block"); + IDom = IDom->getIDom(); + } + + return IDom->getBlock()->getTerminator(); +} + +/// \brief Given \p BBs as input, find another set of BBs which collectively +/// dominates \p BBs and have the minimal sum of frequencies. Return the BB +/// set found in \p BBs. +static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI, + BasicBlock *Entry, + SmallPtrSet<BasicBlock *, 8> &BBs) { + assert(!BBs.count(Entry) && "Assume Entry is not in BBs"); + // Nodes on the current path to the root. + SmallPtrSet<BasicBlock *, 8> Path; + // Candidates includes any block 'BB' in set 'BBs' that is not strictly + // dominated by any other blocks in set 'BBs', and all nodes in the path + // in the dominator tree from Entry to 'BB'. + SmallPtrSet<BasicBlock *, 16> Candidates; + for (auto BB : BBs) { + Path.clear(); + // Walk up the dominator tree until Entry or another BB in BBs + // is reached. Insert the nodes on the way to the Path. + BasicBlock *Node = BB; + // The "Path" is a candidate path to be added into Candidates set. + bool isCandidate = false; + do { + Path.insert(Node); + if (Node == Entry || Candidates.count(Node)) { + isCandidate = true; + break; + } + assert(DT.getNode(Node)->getIDom() && + "Entry doens't dominate current Node"); + Node = DT.getNode(Node)->getIDom()->getBlock(); + } while (!BBs.count(Node)); + + // If isCandidate is false, Node is another Block in BBs dominating + // current 'BB'. Drop the nodes on the Path. + if (!isCandidate) + continue; + + // Add nodes on the Path into Candidates. + Candidates.insert(Path.begin(), Path.end()); + } + + // Sort the nodes in Candidates in top-down order and save the nodes + // in Orders. + unsigned Idx = 0; + SmallVector<BasicBlock *, 16> Orders; + Orders.push_back(Entry); + while (Idx != Orders.size()) { + BasicBlock *Node = Orders[Idx++]; + for (auto ChildDomNode : DT.getNode(Node)->getChildren()) { + if (Candidates.count(ChildDomNode->getBlock())) + Orders.push_back(ChildDomNode->getBlock()); + } + } + + // Visit Orders in bottom-up order. + using InsertPtsCostPair = + std::pair<SmallPtrSet<BasicBlock *, 16>, BlockFrequency>; + + // InsertPtsMap is a map from a BB to the best insertion points for the + // subtree of BB (subtree not including the BB itself). + DenseMap<BasicBlock *, InsertPtsCostPair> InsertPtsMap; + InsertPtsMap.reserve(Orders.size() + 1); + for (auto RIt = Orders.rbegin(); RIt != Orders.rend(); RIt++) { + BasicBlock *Node = *RIt; + bool NodeInBBs = BBs.count(Node); + SmallPtrSet<BasicBlock *, 16> &InsertPts = InsertPtsMap[Node].first; + BlockFrequency &InsertPtsFreq = InsertPtsMap[Node].second; + + // Return the optimal insert points in BBs. + if (Node == Entry) { + BBs.clear(); + if (InsertPtsFreq > BFI.getBlockFreq(Node) || + (InsertPtsFreq == BFI.getBlockFreq(Node) && InsertPts.size() > 1)) + BBs.insert(Entry); + else + BBs.insert(InsertPts.begin(), InsertPts.end()); + break; + } + + BasicBlock *Parent = DT.getNode(Node)->getIDom()->getBlock(); + // Initially, ParentInsertPts is empty and ParentPtsFreq is 0. Every child + // will update its parent's ParentInsertPts and ParentPtsFreq. + SmallPtrSet<BasicBlock *, 16> &ParentInsertPts = InsertPtsMap[Parent].first; + BlockFrequency &ParentPtsFreq = InsertPtsMap[Parent].second; + // Choose to insert in Node or in subtree of Node. + // Don't hoist to EHPad because we may not find a proper place to insert + // in EHPad. + // If the total frequency of InsertPts is the same as the frequency of the + // target Node, and InsertPts contains more than one nodes, choose hoisting + // to reduce code size. + if (NodeInBBs || + (!Node->isEHPad() && + (InsertPtsFreq > BFI.getBlockFreq(Node) || + (InsertPtsFreq == BFI.getBlockFreq(Node) && InsertPts.size() > 1)))) { + ParentInsertPts.insert(Node); + ParentPtsFreq += BFI.getBlockFreq(Node); + } else { + ParentInsertPts.insert(InsertPts.begin(), InsertPts.end()); + ParentPtsFreq += InsertPtsFreq; + } + } +} + +/// \brief Find an insertion point that dominates all uses. +SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint( + const ConstantInfo &ConstInfo) const { + assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry."); + // Collect all basic blocks. + SmallPtrSet<BasicBlock *, 8> BBs; + SmallPtrSet<Instruction *, 8> InsertPts; + for (auto const &RCI : ConstInfo.RebasedConstants) + for (auto const &U : RCI.Uses) + BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent()); + + if (BBs.count(Entry)) { + InsertPts.insert(&Entry->front()); + return InsertPts; + } + + if (BFI) { + findBestInsertionSet(*DT, *BFI, Entry, BBs); + for (auto BB : BBs) { + BasicBlock::iterator InsertPt = BB->begin(); + for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt) + ; + InsertPts.insert(&*InsertPt); + } + return InsertPts; + } + + while (BBs.size() >= 2) { + BasicBlock *BB, *BB1, *BB2; + BB1 = *BBs.begin(); + BB2 = *std::next(BBs.begin()); + BB = DT->findNearestCommonDominator(BB1, BB2); + if (BB == Entry) { + InsertPts.insert(&Entry->front()); + return InsertPts; + } + BBs.erase(BB1); + BBs.erase(BB2); + BBs.insert(BB); + } + assert((BBs.size() == 1) && "Expected only one element."); + Instruction &FirstInst = (*BBs.begin())->front(); + InsertPts.insert(findMatInsertPt(&FirstInst)); + return InsertPts; +} + +/// \brief Record constant integer ConstInt for instruction Inst at operand +/// index Idx. +/// +/// The operand at index Idx is not necessarily the constant integer itself. It +/// could also be a cast instruction or a constant expression that uses the +// constant integer. +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx, + ConstantInt *ConstInt) { + unsigned Cost; + // Ask the target about the cost of materializing the constant for the given + // instruction and operand index. + if (auto IntrInst = dyn_cast<IntrinsicInst>(Inst)) + Cost = TTI->getIntImmCost(IntrInst->getIntrinsicID(), Idx, + ConstInt->getValue(), ConstInt->getType()); + else + Cost = TTI->getIntImmCost(Inst->getOpcode(), Idx, ConstInt->getValue(), + ConstInt->getType()); + + // Ignore cheap integer constants. + if (Cost > TargetTransformInfo::TCC_Basic) { + ConstCandMapType::iterator Itr; + bool Inserted; + std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(ConstInt, 0)); + if (Inserted) { + ConstCandVec.push_back(ConstantCandidate(ConstInt)); + Itr->second = ConstCandVec.size() - 1; + } + ConstCandVec[Itr->second].addUser(Inst, Idx, Cost); + DEBUG(if (isa<ConstantInt>(Inst->getOperand(Idx))) + dbgs() << "Collect constant " << *ConstInt << " from " << *Inst + << " with cost " << Cost << '\n'; + else + dbgs() << "Collect constant " << *ConstInt << " indirectly from " + << *Inst << " via " << *Inst->getOperand(Idx) << " with cost " + << Cost << '\n'; + ); + } +} + +/// \brief Check the operand for instruction Inst at index Idx. +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx) { + Value *Opnd = Inst->getOperand(Idx); + + // Visit constant integers. + if (auto ConstInt = dyn_cast<ConstantInt>(Opnd)) { + collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); + return; + } + + // Visit cast instructions that have constant integers. + if (auto CastInst = dyn_cast<Instruction>(Opnd)) { + // Only visit cast instructions, which have been skipped. All other + // instructions should have already been visited. + if (!CastInst->isCast()) + return; + + if (auto *ConstInt = dyn_cast<ConstantInt>(CastInst->getOperand(0))) { + // Pretend the constant is directly used by the instruction and ignore + // the cast instruction. + collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); + return; + } + } + + // Visit constant expressions that have constant integers. + if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) { + // Only visit constant cast expressions. + if (!ConstExpr->isCast()) + return; + + if (auto ConstInt = dyn_cast<ConstantInt>(ConstExpr->getOperand(0))) { + // Pretend the constant is directly used by the instruction and ignore + // the constant expression. + collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt); + return; + } + } +} + +/// \brief Scan the instruction for expensive integer constants and record them +/// in the constant candidate vector. +void ConstantHoistingPass::collectConstantCandidates( + ConstCandMapType &ConstCandMap, Instruction *Inst) { + // Skip all cast instructions. They are visited indirectly later on. + if (Inst->isCast()) + return; + + // Scan all operands. + for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) { + // The cost of materializing the constants (defined in + // `TargetTransformInfo::getIntImmCost`) for instructions which only take + // constant variables is lower than `TargetTransformInfo::TCC_Basic`. So + // it's safe for us to collect constant candidates from all IntrinsicInsts. + if (canReplaceOperandWithVariable(Inst, Idx) || isa<IntrinsicInst>(Inst)) { + collectConstantCandidates(ConstCandMap, Inst, Idx); + } + } // end of for all operands +} + +/// \brief Collect all integer constants in the function that cannot be folded +/// into an instruction itself. +void ConstantHoistingPass::collectConstantCandidates(Function &Fn) { + ConstCandMapType ConstCandMap; + for (BasicBlock &BB : Fn) + for (Instruction &Inst : BB) + collectConstantCandidates(ConstCandMap, &Inst); +} + +// This helper function is necessary to deal with values that have different +// bit widths (APInt Operator- does not like that). If the value cannot be +// represented in uint64 we return an "empty" APInt. This is then interpreted +// as the value is not in range. +static Optional<APInt> calculateOffsetDiff(const APInt &V1, const APInt &V2) { + Optional<APInt> Res = None; + unsigned BW = V1.getBitWidth() > V2.getBitWidth() ? + V1.getBitWidth() : V2.getBitWidth(); + uint64_t LimVal1 = V1.getLimitedValue(); + uint64_t LimVal2 = V2.getLimitedValue(); + + if (LimVal1 == ~0ULL || LimVal2 == ~0ULL) + return Res; + + uint64_t Diff = LimVal1 - LimVal2; + return APInt(BW, Diff, true); +} + +// From a list of constants, one needs to picked as the base and the other +// constants will be transformed into an offset from that base constant. The +// question is which we can pick best? For example, consider these constants +// and their number of uses: +// +// Constants| 2 | 4 | 12 | 42 | +// NumUses | 3 | 2 | 8 | 7 | +// +// Selecting constant 12 because it has the most uses will generate negative +// offsets for constants 2 and 4 (i.e. -10 and -8 respectively). If negative +// offsets lead to less optimal code generation, then there might be better +// solutions. Suppose immediates in the range of 0..35 are most optimally +// supported by the architecture, then selecting constant 2 is most optimal +// because this will generate offsets: 0, 2, 10, 40. Offsets 0, 2 and 10 are in +// range 0..35, and thus 3 + 2 + 8 = 13 uses are in range. Selecting 12 would +// have only 8 uses in range, so choosing 2 as a base is more optimal. Thus, in +// selecting the base constant the range of the offsets is a very important +// factor too that we take into account here. This algorithm calculates a total +// costs for selecting a constant as the base and substract the costs if +// immediates are out of range. It has quadratic complexity, so we call this +// function only when we're optimising for size and there are less than 100 +// constants, we fall back to the straightforward algorithm otherwise +// which does not do all the offset calculations. +unsigned +ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S, + ConstCandVecType::iterator E, + ConstCandVecType::iterator &MaxCostItr) { + unsigned NumUses = 0; + + if(!Entry->getParent()->optForSize() || std::distance(S,E) > 100) { + for (auto ConstCand = S; ConstCand != E; ++ConstCand) { + NumUses += ConstCand->Uses.size(); + if (ConstCand->CumulativeCost > MaxCostItr->CumulativeCost) + MaxCostItr = ConstCand; + } + return NumUses; + } + + DEBUG(dbgs() << "== Maximize constants in range ==\n"); + int MaxCost = -1; + for (auto ConstCand = S; ConstCand != E; ++ConstCand) { + auto Value = ConstCand->ConstInt->getValue(); + Type *Ty = ConstCand->ConstInt->getType(); + int Cost = 0; + NumUses += ConstCand->Uses.size(); + DEBUG(dbgs() << "= Constant: " << ConstCand->ConstInt->getValue() << "\n"); + + for (auto User : ConstCand->Uses) { + unsigned Opcode = User.Inst->getOpcode(); + unsigned OpndIdx = User.OpndIdx; + Cost += TTI->getIntImmCost(Opcode, OpndIdx, Value, Ty); + DEBUG(dbgs() << "Cost: " << Cost << "\n"); + + for (auto C2 = S; C2 != E; ++C2) { + Optional<APInt> Diff = calculateOffsetDiff( + C2->ConstInt->getValue(), + ConstCand->ConstInt->getValue()); + if (Diff) { + const int ImmCosts = + TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, Diff.getValue(), Ty); + Cost -= ImmCosts; + DEBUG(dbgs() << "Offset " << Diff.getValue() << " " + << "has penalty: " << ImmCosts << "\n" + << "Adjusted cost: " << Cost << "\n"); + } + } + } + DEBUG(dbgs() << "Cumulative cost: " << Cost << "\n"); + if (Cost > MaxCost) { + MaxCost = Cost; + MaxCostItr = ConstCand; + DEBUG(dbgs() << "New candidate: " << MaxCostItr->ConstInt->getValue() + << "\n"); + } + } + return NumUses; +} + +/// \brief Find the base constant within the given range and rebase all other +/// constants with respect to the base constant. +void ConstantHoistingPass::findAndMakeBaseConstant( + ConstCandVecType::iterator S, ConstCandVecType::iterator E) { + auto MaxCostItr = S; + unsigned NumUses = maximizeConstantsInRange(S, E, MaxCostItr); + + // Don't hoist constants that have only one use. + if (NumUses <= 1) + return; + + ConstantInfo ConstInfo; + ConstInfo.BaseConstant = MaxCostItr->ConstInt; + Type *Ty = ConstInfo.BaseConstant->getType(); + + // Rebase the constants with respect to the base constant. + for (auto ConstCand = S; ConstCand != E; ++ConstCand) { + APInt Diff = ConstCand->ConstInt->getValue() - + ConstInfo.BaseConstant->getValue(); + Constant *Offset = Diff == 0 ? nullptr : ConstantInt::get(Ty, Diff); + ConstInfo.RebasedConstants.push_back( + RebasedConstantInfo(std::move(ConstCand->Uses), Offset)); + } + ConstantVec.push_back(std::move(ConstInfo)); +} + +/// \brief Finds and combines constant candidates that can be easily +/// rematerialized with an add from a common base constant. +void ConstantHoistingPass::findBaseConstants() { + // Sort the constants by value and type. This invalidates the mapping! + std::sort(ConstCandVec.begin(), ConstCandVec.end(), + [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) { + if (LHS.ConstInt->getType() != RHS.ConstInt->getType()) + return LHS.ConstInt->getType()->getBitWidth() < + RHS.ConstInt->getType()->getBitWidth(); + return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue()); + }); + + // Simple linear scan through the sorted constant candidate vector for viable + // merge candidates. + auto MinValItr = ConstCandVec.begin(); + for (auto CC = std::next(ConstCandVec.begin()), E = ConstCandVec.end(); + CC != E; ++CC) { + if (MinValItr->ConstInt->getType() == CC->ConstInt->getType()) { + // Check if the constant is in range of an add with immediate. + APInt Diff = CC->ConstInt->getValue() - MinValItr->ConstInt->getValue(); + if ((Diff.getBitWidth() <= 64) && + TTI->isLegalAddImmediate(Diff.getSExtValue())) + continue; + } + // We either have now a different constant type or the constant is not in + // range of an add with immediate anymore. + findAndMakeBaseConstant(MinValItr, CC); + // Start a new base constant search. + MinValItr = CC; + } + // Finalize the last base constant search. + findAndMakeBaseConstant(MinValItr, ConstCandVec.end()); +} + +/// \brief Updates the operand at Idx in instruction Inst with the result of +/// instruction Mat. If the instruction is a PHI node then special +/// handling for duplicate values form the same incoming basic block is +/// required. +/// \return The update will always succeed, but the return value indicated if +/// Mat was used for the update or not. +static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) { + if (auto PHI = dyn_cast<PHINode>(Inst)) { + // Check if any previous operand of the PHI node has the same incoming basic + // block. This is a very odd case that happens when the incoming basic block + // has a switch statement. In this case use the same value as the previous + // operand(s), otherwise we will fail verification due to different values. + // The values are actually the same, but the variable names are different + // and the verifier doesn't like that. + BasicBlock *IncomingBB = PHI->getIncomingBlock(Idx); + for (unsigned i = 0; i < Idx; ++i) { + if (PHI->getIncomingBlock(i) == IncomingBB) { + Value *IncomingVal = PHI->getIncomingValue(i); + Inst->setOperand(Idx, IncomingVal); + return false; + } + } + } + + Inst->setOperand(Idx, Mat); + return true; +} + +/// \brief Emit materialization code for all rebased constants and update their +/// users. +void ConstantHoistingPass::emitBaseConstants(Instruction *Base, + Constant *Offset, + const ConstantUser &ConstUser) { + Instruction *Mat = Base; + if (Offset) { + Instruction *InsertionPt = findMatInsertPt(ConstUser.Inst, + ConstUser.OpndIdx); + Mat = BinaryOperator::Create(Instruction::Add, Base, Offset, + "const_mat", InsertionPt); + + DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0) + << " + " << *Offset << ") in BB " + << Mat->getParent()->getName() << '\n' << *Mat << '\n'); + Mat->setDebugLoc(ConstUser.Inst->getDebugLoc()); + } + Value *Opnd = ConstUser.Inst->getOperand(ConstUser.OpndIdx); + + // Visit constant integer. + if (isa<ConstantInt>(Opnd)) { + DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); + if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat) && Offset) + Mat->eraseFromParent(); + DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + return; + } + + // Visit cast instruction. + if (auto CastInst = dyn_cast<Instruction>(Opnd)) { + assert(CastInst->isCast() && "Expected an cast instruction!"); + // Check if we already have visited this cast instruction before to avoid + // unnecessary cloning. + Instruction *&ClonedCastInst = ClonedCastMap[CastInst]; + if (!ClonedCastInst) { + ClonedCastInst = CastInst->clone(); + ClonedCastInst->setOperand(0, Mat); + ClonedCastInst->insertAfter(CastInst); + // Use the same debug location as the original cast instruction. + ClonedCastInst->setDebugLoc(CastInst->getDebugLoc()); + DEBUG(dbgs() << "Clone instruction: " << *CastInst << '\n' + << "To : " << *ClonedCastInst << '\n'); + } + + DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); + updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ClonedCastInst); + DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + return; + } + + // Visit constant expression. + if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) { + Instruction *ConstExprInst = ConstExpr->getAsInstruction(); + ConstExprInst->setOperand(0, Mat); + ConstExprInst->insertBefore(findMatInsertPt(ConstUser.Inst, + ConstUser.OpndIdx)); + + // Use the same debug location as the instruction we are about to update. + ConstExprInst->setDebugLoc(ConstUser.Inst->getDebugLoc()); + + DEBUG(dbgs() << "Create instruction: " << *ConstExprInst << '\n' + << "From : " << *ConstExpr << '\n'); + DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n'); + if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ConstExprInst)) { + ConstExprInst->eraseFromParent(); + if (Offset) + Mat->eraseFromParent(); + } + DEBUG(dbgs() << "To : " << *ConstUser.Inst << '\n'); + return; + } +} + +/// \brief Hoist and hide the base constant behind a bitcast and emit +/// materialization code for derived constants. +bool ConstantHoistingPass::emitBaseConstants() { + bool MadeChange = false; + for (auto const &ConstInfo : ConstantVec) { + // Hoist and hide the base constant behind a bitcast. + SmallPtrSet<Instruction *, 8> IPSet = findConstantInsertionPoint(ConstInfo); + assert(!IPSet.empty() && "IPSet is empty"); + + unsigned UsesNum = 0; + unsigned ReBasesNum = 0; + for (Instruction *IP : IPSet) { + IntegerType *Ty = ConstInfo.BaseConstant->getType(); + Instruction *Base = + new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP); + + Base->setDebugLoc(IP->getDebugLoc()); + + DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant + << ") to BB " << IP->getParent()->getName() << '\n' + << *Base << '\n'); + + // Emit materialization code for all rebased constants. + unsigned Uses = 0; + for (auto const &RCI : ConstInfo.RebasedConstants) { + for (auto const &U : RCI.Uses) { + Uses++; + BasicBlock *OrigMatInsertBB = + findMatInsertPt(U.Inst, U.OpndIdx)->getParent(); + // If Base constant is to be inserted in multiple places, + // generate rebase for U using the Base dominating U. + if (IPSet.size() == 1 || + DT->dominates(Base->getParent(), OrigMatInsertBB)) { + emitBaseConstants(Base, RCI.Offset, U); + ReBasesNum++; + } + + Base->setDebugLoc(DILocation::getMergedLocation(Base->getDebugLoc(), U.Inst->getDebugLoc())); + } + } + UsesNum = Uses; + + // Use the same debug location as the last user of the constant. + assert(!Base->use_empty() && "The use list is empty!?"); + assert(isa<Instruction>(Base->user_back()) && + "All uses should be instructions."); + } + (void)UsesNum; + (void)ReBasesNum; + // Expect all uses are rebased after rebase is done. + assert(UsesNum == ReBasesNum && "Not all uses are rebased"); + + NumConstantsHoisted++; + + // Base constant is also included in ConstInfo.RebasedConstants, so + // deduct 1 from ConstInfo.RebasedConstants.size(). + NumConstantsRebased = ConstInfo.RebasedConstants.size() - 1; + + MadeChange = true; + } + return MadeChange; +} + +/// \brief Check all cast instructions we made a copy of and remove them if they +/// have no more users. +void ConstantHoistingPass::deleteDeadCastInst() const { + for (auto const &I : ClonedCastMap) + if (I.first->use_empty()) + I.first->eraseFromParent(); +} + +/// \brief Optimize expensive integer constants in the given function. +bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI, + DominatorTree &DT, BlockFrequencyInfo *BFI, + BasicBlock &Entry) { + this->TTI = &TTI; + this->DT = &DT; + this->BFI = BFI; + this->Entry = &Entry; + // Collect all constant candidates. + collectConstantCandidates(Fn); + + // There are no constant candidates to worry about. + if (ConstCandVec.empty()) + return false; + + // Combine constants that can be easily materialized with an add from a common + // base constant. + findBaseConstants(); + + // There are no constants to emit. + if (ConstantVec.empty()) + return false; + + // Finally hoist the base constant and emit materialization code for dependent + // constants. + bool MadeChange = emitBaseConstants(); + + // Cleanup dead instructions. + deleteDeadCastInst(); + + return MadeChange; +} + +PreservedAnalyses ConstantHoistingPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto BFI = ConstHoistWithBlockFrequency + ? &AM.getResult<BlockFrequencyAnalysis>(F) + : nullptr; + if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock())) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/ConstantProp.cpp b/contrib/llvm/lib/Transforms/Scalar/ConstantProp.cpp new file mode 100644 index 000000000000..4fa27891a974 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/ConstantProp.cpp @@ -0,0 +1,104 @@ +//===- ConstantProp.cpp - Code to perform Simple Constant Propagation -----===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements constant propagation and merging: +// +// Specifically, this: +// * Converts instructions like "add int 1, 2" into 3 +// +// Notice that: +// * This pass has a habit of making definitions be dead. It is a good idea +// to run a DIE pass sometime after running this pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <set> +using namespace llvm; + +#define DEBUG_TYPE "constprop" + +STATISTIC(NumInstKilled, "Number of instructions killed"); + +namespace { + struct ConstantPropagation : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + ConstantPropagation() : FunctionPass(ID) { + initializeConstantPropagationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } + }; +} + +char ConstantPropagation::ID = 0; +INITIALIZE_PASS_BEGIN(ConstantPropagation, "constprop", + "Simple constant propagation", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(ConstantPropagation, "constprop", + "Simple constant propagation", false, false) + +FunctionPass *llvm::createConstantPropagationPass() { + return new ConstantPropagation(); +} + +bool ConstantPropagation::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + // Initialize the worklist to all of the instructions ready to process... + std::set<Instruction*> WorkList; + for (Instruction &I: instructions(&F)) + WorkList.insert(&I); + + bool Changed = false; + const DataLayout &DL = F.getParent()->getDataLayout(); + TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + while (!WorkList.empty()) { + Instruction *I = *WorkList.begin(); + WorkList.erase(WorkList.begin()); // Get an element from the worklist... + + if (!I->use_empty()) // Don't muck with dead instructions... + if (Constant *C = ConstantFoldInstruction(I, DL, TLI)) { + // Add all of the users of this instruction to the worklist, they might + // be constant propagatable now... + for (User *U : I->users()) + WorkList.insert(cast<Instruction>(U)); + + // Replace all of the uses of a variable with uses of the constant. + I->replaceAllUsesWith(C); + + // Remove the dead instruction. + WorkList.erase(I); + if (isInstructionTriviallyDead(I, TLI)) { + I->eraseFromParent(); + ++NumInstKilled; + } + + // We made a change to the function... + Changed = true; + } + } + return Changed; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp b/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp new file mode 100644 index 000000000000..8f468ebf8949 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/CorrelatedValuePropagation.cpp @@ -0,0 +1,653 @@ +//===- CorrelatedValuePropagation.cpp - Propagate CFG-derived info --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Correlated Value Propagation pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/CorrelatedValuePropagation.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "correlated-value-propagation" + +STATISTIC(NumPhis, "Number of phis propagated"); +STATISTIC(NumSelects, "Number of selects propagated"); +STATISTIC(NumMemAccess, "Number of memory access targets propagated"); +STATISTIC(NumCmps, "Number of comparisons propagated"); +STATISTIC(NumReturns, "Number of return values propagated"); +STATISTIC(NumDeadCases, "Number of switch cases removed"); +STATISTIC(NumSDivs, "Number of sdiv converted to udiv"); +STATISTIC(NumAShrs, "Number of ashr converted to lshr"); +STATISTIC(NumSRems, "Number of srem converted to urem"); +STATISTIC(NumOverflows, "Number of overflow checks removed"); + +static cl::opt<bool> DontProcessAdds("cvp-dont-process-adds", cl::init(true)); + +namespace { + + class CorrelatedValuePropagation : public FunctionPass { + public: + static char ID; + + CorrelatedValuePropagation(): FunctionPass(ID) { + initializeCorrelatedValuePropagationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<LazyValueInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + }; + +} // end anonymous namespace + +char CorrelatedValuePropagation::ID = 0; + +INITIALIZE_PASS_BEGIN(CorrelatedValuePropagation, "correlated-propagation", + "Value Propagation", false, false) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_END(CorrelatedValuePropagation, "correlated-propagation", + "Value Propagation", false, false) + +// Public interface to the Value Propagation pass +Pass *llvm::createCorrelatedValuePropagationPass() { + return new CorrelatedValuePropagation(); +} + +static bool processSelect(SelectInst *S, LazyValueInfo *LVI) { + if (S->getType()->isVectorTy()) return false; + if (isa<Constant>(S->getOperand(0))) return false; + + Constant *C = LVI->getConstant(S->getOperand(0), S->getParent(), S); + if (!C) return false; + + ConstantInt *CI = dyn_cast<ConstantInt>(C); + if (!CI) return false; + + Value *ReplaceWith = S->getOperand(1); + Value *Other = S->getOperand(2); + if (!CI->isOne()) std::swap(ReplaceWith, Other); + if (ReplaceWith == S) ReplaceWith = UndefValue::get(S->getType()); + + S->replaceAllUsesWith(ReplaceWith); + S->eraseFromParent(); + + ++NumSelects; + + return true; +} + +static bool processPHI(PHINode *P, LazyValueInfo *LVI, + const SimplifyQuery &SQ) { + bool Changed = false; + + BasicBlock *BB = P->getParent(); + for (unsigned i = 0, e = P->getNumIncomingValues(); i < e; ++i) { + Value *Incoming = P->getIncomingValue(i); + if (isa<Constant>(Incoming)) continue; + + Value *V = LVI->getConstantOnEdge(Incoming, P->getIncomingBlock(i), BB, P); + + // Look if the incoming value is a select with a scalar condition for which + // LVI can tells us the value. In that case replace the incoming value with + // the appropriate value of the select. This often allows us to remove the + // select later. + if (!V) { + SelectInst *SI = dyn_cast<SelectInst>(Incoming); + if (!SI) continue; + + Value *Condition = SI->getCondition(); + if (!Condition->getType()->isVectorTy()) { + if (Constant *C = LVI->getConstantOnEdge( + Condition, P->getIncomingBlock(i), BB, P)) { + if (C->isOneValue()) { + V = SI->getTrueValue(); + } else if (C->isZeroValue()) { + V = SI->getFalseValue(); + } + // Once LVI learns to handle vector types, we could also add support + // for vector type constants that are not all zeroes or all ones. + } + } + + // Look if the select has a constant but LVI tells us that the incoming + // value can never be that constant. In that case replace the incoming + // value with the other value of the select. This often allows us to + // remove the select later. + if (!V) { + Constant *C = dyn_cast<Constant>(SI->getFalseValue()); + if (!C) continue; + + if (LVI->getPredicateOnEdge(ICmpInst::ICMP_EQ, SI, C, + P->getIncomingBlock(i), BB, P) != + LazyValueInfo::False) + continue; + V = SI->getTrueValue(); + } + + DEBUG(dbgs() << "CVP: Threading PHI over " << *SI << '\n'); + } + + P->setIncomingValue(i, V); + Changed = true; + } + + if (Value *V = SimplifyInstruction(P, SQ)) { + P->replaceAllUsesWith(V); + P->eraseFromParent(); + Changed = true; + } + + if (Changed) + ++NumPhis; + + return Changed; +} + +static bool processMemAccess(Instruction *I, LazyValueInfo *LVI) { + Value *Pointer = nullptr; + if (LoadInst *L = dyn_cast<LoadInst>(I)) + Pointer = L->getPointerOperand(); + else + Pointer = cast<StoreInst>(I)->getPointerOperand(); + + if (isa<Constant>(Pointer)) return false; + + Constant *C = LVI->getConstant(Pointer, I->getParent(), I); + if (!C) return false; + + ++NumMemAccess; + I->replaceUsesOfWith(Pointer, C); + return true; +} + +/// See if LazyValueInfo's ability to exploit edge conditions or range +/// information is sufficient to prove this comparison. Even for local +/// conditions, this can sometimes prove conditions instcombine can't by +/// exploiting range information. +static bool processCmp(CmpInst *C, LazyValueInfo *LVI) { + Value *Op0 = C->getOperand(0); + Constant *Op1 = dyn_cast<Constant>(C->getOperand(1)); + if (!Op1) return false; + + // As a policy choice, we choose not to waste compile time on anything where + // the comparison is testing local values. While LVI can sometimes reason + // about such cases, it's not its primary purpose. We do make sure to do + // the block local query for uses from terminator instructions, but that's + // handled in the code for each terminator. + auto *I = dyn_cast<Instruction>(Op0); + if (I && I->getParent() == C->getParent()) + return false; + + LazyValueInfo::Tristate Result = + LVI->getPredicateAt(C->getPredicate(), Op0, Op1, C); + if (Result == LazyValueInfo::Unknown) return false; + + ++NumCmps; + if (Result == LazyValueInfo::True) + C->replaceAllUsesWith(ConstantInt::getTrue(C->getContext())); + else + C->replaceAllUsesWith(ConstantInt::getFalse(C->getContext())); + C->eraseFromParent(); + + return true; +} + +/// Simplify a switch instruction by removing cases which can never fire. If the +/// uselessness of a case could be determined locally then constant propagation +/// would already have figured it out. Instead, walk the predecessors and +/// statically evaluate cases based on information available on that edge. Cases +/// that cannot fire no matter what the incoming edge can safely be removed. If +/// a case fires on every incoming edge then the entire switch can be removed +/// and replaced with a branch to the case destination. +static bool processSwitch(SwitchInst *SI, LazyValueInfo *LVI) { + Value *Cond = SI->getCondition(); + BasicBlock *BB = SI->getParent(); + + // If the condition was defined in same block as the switch then LazyValueInfo + // currently won't say anything useful about it, though in theory it could. + if (isa<Instruction>(Cond) && cast<Instruction>(Cond)->getParent() == BB) + return false; + + // If the switch is unreachable then trying to improve it is a waste of time. + pred_iterator PB = pred_begin(BB), PE = pred_end(BB); + if (PB == PE) return false; + + // Analyse each switch case in turn. + bool Changed = false; + for (auto CI = SI->case_begin(), CE = SI->case_end(); CI != CE;) { + ConstantInt *Case = CI->getCaseValue(); + + // Check to see if the switch condition is equal to/not equal to the case + // value on every incoming edge, equal/not equal being the same each time. + LazyValueInfo::Tristate State = LazyValueInfo::Unknown; + for (pred_iterator PI = PB; PI != PE; ++PI) { + // Is the switch condition equal to the case value? + LazyValueInfo::Tristate Value = LVI->getPredicateOnEdge(CmpInst::ICMP_EQ, + Cond, Case, *PI, + BB, SI); + // Give up on this case if nothing is known. + if (Value == LazyValueInfo::Unknown) { + State = LazyValueInfo::Unknown; + break; + } + + // If this was the first edge to be visited, record that all other edges + // need to give the same result. + if (PI == PB) { + State = Value; + continue; + } + + // If this case is known to fire for some edges and known not to fire for + // others then there is nothing we can do - give up. + if (Value != State) { + State = LazyValueInfo::Unknown; + break; + } + } + + if (State == LazyValueInfo::False) { + // This case never fires - remove it. + CI->getCaseSuccessor()->removePredecessor(BB); + CI = SI->removeCase(CI); + CE = SI->case_end(); + + // The condition can be modified by removePredecessor's PHI simplification + // logic. + Cond = SI->getCondition(); + + ++NumDeadCases; + Changed = true; + continue; + } + if (State == LazyValueInfo::True) { + // This case always fires. Arrange for the switch to be turned into an + // unconditional branch by replacing the switch condition with the case + // value. + SI->setCondition(Case); + NumDeadCases += SI->getNumCases(); + Changed = true; + break; + } + + // Increment the case iterator since we didn't delete it. + ++CI; + } + + if (Changed) + // If the switch has been simplified to the point where it can be replaced + // by a branch then do so now. + ConstantFoldTerminator(BB); + + return Changed; +} + +// See if we can prove that the given overflow intrinsic will not overflow. +static bool willNotOverflow(IntrinsicInst *II, LazyValueInfo *LVI) { + using OBO = OverflowingBinaryOperator; + auto NoWrap = [&] (Instruction::BinaryOps BinOp, unsigned NoWrapKind) { + Value *RHS = II->getOperand(1); + ConstantRange RRange = LVI->getConstantRange(RHS, II->getParent(), II); + ConstantRange NWRegion = ConstantRange::makeGuaranteedNoWrapRegion( + BinOp, RRange, NoWrapKind); + // As an optimization, do not compute LRange if we do not need it. + if (NWRegion.isEmptySet()) + return false; + Value *LHS = II->getOperand(0); + ConstantRange LRange = LVI->getConstantRange(LHS, II->getParent(), II); + return NWRegion.contains(LRange); + }; + switch (II->getIntrinsicID()) { + default: + break; + case Intrinsic::uadd_with_overflow: + return NoWrap(Instruction::Add, OBO::NoUnsignedWrap); + case Intrinsic::sadd_with_overflow: + return NoWrap(Instruction::Add, OBO::NoSignedWrap); + case Intrinsic::usub_with_overflow: + return NoWrap(Instruction::Sub, OBO::NoUnsignedWrap); + case Intrinsic::ssub_with_overflow: + return NoWrap(Instruction::Sub, OBO::NoSignedWrap); + } + return false; +} + +static void processOverflowIntrinsic(IntrinsicInst *II) { + Value *NewOp = nullptr; + switch (II->getIntrinsicID()) { + default: + llvm_unreachable("Unexpected instruction."); + case Intrinsic::uadd_with_overflow: + case Intrinsic::sadd_with_overflow: + NewOp = BinaryOperator::CreateAdd(II->getOperand(0), II->getOperand(1), + II->getName(), II); + break; + case Intrinsic::usub_with_overflow: + case Intrinsic::ssub_with_overflow: + NewOp = BinaryOperator::CreateSub(II->getOperand(0), II->getOperand(1), + II->getName(), II); + break; + } + ++NumOverflows; + IRBuilder<> B(II); + Value *NewI = B.CreateInsertValue(UndefValue::get(II->getType()), NewOp, 0); + NewI = B.CreateInsertValue(NewI, ConstantInt::getFalse(II->getContext()), 1); + II->replaceAllUsesWith(NewI); + II->eraseFromParent(); +} + +/// Infer nonnull attributes for the arguments at the specified callsite. +static bool processCallSite(CallSite CS, LazyValueInfo *LVI) { + SmallVector<unsigned, 4> ArgNos; + unsigned ArgNo = 0; + + if (auto *II = dyn_cast<IntrinsicInst>(CS.getInstruction())) { + if (willNotOverflow(II, LVI)) { + processOverflowIntrinsic(II); + return true; + } + } + + for (Value *V : CS.args()) { + PointerType *Type = dyn_cast<PointerType>(V->getType()); + // Try to mark pointer typed parameters as non-null. We skip the + // relatively expensive analysis for constants which are obviously either + // null or non-null to start with. + if (Type && !CS.paramHasAttr(ArgNo, Attribute::NonNull) && + !isa<Constant>(V) && + LVI->getPredicateAt(ICmpInst::ICMP_EQ, V, + ConstantPointerNull::get(Type), + CS.getInstruction()) == LazyValueInfo::False) + ArgNos.push_back(ArgNo); + ArgNo++; + } + + assert(ArgNo == CS.arg_size() && "sanity check"); + + if (ArgNos.empty()) + return false; + + AttributeList AS = CS.getAttributes(); + LLVMContext &Ctx = CS.getInstruction()->getContext(); + AS = AS.addParamAttribute(Ctx, ArgNos, + Attribute::get(Ctx, Attribute::NonNull)); + CS.setAttributes(AS); + + return true; +} + +static bool hasPositiveOperands(BinaryOperator *SDI, LazyValueInfo *LVI) { + Constant *Zero = ConstantInt::get(SDI->getType(), 0); + for (Value *O : SDI->operands()) { + auto Result = LVI->getPredicateAt(ICmpInst::ICMP_SGE, O, Zero, SDI); + if (Result != LazyValueInfo::True) + return false; + } + return true; +} + +static bool processSRem(BinaryOperator *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy() || + !hasPositiveOperands(SDI, LVI)) + return false; + + ++NumSRems; + auto *BO = BinaryOperator::CreateURem(SDI->getOperand(0), SDI->getOperand(1), + SDI->getName(), SDI); + SDI->replaceAllUsesWith(BO); + SDI->eraseFromParent(); + return true; +} + +/// See if LazyValueInfo's ability to exploit edge conditions or range +/// information is sufficient to prove the both operands of this SDiv are +/// positive. If this is the case, replace the SDiv with a UDiv. Even for local +/// conditions, this can sometimes prove conditions instcombine can't by +/// exploiting range information. +static bool processSDiv(BinaryOperator *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy() || + !hasPositiveOperands(SDI, LVI)) + return false; + + ++NumSDivs; + auto *BO = BinaryOperator::CreateUDiv(SDI->getOperand(0), SDI->getOperand(1), + SDI->getName(), SDI); + BO->setIsExact(SDI->isExact()); + SDI->replaceAllUsesWith(BO); + SDI->eraseFromParent(); + + return true; +} + +static bool processAShr(BinaryOperator *SDI, LazyValueInfo *LVI) { + if (SDI->getType()->isVectorTy()) + return false; + + Constant *Zero = ConstantInt::get(SDI->getType(), 0); + if (LVI->getPredicateAt(ICmpInst::ICMP_SGE, SDI->getOperand(0), Zero, SDI) != + LazyValueInfo::True) + return false; + + ++NumAShrs; + auto *BO = BinaryOperator::CreateLShr(SDI->getOperand(0), SDI->getOperand(1), + SDI->getName(), SDI); + BO->setIsExact(SDI->isExact()); + SDI->replaceAllUsesWith(BO); + SDI->eraseFromParent(); + + return true; +} + +static bool processAdd(BinaryOperator *AddOp, LazyValueInfo *LVI) { + using OBO = OverflowingBinaryOperator; + + if (DontProcessAdds) + return false; + + if (AddOp->getType()->isVectorTy()) + return false; + + bool NSW = AddOp->hasNoSignedWrap(); + bool NUW = AddOp->hasNoUnsignedWrap(); + if (NSW && NUW) + return false; + + BasicBlock *BB = AddOp->getParent(); + + Value *LHS = AddOp->getOperand(0); + Value *RHS = AddOp->getOperand(1); + + ConstantRange LRange = LVI->getConstantRange(LHS, BB, AddOp); + + // Initialize RRange only if we need it. If we know that guaranteed no wrap + // range for the given LHS range is empty don't spend time calculating the + // range for the RHS. + Optional<ConstantRange> RRange; + auto LazyRRange = [&] () { + if (!RRange) + RRange = LVI->getConstantRange(RHS, BB, AddOp); + return RRange.getValue(); + }; + + bool Changed = false; + if (!NUW) { + ConstantRange NUWRange = ConstantRange::makeGuaranteedNoWrapRegion( + BinaryOperator::Add, LRange, OBO::NoUnsignedWrap); + if (!NUWRange.isEmptySet()) { + bool NewNUW = NUWRange.contains(LazyRRange()); + AddOp->setHasNoUnsignedWrap(NewNUW); + Changed |= NewNUW; + } + } + if (!NSW) { + ConstantRange NSWRange = ConstantRange::makeGuaranteedNoWrapRegion( + BinaryOperator::Add, LRange, OBO::NoSignedWrap); + if (!NSWRange.isEmptySet()) { + bool NewNSW = NSWRange.contains(LazyRRange()); + AddOp->setHasNoSignedWrap(NewNSW); + Changed |= NewNSW; + } + } + + return Changed; +} + +static Constant *getConstantAt(Value *V, Instruction *At, LazyValueInfo *LVI) { + if (Constant *C = LVI->getConstant(V, At->getParent(), At)) + return C; + + // TODO: The following really should be sunk inside LVI's core algorithm, or + // at least the outer shims around such. + auto *C = dyn_cast<CmpInst>(V); + if (!C) return nullptr; + + Value *Op0 = C->getOperand(0); + Constant *Op1 = dyn_cast<Constant>(C->getOperand(1)); + if (!Op1) return nullptr; + + LazyValueInfo::Tristate Result = + LVI->getPredicateAt(C->getPredicate(), Op0, Op1, At); + if (Result == LazyValueInfo::Unknown) + return nullptr; + + return (Result == LazyValueInfo::True) ? + ConstantInt::getTrue(C->getContext()) : + ConstantInt::getFalse(C->getContext()); +} + +static bool runImpl(Function &F, LazyValueInfo *LVI, const SimplifyQuery &SQ) { + bool FnChanged = false; + // Visiting in a pre-order depth-first traversal causes us to simplify early + // blocks before querying later blocks (which require us to analyze early + // blocks). Eagerly simplifying shallow blocks means there is strictly less + // work to do for deep blocks. This also means we don't visit unreachable + // blocks. + for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { + bool BBChanged = false; + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { + Instruction *II = &*BI++; + switch (II->getOpcode()) { + case Instruction::Select: + BBChanged |= processSelect(cast<SelectInst>(II), LVI); + break; + case Instruction::PHI: + BBChanged |= processPHI(cast<PHINode>(II), LVI, SQ); + break; + case Instruction::ICmp: + case Instruction::FCmp: + BBChanged |= processCmp(cast<CmpInst>(II), LVI); + break; + case Instruction::Load: + case Instruction::Store: + BBChanged |= processMemAccess(II, LVI); + break; + case Instruction::Call: + case Instruction::Invoke: + BBChanged |= processCallSite(CallSite(II), LVI); + break; + case Instruction::SRem: + BBChanged |= processSRem(cast<BinaryOperator>(II), LVI); + break; + case Instruction::SDiv: + BBChanged |= processSDiv(cast<BinaryOperator>(II), LVI); + break; + case Instruction::AShr: + BBChanged |= processAShr(cast<BinaryOperator>(II), LVI); + break; + case Instruction::Add: + BBChanged |= processAdd(cast<BinaryOperator>(II), LVI); + break; + } + } + + Instruction *Term = BB->getTerminator(); + switch (Term->getOpcode()) { + case Instruction::Switch: + BBChanged |= processSwitch(cast<SwitchInst>(Term), LVI); + break; + case Instruction::Ret: { + auto *RI = cast<ReturnInst>(Term); + // Try to determine the return value if we can. This is mainly here to + // simplify the writing of unit tests, but also helps to enable IPO by + // constant folding the return values of callees. + auto *RetVal = RI->getReturnValue(); + if (!RetVal) break; // handle "ret void" + if (isa<Constant>(RetVal)) break; // nothing to do + if (auto *C = getConstantAt(RetVal, RI, LVI)) { + ++NumReturns; + RI->replaceUsesOfWith(RetVal, C); + BBChanged = true; + } + } + } + + FnChanged |= BBChanged; + } + + return FnChanged; +} + +bool CorrelatedValuePropagation::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + LazyValueInfo *LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + return runImpl(F, LVI, getBestSimplifyQuery(*this, F)); +} + +PreservedAnalyses +CorrelatedValuePropagationPass::run(Function &F, FunctionAnalysisManager &AM) { + + LazyValueInfo *LVI = &AM.getResult<LazyValueAnalysis>(F); + bool Changed = runImpl(F, LVI, getBestSimplifyQuery(AM, F)); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/DCE.cpp b/contrib/llvm/lib/Transforms/Scalar/DCE.cpp new file mode 100644 index 000000000000..fa4806e884c3 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/DCE.cpp @@ -0,0 +1,163 @@ +//===- DCE.cpp - Code to perform dead code elimination --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements dead inst elimination and dead code elimination. +// +// Dead Inst Elimination performs a single pass over the function removing +// instructions that are obviously dead. Dead Code Elimination is similar, but +// it rechecks instructions that were used by removed instructions to see if +// they are newly dead. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/DCE.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +using namespace llvm; + +#define DEBUG_TYPE "dce" + +STATISTIC(DIEEliminated, "Number of insts removed by DIE pass"); +STATISTIC(DCEEliminated, "Number of insts removed"); + +namespace { + //===--------------------------------------------------------------------===// + // DeadInstElimination pass implementation + // + struct DeadInstElimination : public BasicBlockPass { + static char ID; // Pass identification, replacement for typeid + DeadInstElimination() : BasicBlockPass(ID) { + initializeDeadInstEliminationPass(*PassRegistry::getPassRegistry()); + } + bool runOnBasicBlock(BasicBlock &BB) override { + if (skipBasicBlock(BB)) + return false; + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr; + bool Changed = false; + for (BasicBlock::iterator DI = BB.begin(); DI != BB.end(); ) { + Instruction *Inst = &*DI++; + if (isInstructionTriviallyDead(Inst, TLI)) { + Inst->eraseFromParent(); + Changed = true; + ++DIEEliminated; + } + } + return Changed; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } + }; +} + +char DeadInstElimination::ID = 0; +INITIALIZE_PASS(DeadInstElimination, "die", + "Dead Instruction Elimination", false, false) + +Pass *llvm::createDeadInstEliminationPass() { + return new DeadInstElimination(); +} + +static bool DCEInstruction(Instruction *I, + SmallSetVector<Instruction *, 16> &WorkList, + const TargetLibraryInfo *TLI) { + if (isInstructionTriviallyDead(I, TLI)) { + // Null out all of the instruction's operands to see if any operand becomes + // dead as we go. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) { + Value *OpV = I->getOperand(i); + I->setOperand(i, nullptr); + + if (!OpV->use_empty() || I == OpV) + continue; + + // If the operand is an instruction that became dead as we nulled out the + // operand, and if it is 'trivially' dead, delete it in a future loop + // iteration. + if (Instruction *OpI = dyn_cast<Instruction>(OpV)) + if (isInstructionTriviallyDead(OpI, TLI)) + WorkList.insert(OpI); + } + + I->eraseFromParent(); + ++DCEEliminated; + return true; + } + return false; +} + +static bool eliminateDeadCode(Function &F, TargetLibraryInfo *TLI) { + bool MadeChange = false; + SmallSetVector<Instruction *, 16> WorkList; + // Iterate over the original function, only adding insts to the worklist + // if they actually need to be revisited. This avoids having to pre-init + // the worklist with the entire function's worth of instructions. + for (inst_iterator FI = inst_begin(F), FE = inst_end(F); FI != FE;) { + Instruction *I = &*FI; + ++FI; + + // We're visiting this instruction now, so make sure it's not in the + // worklist from an earlier visit. + if (!WorkList.count(I)) + MadeChange |= DCEInstruction(I, WorkList, TLI); + } + + while (!WorkList.empty()) { + Instruction *I = WorkList.pop_back_val(); + MadeChange |= DCEInstruction(I, WorkList, TLI); + } + return MadeChange; +} + +PreservedAnalyses DCEPass::run(Function &F, FunctionAnalysisManager &AM) { + if (!eliminateDeadCode(F, AM.getCachedResult<TargetLibraryAnalysis>(F))) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +namespace { +struct DCELegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + DCELegacyPass() : FunctionPass(ID) { + initializeDCELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + TargetLibraryInfo *TLI = TLIP ? &TLIP->getTLI() : nullptr; + + return eliminateDeadCode(F, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + } +}; +} + +char DCELegacyPass::ID = 0; +INITIALIZE_PASS(DCELegacyPass, "dce", "Dead Code Elimination", false, false) + +FunctionPass *llvm::createDeadCodeEliminationPass() { + return new DCELegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp new file mode 100644 index 000000000000..b665d94a70aa --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/DeadStoreElimination.cpp @@ -0,0 +1,1360 @@ +//===- DeadStoreElimination.cpp - Fast Dead Store Elimination -------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a trivial dead store elimination that only considers +// basic-block local redundant stores. +// +// FIXME: This should eventually be extended to be a post-dominator tree +// traversal. Doing so would be pretty trivial. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/DeadStoreElimination.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <cstddef> +#include <iterator> +#include <map> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "dse" + +STATISTIC(NumRedundantStores, "Number of redundant stores deleted"); +STATISTIC(NumFastStores, "Number of stores deleted"); +STATISTIC(NumFastOther , "Number of other instrs removed"); +STATISTIC(NumCompletePartials, "Number of stores dead by later partials"); +STATISTIC(NumModifiedStores, "Number of stores modified"); + +static cl::opt<bool> +EnablePartialOverwriteTracking("enable-dse-partial-overwrite-tracking", + cl::init(true), cl::Hidden, + cl::desc("Enable partial-overwrite tracking in DSE")); + +static cl::opt<bool> +EnablePartialStoreMerging("enable-dse-partial-store-merging", + cl::init(true), cl::Hidden, + cl::desc("Enable partial store merging in DSE")); + +//===----------------------------------------------------------------------===// +// Helper functions +//===----------------------------------------------------------------------===// +using OverlapIntervalsTy = std::map<int64_t, int64_t>; +using InstOverlapIntervalsTy = DenseMap<Instruction *, OverlapIntervalsTy>; + +/// Delete this instruction. Before we do, go through and zero out all the +/// operands of this instruction. If any of them become dead, delete them and +/// the computation tree that feeds them. +/// If ValueSet is non-null, remove any deleted instructions from it as well. +static void +deleteDeadInstruction(Instruction *I, BasicBlock::iterator *BBI, + MemoryDependenceResults &MD, const TargetLibraryInfo &TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering, + SmallSetVector<Value *, 16> *ValueSet = nullptr) { + SmallVector<Instruction*, 32> NowDeadInsts; + + NowDeadInsts.push_back(I); + --NumFastOther; + + // Keeping the iterator straight is a pain, so we let this routine tell the + // caller what the next instruction is after we're done mucking about. + BasicBlock::iterator NewIter = *BBI; + + // Before we touch this instruction, remove it from memdep! + do { + Instruction *DeadInst = NowDeadInsts.pop_back_val(); + ++NumFastOther; + + // This instruction is dead, zap it, in stages. Start by removing it from + // MemDep, which needs to know the operands and needs it to be in the + // function. + MD.removeInstruction(DeadInst); + + for (unsigned op = 0, e = DeadInst->getNumOperands(); op != e; ++op) { + Value *Op = DeadInst->getOperand(op); + DeadInst->setOperand(op, nullptr); + + // If this operand just became dead, add it to the NowDeadInsts list. + if (!Op->use_empty()) continue; + + if (Instruction *OpI = dyn_cast<Instruction>(Op)) + if (isInstructionTriviallyDead(OpI, &TLI)) + NowDeadInsts.push_back(OpI); + } + + if (ValueSet) ValueSet->remove(DeadInst); + InstrOrdering->erase(DeadInst); + IOL.erase(DeadInst); + + if (NewIter == DeadInst->getIterator()) + NewIter = DeadInst->eraseFromParent(); + else + DeadInst->eraseFromParent(); + } while (!NowDeadInsts.empty()); + *BBI = NewIter; +} + +/// Does this instruction write some memory? This only returns true for things +/// that we can analyze with other helpers below. +static bool hasMemoryWrite(Instruction *I, const TargetLibraryInfo &TLI) { + if (isa<StoreInst>(I)) + return true; + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: + return false; + case Intrinsic::memset: + case Intrinsic::memmove: + case Intrinsic::memcpy: + case Intrinsic::init_trampoline: + case Intrinsic::lifetime_end: + return true; + } + } + if (auto CS = CallSite(I)) { + if (Function *F = CS.getCalledFunction()) { + StringRef FnName = F->getName(); + if (TLI.has(LibFunc_strcpy) && FnName == TLI.getName(LibFunc_strcpy)) + return true; + if (TLI.has(LibFunc_strncpy) && FnName == TLI.getName(LibFunc_strncpy)) + return true; + if (TLI.has(LibFunc_strcat) && FnName == TLI.getName(LibFunc_strcat)) + return true; + if (TLI.has(LibFunc_strncat) && FnName == TLI.getName(LibFunc_strncat)) + return true; + } + } + return false; +} + +/// Return a Location stored to by the specified instruction. If isRemovable +/// returns true, this function and getLocForRead completely describe the memory +/// operations for this instruction. +static MemoryLocation getLocForWrite(Instruction *Inst, AliasAnalysis &AA) { + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) + return MemoryLocation::get(SI); + + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(Inst)) { + // memcpy/memmove/memset. + MemoryLocation Loc = MemoryLocation::getForDest(MI); + return Loc; + } + + IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst); + if (!II) + return MemoryLocation(); + + switch (II->getIntrinsicID()) { + default: + return MemoryLocation(); // Unhandled intrinsic. + case Intrinsic::init_trampoline: + // FIXME: We don't know the size of the trampoline, so we can't really + // handle it here. + return MemoryLocation(II->getArgOperand(0)); + case Intrinsic::lifetime_end: { + uint64_t Len = cast<ConstantInt>(II->getArgOperand(0))->getZExtValue(); + return MemoryLocation(II->getArgOperand(1), Len); + } + } +} + +/// Return the location read by the specified "hasMemoryWrite" instruction if +/// any. +static MemoryLocation getLocForRead(Instruction *Inst, + const TargetLibraryInfo &TLI) { + assert(hasMemoryWrite(Inst, TLI) && "Unknown instruction case"); + + // The only instructions that both read and write are the mem transfer + // instructions (memcpy/memmove). + if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(Inst)) + return MemoryLocation::getForSource(MTI); + return MemoryLocation(); +} + +/// If the value of this instruction and the memory it writes to is unused, may +/// we delete this instruction? +static bool isRemovable(Instruction *I) { + // Don't remove volatile/atomic stores. + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return SI->isUnordered(); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: llvm_unreachable("doesn't pass 'hasMemoryWrite' predicate"); + case Intrinsic::lifetime_end: + // Never remove dead lifetime_end's, e.g. because it is followed by a + // free. + return false; + case Intrinsic::init_trampoline: + // Always safe to remove init_trampoline. + return true; + case Intrinsic::memset: + case Intrinsic::memmove: + case Intrinsic::memcpy: + // Don't remove volatile memory intrinsics. + return !cast<MemIntrinsic>(II)->isVolatile(); + } + } + + if (auto CS = CallSite(I)) + return CS.getInstruction()->use_empty(); + + return false; +} + +/// Returns true if the end of this instruction can be safely shortened in +/// length. +static bool isShortenableAtTheEnd(Instruction *I) { + // Don't shorten stores for now + if (isa<StoreInst>(I)) + return false; + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: return false; + case Intrinsic::memset: + case Intrinsic::memcpy: + // Do shorten memory intrinsics. + // FIXME: Add memmove if it's also safe to transform. + return true; + } + } + + // Don't shorten libcalls calls for now. + + return false; +} + +/// Returns true if the beginning of this instruction can be safely shortened +/// in length. +static bool isShortenableAtTheBeginning(Instruction *I) { + // FIXME: Handle only memset for now. Supporting memcpy/memmove should be + // easily done by offsetting the source address. + IntrinsicInst *II = dyn_cast<IntrinsicInst>(I); + return II && II->getIntrinsicID() == Intrinsic::memset; +} + +/// Return the pointer that is being written to. +static Value *getStoredPointerOperand(Instruction *I) { + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return SI->getPointerOperand(); + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) + return MI->getDest(); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: llvm_unreachable("Unexpected intrinsic!"); + case Intrinsic::init_trampoline: + return II->getArgOperand(0); + } + } + + CallSite CS(I); + // All the supported functions so far happen to have dest as their first + // argument. + return CS.getArgument(0); +} + +static uint64_t getPointerSize(const Value *V, const DataLayout &DL, + const TargetLibraryInfo &TLI) { + uint64_t Size; + if (getObjectSize(V, Size, DL, &TLI)) + return Size; + return MemoryLocation::UnknownSize; +} + +namespace { + +enum OverwriteResult { + OW_Begin, + OW_Complete, + OW_End, + OW_PartialEarlierWithFullLater, + OW_Unknown +}; + +} // end anonymous namespace + +/// Return 'OW_Complete' if a store to the 'Later' location completely +/// overwrites a store to the 'Earlier' location, 'OW_End' if the end of the +/// 'Earlier' location is completely overwritten by 'Later', 'OW_Begin' if the +/// beginning of the 'Earlier' location is overwritten by 'Later'. +/// 'OW_PartialEarlierWithFullLater' means that an earlier (big) store was +/// overwritten by a latter (smaller) store which doesn't write outside the big +/// store's memory locations. Returns 'OW_Unknown' if nothing can be determined. +static OverwriteResult isOverwrite(const MemoryLocation &Later, + const MemoryLocation &Earlier, + const DataLayout &DL, + const TargetLibraryInfo &TLI, + int64_t &EarlierOff, int64_t &LaterOff, + Instruction *DepWrite, + InstOverlapIntervalsTy &IOL) { + // If we don't know the sizes of either access, then we can't do a comparison. + if (Later.Size == MemoryLocation::UnknownSize || + Earlier.Size == MemoryLocation::UnknownSize) + return OW_Unknown; + + const Value *P1 = Earlier.Ptr->stripPointerCasts(); + const Value *P2 = Later.Ptr->stripPointerCasts(); + + // If the start pointers are the same, we just have to compare sizes to see if + // the later store was larger than the earlier store. + if (P1 == P2) { + // Make sure that the Later size is >= the Earlier size. + if (Later.Size >= Earlier.Size) + return OW_Complete; + } + + // Check to see if the later store is to the entire object (either a global, + // an alloca, or a byval/inalloca argument). If so, then it clearly + // overwrites any other store to the same object. + const Value *UO1 = GetUnderlyingObject(P1, DL), + *UO2 = GetUnderlyingObject(P2, DL); + + // If we can't resolve the same pointers to the same object, then we can't + // analyze them at all. + if (UO1 != UO2) + return OW_Unknown; + + // If the "Later" store is to a recognizable object, get its size. + uint64_t ObjectSize = getPointerSize(UO2, DL, TLI); + if (ObjectSize != MemoryLocation::UnknownSize) + if (ObjectSize == Later.Size && ObjectSize >= Earlier.Size) + return OW_Complete; + + // Okay, we have stores to two completely different pointers. Try to + // decompose the pointer into a "base + constant_offset" form. If the base + // pointers are equal, then we can reason about the two stores. + EarlierOff = 0; + LaterOff = 0; + const Value *BP1 = GetPointerBaseWithConstantOffset(P1, EarlierOff, DL); + const Value *BP2 = GetPointerBaseWithConstantOffset(P2, LaterOff, DL); + + // If the base pointers still differ, we have two completely different stores. + if (BP1 != BP2) + return OW_Unknown; + + // The later store completely overlaps the earlier store if: + // + // 1. Both start at the same offset and the later one's size is greater than + // or equal to the earlier one's, or + // + // |--earlier--| + // |-- later --| + // + // 2. The earlier store has an offset greater than the later offset, but which + // still lies completely within the later store. + // + // |--earlier--| + // |----- later ------| + // + // We have to be careful here as *Off is signed while *.Size is unsigned. + if (EarlierOff >= LaterOff && + Later.Size >= Earlier.Size && + uint64_t(EarlierOff - LaterOff) + Earlier.Size <= Later.Size) + return OW_Complete; + + // We may now overlap, although the overlap is not complete. There might also + // be other incomplete overlaps, and together, they might cover the complete + // earlier write. + // Note: The correctness of this logic depends on the fact that this function + // is not even called providing DepWrite when there are any intervening reads. + if (EnablePartialOverwriteTracking && + LaterOff < int64_t(EarlierOff + Earlier.Size) && + int64_t(LaterOff + Later.Size) >= EarlierOff) { + + // Insert our part of the overlap into the map. + auto &IM = IOL[DepWrite]; + DEBUG(dbgs() << "DSE: Partial overwrite: Earlier [" << EarlierOff << ", " << + int64_t(EarlierOff + Earlier.Size) << ") Later [" << + LaterOff << ", " << int64_t(LaterOff + Later.Size) << ")\n"); + + // Make sure that we only insert non-overlapping intervals and combine + // adjacent intervals. The intervals are stored in the map with the ending + // offset as the key (in the half-open sense) and the starting offset as + // the value. + int64_t LaterIntStart = LaterOff, LaterIntEnd = LaterOff + Later.Size; + + // Find any intervals ending at, or after, LaterIntStart which start + // before LaterIntEnd. + auto ILI = IM.lower_bound(LaterIntStart); + if (ILI != IM.end() && ILI->second <= LaterIntEnd) { + // This existing interval is overlapped with the current store somewhere + // in [LaterIntStart, LaterIntEnd]. Merge them by erasing the existing + // intervals and adjusting our start and end. + LaterIntStart = std::min(LaterIntStart, ILI->second); + LaterIntEnd = std::max(LaterIntEnd, ILI->first); + ILI = IM.erase(ILI); + + // Continue erasing and adjusting our end in case other previous + // intervals are also overlapped with the current store. + // + // |--- ealier 1 ---| |--- ealier 2 ---| + // |------- later---------| + // + while (ILI != IM.end() && ILI->second <= LaterIntEnd) { + assert(ILI->second > LaterIntStart && "Unexpected interval"); + LaterIntEnd = std::max(LaterIntEnd, ILI->first); + ILI = IM.erase(ILI); + } + } + + IM[LaterIntEnd] = LaterIntStart; + + ILI = IM.begin(); + if (ILI->second <= EarlierOff && + ILI->first >= int64_t(EarlierOff + Earlier.Size)) { + DEBUG(dbgs() << "DSE: Full overwrite from partials: Earlier [" << + EarlierOff << ", " << + int64_t(EarlierOff + Earlier.Size) << + ") Composite Later [" << + ILI->second << ", " << ILI->first << ")\n"); + ++NumCompletePartials; + return OW_Complete; + } + } + + // Check for an earlier store which writes to all the memory locations that + // the later store writes to. + if (EnablePartialStoreMerging && LaterOff >= EarlierOff && + int64_t(EarlierOff + Earlier.Size) > LaterOff && + uint64_t(LaterOff - EarlierOff) + Later.Size <= Earlier.Size) { + DEBUG(dbgs() << "DSE: Partial overwrite an earlier load [" << EarlierOff + << ", " << int64_t(EarlierOff + Earlier.Size) + << ") by a later store [" << LaterOff << ", " + << int64_t(LaterOff + Later.Size) << ")\n"); + // TODO: Maybe come up with a better name? + return OW_PartialEarlierWithFullLater; + } + + // Another interesting case is if the later store overwrites the end of the + // earlier store. + // + // |--earlier--| + // |-- later --| + // + // In this case we may want to trim the size of earlier to avoid generating + // writes to addresses which will definitely be overwritten later + if (!EnablePartialOverwriteTracking && + (LaterOff > EarlierOff && LaterOff < int64_t(EarlierOff + Earlier.Size) && + int64_t(LaterOff + Later.Size) >= int64_t(EarlierOff + Earlier.Size))) + return OW_End; + + // Finally, we also need to check if the later store overwrites the beginning + // of the earlier store. + // + // |--earlier--| + // |-- later --| + // + // In this case we may want to move the destination address and trim the size + // of earlier to avoid generating writes to addresses which will definitely + // be overwritten later. + if (!EnablePartialOverwriteTracking && + (LaterOff <= EarlierOff && int64_t(LaterOff + Later.Size) > EarlierOff)) { + assert(int64_t(LaterOff + Later.Size) < + int64_t(EarlierOff + Earlier.Size) && + "Expect to be handled as OW_Complete"); + return OW_Begin; + } + // Otherwise, they don't completely overlap. + return OW_Unknown; +} + +/// If 'Inst' might be a self read (i.e. a noop copy of a +/// memory region into an identical pointer) then it doesn't actually make its +/// input dead in the traditional sense. Consider this case: +/// +/// memcpy(A <- B) +/// memcpy(A <- A) +/// +/// In this case, the second store to A does not make the first store to A dead. +/// The usual situation isn't an explicit A<-A store like this (which can be +/// trivially removed) but a case where two pointers may alias. +/// +/// This function detects when it is unsafe to remove a dependent instruction +/// because the DSE inducing instruction may be a self-read. +static bool isPossibleSelfRead(Instruction *Inst, + const MemoryLocation &InstStoreLoc, + Instruction *DepWrite, + const TargetLibraryInfo &TLI, + AliasAnalysis &AA) { + // Self reads can only happen for instructions that read memory. Get the + // location read. + MemoryLocation InstReadLoc = getLocForRead(Inst, TLI); + if (!InstReadLoc.Ptr) return false; // Not a reading instruction. + + // If the read and written loc obviously don't alias, it isn't a read. + if (AA.isNoAlias(InstReadLoc, InstStoreLoc)) return false; + + // Okay, 'Inst' may copy over itself. However, we can still remove a the + // DepWrite instruction if we can prove that it reads from the same location + // as Inst. This handles useful cases like: + // memcpy(A <- B) + // memcpy(A <- B) + // Here we don't know if A/B may alias, but we do know that B/B are must + // aliases, so removing the first memcpy is safe (assuming it writes <= # + // bytes as the second one. + MemoryLocation DepReadLoc = getLocForRead(DepWrite, TLI); + + if (DepReadLoc.Ptr && AA.isMustAlias(InstReadLoc.Ptr, DepReadLoc.Ptr)) + return false; + + // If DepWrite doesn't read memory or if we can't prove it is a must alias, + // then it can't be considered dead. + return true; +} + +/// Returns true if the memory which is accessed by the second instruction is not +/// modified between the first and the second instruction. +/// Precondition: Second instruction must be dominated by the first +/// instruction. +static bool memoryIsNotModifiedBetween(Instruction *FirstI, + Instruction *SecondI, + AliasAnalysis *AA) { + SmallVector<BasicBlock *, 16> WorkList; + SmallPtrSet<BasicBlock *, 8> Visited; + BasicBlock::iterator FirstBBI(FirstI); + ++FirstBBI; + BasicBlock::iterator SecondBBI(SecondI); + BasicBlock *FirstBB = FirstI->getParent(); + BasicBlock *SecondBB = SecondI->getParent(); + MemoryLocation MemLoc = MemoryLocation::get(SecondI); + + // Start checking the store-block. + WorkList.push_back(SecondBB); + bool isFirstBlock = true; + + // Check all blocks going backward until we reach the load-block. + while (!WorkList.empty()) { + BasicBlock *B = WorkList.pop_back_val(); + + // Ignore instructions before LI if this is the FirstBB. + BasicBlock::iterator BI = (B == FirstBB ? FirstBBI : B->begin()); + + BasicBlock::iterator EI; + if (isFirstBlock) { + // Ignore instructions after SI if this is the first visit of SecondBB. + assert(B == SecondBB && "first block is not the store block"); + EI = SecondBBI; + isFirstBlock = false; + } else { + // It's not SecondBB or (in case of a loop) the second visit of SecondBB. + // In this case we also have to look at instructions after SI. + EI = B->end(); + } + for (; BI != EI; ++BI) { + Instruction *I = &*BI; + if (I->mayWriteToMemory() && I != SecondI) + if (isModSet(AA->getModRefInfo(I, MemLoc))) + return false; + } + if (B != FirstBB) { + assert(B != &FirstBB->getParent()->getEntryBlock() && + "Should not hit the entry block because SI must be dominated by LI"); + for (auto PredI = pred_begin(B), PE = pred_end(B); PredI != PE; ++PredI) { + if (!Visited.insert(*PredI).second) + continue; + WorkList.push_back(*PredI); + } + } + } + return true; +} + +/// Find all blocks that will unconditionally lead to the block BB and append +/// them to F. +static void findUnconditionalPreds(SmallVectorImpl<BasicBlock *> &Blocks, + BasicBlock *BB, DominatorTree *DT) { + for (pred_iterator I = pred_begin(BB), E = pred_end(BB); I != E; ++I) { + BasicBlock *Pred = *I; + if (Pred == BB) continue; + TerminatorInst *PredTI = Pred->getTerminator(); + if (PredTI->getNumSuccessors() != 1) + continue; + + if (DT->isReachableFromEntry(Pred)) + Blocks.push_back(Pred); + } +} + +/// Handle frees of entire structures whose dependency is a store +/// to a field of that structure. +static bool handleFree(CallInst *F, AliasAnalysis *AA, + MemoryDependenceResults *MD, DominatorTree *DT, + const TargetLibraryInfo *TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering) { + bool MadeChange = false; + + MemoryLocation Loc = MemoryLocation(F->getOperand(0)); + SmallVector<BasicBlock *, 16> Blocks; + Blocks.push_back(F->getParent()); + const DataLayout &DL = F->getModule()->getDataLayout(); + + while (!Blocks.empty()) { + BasicBlock *BB = Blocks.pop_back_val(); + Instruction *InstPt = BB->getTerminator(); + if (BB == F->getParent()) InstPt = F; + + MemDepResult Dep = + MD->getPointerDependencyFrom(Loc, false, InstPt->getIterator(), BB); + while (Dep.isDef() || Dep.isClobber()) { + Instruction *Dependency = Dep.getInst(); + if (!hasMemoryWrite(Dependency, *TLI) || !isRemovable(Dependency)) + break; + + Value *DepPointer = + GetUnderlyingObject(getStoredPointerOperand(Dependency), DL); + + // Check for aliasing. + if (!AA->isMustAlias(F->getArgOperand(0), DepPointer)) + break; + + DEBUG(dbgs() << "DSE: Dead Store to soon to be freed memory:\n DEAD: " + << *Dependency << '\n'); + + // DCE instructions only used to calculate that store. + BasicBlock::iterator BBI(Dependency); + deleteDeadInstruction(Dependency, &BBI, *MD, *TLI, IOL, InstrOrdering); + ++NumFastStores; + MadeChange = true; + + // Inst's old Dependency is now deleted. Compute the next dependency, + // which may also be dead, as in + // s[0] = 0; + // s[1] = 0; // This has just been deleted. + // free(s); + Dep = MD->getPointerDependencyFrom(Loc, false, BBI, BB); + } + + if (Dep.isNonLocal()) + findUnconditionalPreds(Blocks, BB, DT); + } + + return MadeChange; +} + +/// Check to see if the specified location may alias any of the stack objects in +/// the DeadStackObjects set. If so, they become live because the location is +/// being loaded. +static void removeAccessedObjects(const MemoryLocation &LoadedLoc, + SmallSetVector<Value *, 16> &DeadStackObjects, + const DataLayout &DL, AliasAnalysis *AA, + const TargetLibraryInfo *TLI) { + const Value *UnderlyingPointer = GetUnderlyingObject(LoadedLoc.Ptr, DL); + + // A constant can't be in the dead pointer set. + if (isa<Constant>(UnderlyingPointer)) + return; + + // If the kill pointer can be easily reduced to an alloca, don't bother doing + // extraneous AA queries. + if (isa<AllocaInst>(UnderlyingPointer) || isa<Argument>(UnderlyingPointer)) { + DeadStackObjects.remove(const_cast<Value*>(UnderlyingPointer)); + return; + } + + // Remove objects that could alias LoadedLoc. + DeadStackObjects.remove_if([&](Value *I) { + // See if the loaded location could alias the stack location. + MemoryLocation StackLoc(I, getPointerSize(I, DL, *TLI)); + return !AA->isNoAlias(StackLoc, LoadedLoc); + }); +} + +/// Remove dead stores to stack-allocated locations in the function end block. +/// Ex: +/// %A = alloca i32 +/// ... +/// store i32 1, i32* %A +/// ret void +static bool handleEndBlock(BasicBlock &BB, AliasAnalysis *AA, + MemoryDependenceResults *MD, + const TargetLibraryInfo *TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering) { + bool MadeChange = false; + + // Keep track of all of the stack objects that are dead at the end of the + // function. + SmallSetVector<Value*, 16> DeadStackObjects; + + // Find all of the alloca'd pointers in the entry block. + BasicBlock &Entry = BB.getParent()->front(); + for (Instruction &I : Entry) { + if (isa<AllocaInst>(&I)) + DeadStackObjects.insert(&I); + + // Okay, so these are dead heap objects, but if the pointer never escapes + // then it's leaked by this function anyways. + else if (isAllocLikeFn(&I, TLI) && !PointerMayBeCaptured(&I, true, true)) + DeadStackObjects.insert(&I); + } + + // Treat byval or inalloca arguments the same, stores to them are dead at the + // end of the function. + for (Argument &AI : BB.getParent()->args()) + if (AI.hasByValOrInAllocaAttr()) + DeadStackObjects.insert(&AI); + + const DataLayout &DL = BB.getModule()->getDataLayout(); + + // Scan the basic block backwards + for (BasicBlock::iterator BBI = BB.end(); BBI != BB.begin(); ){ + --BBI; + + // If we find a store, check to see if it points into a dead stack value. + if (hasMemoryWrite(&*BBI, *TLI) && isRemovable(&*BBI)) { + // See through pointer-to-pointer bitcasts + SmallVector<Value *, 4> Pointers; + GetUnderlyingObjects(getStoredPointerOperand(&*BBI), Pointers, DL); + + // Stores to stack values are valid candidates for removal. + bool AllDead = true; + for (Value *Pointer : Pointers) + if (!DeadStackObjects.count(Pointer)) { + AllDead = false; + break; + } + + if (AllDead) { + Instruction *Dead = &*BBI; + + DEBUG(dbgs() << "DSE: Dead Store at End of Block:\n DEAD: " + << *Dead << "\n Objects: "; + for (SmallVectorImpl<Value *>::iterator I = Pointers.begin(), + E = Pointers.end(); I != E; ++I) { + dbgs() << **I; + if (std::next(I) != E) + dbgs() << ", "; + } + dbgs() << '\n'); + + // DCE instructions only used to calculate that store. + deleteDeadInstruction(Dead, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); + ++NumFastStores; + MadeChange = true; + continue; + } + } + + // Remove any dead non-memory-mutating instructions. + if (isInstructionTriviallyDead(&*BBI, TLI)) { + DEBUG(dbgs() << "DSE: Removing trivially dead instruction:\n DEAD: " + << *&*BBI << '\n'); + deleteDeadInstruction(&*BBI, &BBI, *MD, *TLI, IOL, InstrOrdering, &DeadStackObjects); + ++NumFastOther; + MadeChange = true; + continue; + } + + if (isa<AllocaInst>(BBI)) { + // Remove allocas from the list of dead stack objects; there can't be + // any references before the definition. + DeadStackObjects.remove(&*BBI); + continue; + } + + if (auto CS = CallSite(&*BBI)) { + // Remove allocation function calls from the list of dead stack objects; + // there can't be any references before the definition. + if (isAllocLikeFn(&*BBI, TLI)) + DeadStackObjects.remove(&*BBI); + + // If this call does not access memory, it can't be loading any of our + // pointers. + if (AA->doesNotAccessMemory(CS)) + continue; + + // If the call might load from any of our allocas, then any store above + // the call is live. + DeadStackObjects.remove_if([&](Value *I) { + // See if the call site touches the value. + return isRefSet(AA->getModRefInfo(CS, I, getPointerSize(I, DL, *TLI))); + }); + + // If all of the allocas were clobbered by the call then we're not going + // to find anything else to process. + if (DeadStackObjects.empty()) + break; + + continue; + } + + // We can remove the dead stores, irrespective of the fence and its ordering + // (release/acquire/seq_cst). Fences only constraints the ordering of + // already visible stores, it does not make a store visible to other + // threads. So, skipping over a fence does not change a store from being + // dead. + if (isa<FenceInst>(*BBI)) + continue; + + MemoryLocation LoadedLoc; + + // If we encounter a use of the pointer, it is no longer considered dead + if (LoadInst *L = dyn_cast<LoadInst>(BBI)) { + if (!L->isUnordered()) // Be conservative with atomic/volatile load + break; + LoadedLoc = MemoryLocation::get(L); + } else if (VAArgInst *V = dyn_cast<VAArgInst>(BBI)) { + LoadedLoc = MemoryLocation::get(V); + } else if (MemTransferInst *MTI = dyn_cast<MemTransferInst>(BBI)) { + LoadedLoc = MemoryLocation::getForSource(MTI); + } else if (!BBI->mayReadFromMemory()) { + // Instruction doesn't read memory. Note that stores that weren't removed + // above will hit this case. + continue; + } else { + // Unknown inst; assume it clobbers everything. + break; + } + + // Remove any allocas from the DeadPointer set that are loaded, as this + // makes any stores above the access live. + removeAccessedObjects(LoadedLoc, DeadStackObjects, DL, AA, TLI); + + // If all of the allocas were clobbered by the access then we're not going + // to find anything else to process. + if (DeadStackObjects.empty()) + break; + } + + return MadeChange; +} + +static bool tryToShorten(Instruction *EarlierWrite, int64_t &EarlierOffset, + int64_t &EarlierSize, int64_t LaterOffset, + int64_t LaterSize, bool IsOverwriteEnd) { + // TODO: base this on the target vector size so that if the earlier + // store was too small to get vector writes anyway then its likely + // a good idea to shorten it + // Power of 2 vector writes are probably always a bad idea to optimize + // as any store/memset/memcpy is likely using vector instructions so + // shortening it to not vector size is likely to be slower + MemIntrinsic *EarlierIntrinsic = cast<MemIntrinsic>(EarlierWrite); + unsigned EarlierWriteAlign = EarlierIntrinsic->getAlignment(); + if (!IsOverwriteEnd) + LaterOffset = int64_t(LaterOffset + LaterSize); + + if (!(isPowerOf2_64(LaterOffset) && EarlierWriteAlign <= LaterOffset) && + !((EarlierWriteAlign != 0) && LaterOffset % EarlierWriteAlign == 0)) + return false; + + DEBUG(dbgs() << "DSE: Remove Dead Store:\n OW " + << (IsOverwriteEnd ? "END" : "BEGIN") << ": " << *EarlierWrite + << "\n KILLER (offset " << LaterOffset << ", " << EarlierSize + << ")\n"); + + int64_t NewLength = IsOverwriteEnd + ? LaterOffset - EarlierOffset + : EarlierSize - (LaterOffset - EarlierOffset); + + Value *EarlierWriteLength = EarlierIntrinsic->getLength(); + Value *TrimmedLength = + ConstantInt::get(EarlierWriteLength->getType(), NewLength); + EarlierIntrinsic->setLength(TrimmedLength); + + EarlierSize = NewLength; + if (!IsOverwriteEnd) { + int64_t OffsetMoved = (LaterOffset - EarlierOffset); + Value *Indices[1] = { + ConstantInt::get(EarlierWriteLength->getType(), OffsetMoved)}; + GetElementPtrInst *NewDestGEP = GetElementPtrInst::CreateInBounds( + EarlierIntrinsic->getRawDest(), Indices, "", EarlierWrite); + EarlierIntrinsic->setDest(NewDestGEP); + EarlierOffset = EarlierOffset + OffsetMoved; + } + return true; +} + +static bool tryToShortenEnd(Instruction *EarlierWrite, + OverlapIntervalsTy &IntervalMap, + int64_t &EarlierStart, int64_t &EarlierSize) { + if (IntervalMap.empty() || !isShortenableAtTheEnd(EarlierWrite)) + return false; + + OverlapIntervalsTy::iterator OII = --IntervalMap.end(); + int64_t LaterStart = OII->second; + int64_t LaterSize = OII->first - LaterStart; + + if (LaterStart > EarlierStart && LaterStart < EarlierStart + EarlierSize && + LaterStart + LaterSize >= EarlierStart + EarlierSize) { + if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart, + LaterSize, true)) { + IntervalMap.erase(OII); + return true; + } + } + return false; +} + +static bool tryToShortenBegin(Instruction *EarlierWrite, + OverlapIntervalsTy &IntervalMap, + int64_t &EarlierStart, int64_t &EarlierSize) { + if (IntervalMap.empty() || !isShortenableAtTheBeginning(EarlierWrite)) + return false; + + OverlapIntervalsTy::iterator OII = IntervalMap.begin(); + int64_t LaterStart = OII->second; + int64_t LaterSize = OII->first - LaterStart; + + if (LaterStart <= EarlierStart && LaterStart + LaterSize > EarlierStart) { + assert(LaterStart + LaterSize < EarlierStart + EarlierSize && + "Should have been handled as OW_Complete"); + if (tryToShorten(EarlierWrite, EarlierStart, EarlierSize, LaterStart, + LaterSize, false)) { + IntervalMap.erase(OII); + return true; + } + } + return false; +} + +static bool removePartiallyOverlappedStores(AliasAnalysis *AA, + const DataLayout &DL, + InstOverlapIntervalsTy &IOL) { + bool Changed = false; + for (auto OI : IOL) { + Instruction *EarlierWrite = OI.first; + MemoryLocation Loc = getLocForWrite(EarlierWrite, *AA); + assert(isRemovable(EarlierWrite) && "Expect only removable instruction"); + assert(Loc.Size != MemoryLocation::UnknownSize && "Unexpected mem loc"); + + const Value *Ptr = Loc.Ptr->stripPointerCasts(); + int64_t EarlierStart = 0; + int64_t EarlierSize = int64_t(Loc.Size); + GetPointerBaseWithConstantOffset(Ptr, EarlierStart, DL); + OverlapIntervalsTy &IntervalMap = OI.second; + Changed |= + tryToShortenEnd(EarlierWrite, IntervalMap, EarlierStart, EarlierSize); + if (IntervalMap.empty()) + continue; + Changed |= + tryToShortenBegin(EarlierWrite, IntervalMap, EarlierStart, EarlierSize); + } + return Changed; +} + +static bool eliminateNoopStore(Instruction *Inst, BasicBlock::iterator &BBI, + AliasAnalysis *AA, MemoryDependenceResults *MD, + const DataLayout &DL, + const TargetLibraryInfo *TLI, + InstOverlapIntervalsTy &IOL, + DenseMap<Instruction*, size_t> *InstrOrdering) { + // Must be a store instruction. + StoreInst *SI = dyn_cast<StoreInst>(Inst); + if (!SI) + return false; + + // If we're storing the same value back to a pointer that we just loaded from, + // then the store can be removed. + if (LoadInst *DepLoad = dyn_cast<LoadInst>(SI->getValueOperand())) { + if (SI->getPointerOperand() == DepLoad->getPointerOperand() && + isRemovable(SI) && memoryIsNotModifiedBetween(DepLoad, SI, AA)) { + + DEBUG(dbgs() << "DSE: Remove Store Of Load from same pointer:\n LOAD: " + << *DepLoad << "\n STORE: " << *SI << '\n'); + + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, InstrOrdering); + ++NumRedundantStores; + return true; + } + } + + // Remove null stores into the calloc'ed objects + Constant *StoredConstant = dyn_cast<Constant>(SI->getValueOperand()); + if (StoredConstant && StoredConstant->isNullValue() && isRemovable(SI)) { + Instruction *UnderlyingPointer = + dyn_cast<Instruction>(GetUnderlyingObject(SI->getPointerOperand(), DL)); + + if (UnderlyingPointer && isCallocLikeFn(UnderlyingPointer, TLI) && + memoryIsNotModifiedBetween(UnderlyingPointer, SI, AA)) { + DEBUG( + dbgs() << "DSE: Remove null store to the calloc'ed object:\n DEAD: " + << *Inst << "\n OBJECT: " << *UnderlyingPointer << '\n'); + + deleteDeadInstruction(SI, &BBI, *MD, *TLI, IOL, InstrOrdering); + ++NumRedundantStores; + return true; + } + } + return false; +} + +static bool eliminateDeadStores(BasicBlock &BB, AliasAnalysis *AA, + MemoryDependenceResults *MD, DominatorTree *DT, + const TargetLibraryInfo *TLI) { + const DataLayout &DL = BB.getModule()->getDataLayout(); + bool MadeChange = false; + + // FIXME: Maybe change this to use some abstraction like OrderedBasicBlock? + // The current OrderedBasicBlock can't deal with mutation at the moment. + size_t LastThrowingInstIndex = 0; + DenseMap<Instruction*, size_t> InstrOrdering; + size_t InstrIndex = 1; + + // A map of interval maps representing partially-overwritten value parts. + InstOverlapIntervalsTy IOL; + + // Do a top-down walk on the BB. + for (BasicBlock::iterator BBI = BB.begin(), BBE = BB.end(); BBI != BBE; ) { + // Handle 'free' calls specially. + if (CallInst *F = isFreeCall(&*BBI, TLI)) { + MadeChange |= handleFree(F, AA, MD, DT, TLI, IOL, &InstrOrdering); + // Increment BBI after handleFree has potentially deleted instructions. + // This ensures we maintain a valid iterator. + ++BBI; + continue; + } + + Instruction *Inst = &*BBI++; + + size_t CurInstNumber = InstrIndex++; + InstrOrdering.insert(std::make_pair(Inst, CurInstNumber)); + if (Inst->mayThrow()) { + LastThrowingInstIndex = CurInstNumber; + continue; + } + + // Check to see if Inst writes to memory. If not, continue. + if (!hasMemoryWrite(Inst, *TLI)) + continue; + + // eliminateNoopStore will update in iterator, if necessary. + if (eliminateNoopStore(Inst, BBI, AA, MD, DL, TLI, IOL, &InstrOrdering)) { + MadeChange = true; + continue; + } + + // If we find something that writes memory, get its memory dependence. + MemDepResult InstDep = MD->getDependency(Inst); + + // Ignore any store where we can't find a local dependence. + // FIXME: cross-block DSE would be fun. :) + if (!InstDep.isDef() && !InstDep.isClobber()) + continue; + + // Figure out what location is being stored to. + MemoryLocation Loc = getLocForWrite(Inst, *AA); + + // If we didn't get a useful location, fail. + if (!Loc.Ptr) + continue; + + // Loop until we find a store we can eliminate or a load that + // invalidates the analysis. Without an upper bound on the number of + // instructions examined, this analysis can become very time-consuming. + // However, the potential gain diminishes as we process more instructions + // without eliminating any of them. Therefore, we limit the number of + // instructions we look at. + auto Limit = MD->getDefaultBlockScanLimit(); + while (InstDep.isDef() || InstDep.isClobber()) { + // Get the memory clobbered by the instruction we depend on. MemDep will + // skip any instructions that 'Loc' clearly doesn't interact with. If we + // end up depending on a may- or must-aliased load, then we can't optimize + // away the store and we bail out. However, if we depend on something + // that overwrites the memory location we *can* potentially optimize it. + // + // Find out what memory location the dependent instruction stores. + Instruction *DepWrite = InstDep.getInst(); + MemoryLocation DepLoc = getLocForWrite(DepWrite, *AA); + // If we didn't get a useful location, or if it isn't a size, bail out. + if (!DepLoc.Ptr) + break; + + // Make sure we don't look past a call which might throw. This is an + // issue because MemoryDependenceAnalysis works in the wrong direction: + // it finds instructions which dominate the current instruction, rather than + // instructions which are post-dominated by the current instruction. + // + // If the underlying object is a non-escaping memory allocation, any store + // to it is dead along the unwind edge. Otherwise, we need to preserve + // the store. + size_t DepIndex = InstrOrdering.lookup(DepWrite); + assert(DepIndex && "Unexpected instruction"); + if (DepIndex <= LastThrowingInstIndex) { + const Value* Underlying = GetUnderlyingObject(DepLoc.Ptr, DL); + bool IsStoreDeadOnUnwind = isa<AllocaInst>(Underlying); + if (!IsStoreDeadOnUnwind) { + // We're looking for a call to an allocation function + // where the allocation doesn't escape before the last + // throwing instruction; PointerMayBeCaptured + // reasonably fast approximation. + IsStoreDeadOnUnwind = isAllocLikeFn(Underlying, TLI) && + !PointerMayBeCaptured(Underlying, false, true); + } + if (!IsStoreDeadOnUnwind) + break; + } + + // If we find a write that is a) removable (i.e., non-volatile), b) is + // completely obliterated by the store to 'Loc', and c) which we know that + // 'Inst' doesn't load from, then we can remove it. + // Also try to merge two stores if a later one only touches memory written + // to by the earlier one. + if (isRemovable(DepWrite) && + !isPossibleSelfRead(Inst, Loc, DepWrite, *TLI, *AA)) { + int64_t InstWriteOffset, DepWriteOffset; + OverwriteResult OR = + isOverwrite(Loc, DepLoc, DL, *TLI, DepWriteOffset, InstWriteOffset, + DepWrite, IOL); + if (OR == OW_Complete) { + DEBUG(dbgs() << "DSE: Remove Dead Store:\n DEAD: " + << *DepWrite << "\n KILLER: " << *Inst << '\n'); + + // Delete the store and now-dead instructions that feed it. + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, &InstrOrdering); + ++NumFastStores; + MadeChange = true; + + // We erased DepWrite; start over. + InstDep = MD->getDependency(Inst); + continue; + } else if ((OR == OW_End && isShortenableAtTheEnd(DepWrite)) || + ((OR == OW_Begin && + isShortenableAtTheBeginning(DepWrite)))) { + assert(!EnablePartialOverwriteTracking && "Do not expect to perform " + "when partial-overwrite " + "tracking is enabled"); + int64_t EarlierSize = DepLoc.Size; + int64_t LaterSize = Loc.Size; + bool IsOverwriteEnd = (OR == OW_End); + MadeChange |= tryToShorten(DepWrite, DepWriteOffset, EarlierSize, + InstWriteOffset, LaterSize, IsOverwriteEnd); + } else if (EnablePartialStoreMerging && + OR == OW_PartialEarlierWithFullLater) { + auto *Earlier = dyn_cast<StoreInst>(DepWrite); + auto *Later = dyn_cast<StoreInst>(Inst); + if (Earlier && isa<ConstantInt>(Earlier->getValueOperand()) && + Later && isa<ConstantInt>(Later->getValueOperand()) && + memoryIsNotModifiedBetween(Earlier, Later, AA)) { + // If the store we find is: + // a) partially overwritten by the store to 'Loc' + // b) the later store is fully contained in the earlier one and + // c) they both have a constant value + // Merge the two stores, replacing the earlier store's value with a + // merge of both values. + // TODO: Deal with other constant types (vectors, etc), and probably + // some mem intrinsics (if needed) + + APInt EarlierValue = + cast<ConstantInt>(Earlier->getValueOperand())->getValue(); + APInt LaterValue = + cast<ConstantInt>(Later->getValueOperand())->getValue(); + unsigned LaterBits = LaterValue.getBitWidth(); + assert(EarlierValue.getBitWidth() > LaterValue.getBitWidth()); + LaterValue = LaterValue.zext(EarlierValue.getBitWidth()); + + // Offset of the smaller store inside the larger store + unsigned BitOffsetDiff = (InstWriteOffset - DepWriteOffset) * 8; + unsigned LShiftAmount = + DL.isBigEndian() + ? EarlierValue.getBitWidth() - BitOffsetDiff - LaterBits + : BitOffsetDiff; + APInt Mask = + APInt::getBitsSet(EarlierValue.getBitWidth(), LShiftAmount, + LShiftAmount + LaterBits); + // Clear the bits we'll be replacing, then OR with the smaller + // store, shifted appropriately. + APInt Merged = + (EarlierValue & ~Mask) | (LaterValue << LShiftAmount); + DEBUG(dbgs() << "DSE: Merge Stores:\n Earlier: " << *DepWrite + << "\n Later: " << *Inst + << "\n Merged Value: " << Merged << '\n'); + + auto *SI = new StoreInst( + ConstantInt::get(Earlier->getValueOperand()->getType(), Merged), + Earlier->getPointerOperand(), false, Earlier->getAlignment(), + Earlier->getOrdering(), Earlier->getSyncScopeID(), DepWrite); + + unsigned MDToKeep[] = {LLVMContext::MD_dbg, LLVMContext::MD_tbaa, + LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_nontemporal}; + SI->copyMetadata(*DepWrite, MDToKeep); + ++NumModifiedStores; + + // Remove earlier, wider, store + size_t Idx = InstrOrdering.lookup(DepWrite); + InstrOrdering.erase(DepWrite); + InstrOrdering.insert(std::make_pair(SI, Idx)); + + // Delete the old stores and now-dead instructions that feed them. + deleteDeadInstruction(Inst, &BBI, *MD, *TLI, IOL, &InstrOrdering); + deleteDeadInstruction(DepWrite, &BBI, *MD, *TLI, IOL, + &InstrOrdering); + MadeChange = true; + + // We erased DepWrite and Inst (Loc); start over. + break; + } + } + } + + // If this is a may-aliased store that is clobbering the store value, we + // can keep searching past it for another must-aliased pointer that stores + // to the same location. For example, in: + // store -> P + // store -> Q + // store -> P + // we can remove the first store to P even though we don't know if P and Q + // alias. + if (DepWrite == &BB.front()) break; + + // Can't look past this instruction if it might read 'Loc'. + if (isRefSet(AA->getModRefInfo(DepWrite, Loc))) + break; + + InstDep = MD->getPointerDependencyFrom(Loc, /*isLoad=*/ false, + DepWrite->getIterator(), &BB, + /*QueryInst=*/ nullptr, &Limit); + } + } + + if (EnablePartialOverwriteTracking) + MadeChange |= removePartiallyOverlappedStores(AA, DL, IOL); + + // If this block ends in a return, unwind, or unreachable, all allocas are + // dead at its end, which means stores to them are also dead. + if (BB.getTerminator()->getNumSuccessors() == 0) + MadeChange |= handleEndBlock(BB, AA, MD, TLI, IOL, &InstrOrdering); + + return MadeChange; +} + +static bool eliminateDeadStores(Function &F, AliasAnalysis *AA, + MemoryDependenceResults *MD, DominatorTree *DT, + const TargetLibraryInfo *TLI) { + bool MadeChange = false; + for (BasicBlock &BB : F) + // Only check non-dead blocks. Dead blocks may have strange pointer + // cycles that will confuse alias analysis. + if (DT->isReachableFromEntry(&BB)) + MadeChange |= eliminateDeadStores(BB, AA, MD, DT, TLI); + + return MadeChange; +} + +//===----------------------------------------------------------------------===// +// DSE Pass +//===----------------------------------------------------------------------===// +PreservedAnalyses DSEPass::run(Function &F, FunctionAnalysisManager &AM) { + AliasAnalysis *AA = &AM.getResult<AAManager>(F); + DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F); + MemoryDependenceResults *MD = &AM.getResult<MemoryDependenceAnalysis>(F); + const TargetLibraryInfo *TLI = &AM.getResult<TargetLibraryAnalysis>(F); + + if (!eliminateDeadStores(F, AA, MD, DT, TLI)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + PA.preserve<MemoryDependenceAnalysis>(); + return PA; +} + +namespace { + +/// A legacy pass for the legacy pass manager that wraps \c DSEPass. +class DSELegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + DSELegacyPass() : FunctionPass(ID) { + initializeDSELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + MemoryDependenceResults *MD = + &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + return eliminateDeadStores(F, AA, MD, DT, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); + } +}; + +} // end anonymous namespace + +char DSELegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(DSELegacyPass, "dse", "Dead Store Elimination", false, + false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(DSELegacyPass, "dse", "Dead Store Elimination", false, + false) + +FunctionPass *llvm::createDeadStoreEliminationPass() { + return new DSELegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/DivRemPairs.cpp b/contrib/llvm/lib/Transforms/Scalar/DivRemPairs.cpp new file mode 100644 index 000000000000..e383af89a384 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/DivRemPairs.cpp @@ -0,0 +1,206 @@ +//===- DivRemPairs.cpp - Hoist/decompose division and remainder -*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass hoists and/or decomposes integer division and remainder +// instructions to enable CFG improvements and better codegen. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/DivRemPairs.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BypassSlowDivision.h" +using namespace llvm; + +#define DEBUG_TYPE "div-rem-pairs" +STATISTIC(NumPairs, "Number of div/rem pairs"); +STATISTIC(NumHoisted, "Number of instructions hoisted"); +STATISTIC(NumDecomposed, "Number of instructions decomposed"); + +/// Find matching pairs of integer div/rem ops (they have the same numerator, +/// denominator, and signedness). If they exist in different basic blocks, bring +/// them together by hoisting or replace the common division operation that is +/// implicit in the remainder: +/// X % Y <--> X - ((X / Y) * Y). +/// +/// We can largely ignore the normal safety and cost constraints on speculation +/// of these ops when we find a matching pair. This is because we are already +/// guaranteed that any exceptions and most cost are already incurred by the +/// first member of the pair. +/// +/// Note: This transform could be an oddball enhancement to EarlyCSE, GVN, or +/// SimplifyCFG, but it's split off on its own because it's different enough +/// that it doesn't quite match the stated objectives of those passes. +static bool optimizeDivRem(Function &F, const TargetTransformInfo &TTI, + const DominatorTree &DT) { + bool Changed = false; + + // Insert all divide and remainder instructions into maps keyed by their + // operands and opcode (signed or unsigned). + DenseMap<DivRemMapKey, Instruction *> DivMap, RemMap; + for (auto &BB : F) { + for (auto &I : BB) { + if (I.getOpcode() == Instruction::SDiv) + DivMap[DivRemMapKey(true, I.getOperand(0), I.getOperand(1))] = &I; + else if (I.getOpcode() == Instruction::UDiv) + DivMap[DivRemMapKey(false, I.getOperand(0), I.getOperand(1))] = &I; + else if (I.getOpcode() == Instruction::SRem) + RemMap[DivRemMapKey(true, I.getOperand(0), I.getOperand(1))] = &I; + else if (I.getOpcode() == Instruction::URem) + RemMap[DivRemMapKey(false, I.getOperand(0), I.getOperand(1))] = &I; + } + } + + // We can iterate over either map because we are only looking for matched + // pairs. Choose remainders for efficiency because they are usually even more + // rare than division. + for (auto &RemPair : RemMap) { + // Find the matching division instruction from the division map. + Instruction *DivInst = DivMap[RemPair.getFirst()]; + if (!DivInst) + continue; + + // We have a matching pair of div/rem instructions. If one dominates the + // other, hoist and/or replace one. + NumPairs++; + Instruction *RemInst = RemPair.getSecond(); + bool IsSigned = DivInst->getOpcode() == Instruction::SDiv; + bool HasDivRemOp = TTI.hasDivRemOp(DivInst->getType(), IsSigned); + + // If the target supports div+rem and the instructions are in the same block + // already, there's nothing to do. The backend should handle this. If the + // target does not support div+rem, then we will decompose the rem. + if (HasDivRemOp && RemInst->getParent() == DivInst->getParent()) + continue; + + bool DivDominates = DT.dominates(DivInst, RemInst); + if (!DivDominates && !DT.dominates(RemInst, DivInst)) + continue; + + if (HasDivRemOp) { + // The target has a single div/rem operation. Hoist the lower instruction + // to make the matched pair visible to the backend. + if (DivDominates) + RemInst->moveAfter(DivInst); + else + DivInst->moveAfter(RemInst); + NumHoisted++; + } else { + // The target does not have a single div/rem operation. Decompose the + // remainder calculation as: + // X % Y --> X - ((X / Y) * Y). + Value *X = RemInst->getOperand(0); + Value *Y = RemInst->getOperand(1); + Instruction *Mul = BinaryOperator::CreateMul(DivInst, Y); + Instruction *Sub = BinaryOperator::CreateSub(X, Mul); + + // If the remainder dominates, then hoist the division up to that block: + // + // bb1: + // %rem = srem %x, %y + // bb2: + // %div = sdiv %x, %y + // --> + // bb1: + // %div = sdiv %x, %y + // %mul = mul %div, %y + // %rem = sub %x, %mul + // + // If the division dominates, it's already in the right place. The mul+sub + // will be in a different block because we don't assume that they are + // cheap to speculatively execute: + // + // bb1: + // %div = sdiv %x, %y + // bb2: + // %rem = srem %x, %y + // --> + // bb1: + // %div = sdiv %x, %y + // bb2: + // %mul = mul %div, %y + // %rem = sub %x, %mul + // + // If the div and rem are in the same block, we do the same transform, + // but any code movement would be within the same block. + + if (!DivDominates) + DivInst->moveBefore(RemInst); + Mul->insertAfter(RemInst); + Sub->insertAfter(Mul); + + // Now kill the explicit remainder. We have replaced it with: + // (sub X, (mul (div X, Y), Y) + RemInst->replaceAllUsesWith(Sub); + RemInst->eraseFromParent(); + NumDecomposed++; + } + Changed = true; + } + + return Changed; +} + +// Pass manager boilerplate below here. + +namespace { +struct DivRemPairsLegacyPass : public FunctionPass { + static char ID; + DivRemPairsLegacyPass() : FunctionPass(ID) { + initializeDivRemPairsLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.setPreservesCFG(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + FunctionPass::getAnalysisUsage(AU); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + return optimizeDivRem(F, TTI, DT); + } +}; +} + +char DivRemPairsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(DivRemPairsLegacyPass, "div-rem-pairs", + "Hoist/decompose integer division and remainder", false, + false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(DivRemPairsLegacyPass, "div-rem-pairs", + "Hoist/decompose integer division and remainder", false, + false) +FunctionPass *llvm::createDivRemPairsPass() { + return new DivRemPairsLegacyPass(); +} + +PreservedAnalyses DivRemPairsPass::run(Function &F, + FunctionAnalysisManager &FAM) { + TargetTransformInfo &TTI = FAM.getResult<TargetIRAnalysis>(F); + DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); + if (!optimizeDivRem(F, TTI, DT)) + return PreservedAnalyses::all(); + // TODO: This pass just hoists/replaces math ops - all analyses are preserved? + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp b/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp new file mode 100644 index 000000000000..5798e1c4ee99 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/EarlyCSE.cpp @@ -0,0 +1,1185 @@ +//===- EarlyCSE.cpp - Simple and fast CSE pass ----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs a simple dominator tree walk that eliminates trivially +// redundant instructions. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/EarlyCSE.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/ScopedHashTable.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/RecyclingAllocator.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <deque> +#include <memory> +#include <utility> + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "early-cse" + +STATISTIC(NumSimplify, "Number of instructions simplified or DCE'd"); +STATISTIC(NumCSE, "Number of instructions CSE'd"); +STATISTIC(NumCSECVP, "Number of compare instructions CVP'd"); +STATISTIC(NumCSELoad, "Number of load instructions CSE'd"); +STATISTIC(NumCSECall, "Number of call instructions CSE'd"); +STATISTIC(NumDSE, "Number of trivial dead stores removed"); + +//===----------------------------------------------------------------------===// +// SimpleValue +//===----------------------------------------------------------------------===// + +namespace { + +/// \brief Struct representing the available values in the scoped hash table. +struct SimpleValue { + Instruction *Inst; + + SimpleValue(Instruction *I) : Inst(I) { + assert((isSentinel() || canHandle(I)) && "Inst can't be handled!"); + } + + bool isSentinel() const { + return Inst == DenseMapInfo<Instruction *>::getEmptyKey() || + Inst == DenseMapInfo<Instruction *>::getTombstoneKey(); + } + + static bool canHandle(Instruction *Inst) { + // This can only handle non-void readnone functions. + if (CallInst *CI = dyn_cast<CallInst>(Inst)) + return CI->doesNotAccessMemory() && !CI->getType()->isVoidTy(); + return isa<CastInst>(Inst) || isa<BinaryOperator>(Inst) || + isa<GetElementPtrInst>(Inst) || isa<CmpInst>(Inst) || + isa<SelectInst>(Inst) || isa<ExtractElementInst>(Inst) || + isa<InsertElementInst>(Inst) || isa<ShuffleVectorInst>(Inst) || + isa<ExtractValueInst>(Inst) || isa<InsertValueInst>(Inst); + } +}; + +} // end anonymous namespace + +namespace llvm { + +template <> struct DenseMapInfo<SimpleValue> { + static inline SimpleValue getEmptyKey() { + return DenseMapInfo<Instruction *>::getEmptyKey(); + } + + static inline SimpleValue getTombstoneKey() { + return DenseMapInfo<Instruction *>::getTombstoneKey(); + } + + static unsigned getHashValue(SimpleValue Val); + static bool isEqual(SimpleValue LHS, SimpleValue RHS); +}; + +} // end namespace llvm + +unsigned DenseMapInfo<SimpleValue>::getHashValue(SimpleValue Val) { + Instruction *Inst = Val.Inst; + // Hash in all of the operands as pointers. + if (BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Inst)) { + Value *LHS = BinOp->getOperand(0); + Value *RHS = BinOp->getOperand(1); + if (BinOp->isCommutative() && BinOp->getOperand(0) > BinOp->getOperand(1)) + std::swap(LHS, RHS); + + return hash_combine(BinOp->getOpcode(), LHS, RHS); + } + + if (CmpInst *CI = dyn_cast<CmpInst>(Inst)) { + Value *LHS = CI->getOperand(0); + Value *RHS = CI->getOperand(1); + CmpInst::Predicate Pred = CI->getPredicate(); + if (Inst->getOperand(0) > Inst->getOperand(1)) { + std::swap(LHS, RHS); + Pred = CI->getSwappedPredicate(); + } + return hash_combine(Inst->getOpcode(), Pred, LHS, RHS); + } + + // Hash min/max/abs (cmp + select) to allow for commuted operands. + // Min/max may also have non-canonical compare predicate (eg, the compare for + // smin may use 'sgt' rather than 'slt'), and non-canonical operands in the + // compare. + Value *A, *B; + SelectPatternFlavor SPF = matchSelectPattern(Inst, A, B).Flavor; + // TODO: We should also detect FP min/max. + if (SPF == SPF_SMIN || SPF == SPF_SMAX || + SPF == SPF_UMIN || SPF == SPF_UMAX || + SPF == SPF_ABS || SPF == SPF_NABS) { + if (A > B) + std::swap(A, B); + return hash_combine(Inst->getOpcode(), SPF, A, B); + } + + if (CastInst *CI = dyn_cast<CastInst>(Inst)) + return hash_combine(CI->getOpcode(), CI->getType(), CI->getOperand(0)); + + if (const ExtractValueInst *EVI = dyn_cast<ExtractValueInst>(Inst)) + return hash_combine(EVI->getOpcode(), EVI->getOperand(0), + hash_combine_range(EVI->idx_begin(), EVI->idx_end())); + + if (const InsertValueInst *IVI = dyn_cast<InsertValueInst>(Inst)) + return hash_combine(IVI->getOpcode(), IVI->getOperand(0), + IVI->getOperand(1), + hash_combine_range(IVI->idx_begin(), IVI->idx_end())); + + assert((isa<CallInst>(Inst) || isa<BinaryOperator>(Inst) || + isa<GetElementPtrInst>(Inst) || isa<SelectInst>(Inst) || + isa<ExtractElementInst>(Inst) || isa<InsertElementInst>(Inst) || + isa<ShuffleVectorInst>(Inst)) && + "Invalid/unknown instruction"); + + // Mix in the opcode. + return hash_combine( + Inst->getOpcode(), + hash_combine_range(Inst->value_op_begin(), Inst->value_op_end())); +} + +bool DenseMapInfo<SimpleValue>::isEqual(SimpleValue LHS, SimpleValue RHS) { + Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst; + + if (LHS.isSentinel() || RHS.isSentinel()) + return LHSI == RHSI; + + if (LHSI->getOpcode() != RHSI->getOpcode()) + return false; + if (LHSI->isIdenticalToWhenDefined(RHSI)) + return true; + + // If we're not strictly identical, we still might be a commutable instruction + if (BinaryOperator *LHSBinOp = dyn_cast<BinaryOperator>(LHSI)) { + if (!LHSBinOp->isCommutative()) + return false; + + assert(isa<BinaryOperator>(RHSI) && + "same opcode, but different instruction type?"); + BinaryOperator *RHSBinOp = cast<BinaryOperator>(RHSI); + + // Commuted equality + return LHSBinOp->getOperand(0) == RHSBinOp->getOperand(1) && + LHSBinOp->getOperand(1) == RHSBinOp->getOperand(0); + } + if (CmpInst *LHSCmp = dyn_cast<CmpInst>(LHSI)) { + assert(isa<CmpInst>(RHSI) && + "same opcode, but different instruction type?"); + CmpInst *RHSCmp = cast<CmpInst>(RHSI); + // Commuted equality + return LHSCmp->getOperand(0) == RHSCmp->getOperand(1) && + LHSCmp->getOperand(1) == RHSCmp->getOperand(0) && + LHSCmp->getSwappedPredicate() == RHSCmp->getPredicate(); + } + + // Min/max/abs can occur with commuted operands, non-canonical predicates, + // and/or non-canonical operands. + Value *LHSA, *LHSB; + SelectPatternFlavor LSPF = matchSelectPattern(LHSI, LHSA, LHSB).Flavor; + // TODO: We should also detect FP min/max. + if (LSPF == SPF_SMIN || LSPF == SPF_SMAX || + LSPF == SPF_UMIN || LSPF == SPF_UMAX || + LSPF == SPF_ABS || LSPF == SPF_NABS) { + Value *RHSA, *RHSB; + SelectPatternFlavor RSPF = matchSelectPattern(RHSI, RHSA, RHSB).Flavor; + return (LSPF == RSPF && ((LHSA == RHSA && LHSB == RHSB) || + (LHSA == RHSB && LHSB == RHSA))); + } + + return false; +} + +//===----------------------------------------------------------------------===// +// CallValue +//===----------------------------------------------------------------------===// + +namespace { + +/// \brief Struct representing the available call values in the scoped hash +/// table. +struct CallValue { + Instruction *Inst; + + CallValue(Instruction *I) : Inst(I) { + assert((isSentinel() || canHandle(I)) && "Inst can't be handled!"); + } + + bool isSentinel() const { + return Inst == DenseMapInfo<Instruction *>::getEmptyKey() || + Inst == DenseMapInfo<Instruction *>::getTombstoneKey(); + } + + static bool canHandle(Instruction *Inst) { + // Don't value number anything that returns void. + if (Inst->getType()->isVoidTy()) + return false; + + CallInst *CI = dyn_cast<CallInst>(Inst); + if (!CI || !CI->onlyReadsMemory()) + return false; + return true; + } +}; + +} // end anonymous namespace + +namespace llvm { + +template <> struct DenseMapInfo<CallValue> { + static inline CallValue getEmptyKey() { + return DenseMapInfo<Instruction *>::getEmptyKey(); + } + + static inline CallValue getTombstoneKey() { + return DenseMapInfo<Instruction *>::getTombstoneKey(); + } + + static unsigned getHashValue(CallValue Val); + static bool isEqual(CallValue LHS, CallValue RHS); +}; + +} // end namespace llvm + +unsigned DenseMapInfo<CallValue>::getHashValue(CallValue Val) { + Instruction *Inst = Val.Inst; + // Hash all of the operands as pointers and mix in the opcode. + return hash_combine( + Inst->getOpcode(), + hash_combine_range(Inst->value_op_begin(), Inst->value_op_end())); +} + +bool DenseMapInfo<CallValue>::isEqual(CallValue LHS, CallValue RHS) { + Instruction *LHSI = LHS.Inst, *RHSI = RHS.Inst; + if (LHS.isSentinel() || RHS.isSentinel()) + return LHSI == RHSI; + return LHSI->isIdenticalTo(RHSI); +} + +//===----------------------------------------------------------------------===// +// EarlyCSE implementation +//===----------------------------------------------------------------------===// + +namespace { + +/// \brief A simple and fast domtree-based CSE pass. +/// +/// This pass does a simple depth-first walk over the dominator tree, +/// eliminating trivially redundant instructions and using instsimplify to +/// canonicalize things as it goes. It is intended to be fast and catch obvious +/// cases so that instcombine and other passes are more effective. It is +/// expected that a later pass of GVN will catch the interesting/hard cases. +class EarlyCSE { +public: + const TargetLibraryInfo &TLI; + const TargetTransformInfo &TTI; + DominatorTree &DT; + AssumptionCache &AC; + const SimplifyQuery SQ; + MemorySSA *MSSA; + std::unique_ptr<MemorySSAUpdater> MSSAUpdater; + + using AllocatorTy = + RecyclingAllocator<BumpPtrAllocator, + ScopedHashTableVal<SimpleValue, Value *>>; + using ScopedHTType = + ScopedHashTable<SimpleValue, Value *, DenseMapInfo<SimpleValue>, + AllocatorTy>; + + /// \brief A scoped hash table of the current values of all of our simple + /// scalar expressions. + /// + /// As we walk down the domtree, we look to see if instructions are in this: + /// if so, we replace them with what we find, otherwise we insert them so + /// that dominated values can succeed in their lookup. + ScopedHTType AvailableValues; + + /// A scoped hash table of the current values of previously encounted memory + /// locations. + /// + /// This allows us to get efficient access to dominating loads or stores when + /// we have a fully redundant load. In addition to the most recent load, we + /// keep track of a generation count of the read, which is compared against + /// the current generation count. The current generation count is incremented + /// after every possibly writing memory operation, which ensures that we only + /// CSE loads with other loads that have no intervening store. Ordering + /// events (such as fences or atomic instructions) increment the generation + /// count as well; essentially, we model these as writes to all possible + /// locations. Note that atomic and/or volatile loads and stores can be + /// present the table; it is the responsibility of the consumer to inspect + /// the atomicity/volatility if needed. + struct LoadValue { + Instruction *DefInst = nullptr; + unsigned Generation = 0; + int MatchingId = -1; + bool IsAtomic = false; + bool IsInvariant = false; + + LoadValue() = default; + LoadValue(Instruction *Inst, unsigned Generation, unsigned MatchingId, + bool IsAtomic, bool IsInvariant) + : DefInst(Inst), Generation(Generation), MatchingId(MatchingId), + IsAtomic(IsAtomic), IsInvariant(IsInvariant) {} + }; + + using LoadMapAllocator = + RecyclingAllocator<BumpPtrAllocator, + ScopedHashTableVal<Value *, LoadValue>>; + using LoadHTType = + ScopedHashTable<Value *, LoadValue, DenseMapInfo<Value *>, + LoadMapAllocator>; + + LoadHTType AvailableLoads; + + /// \brief A scoped hash table of the current values of read-only call + /// values. + /// + /// It uses the same generation count as loads. + using CallHTType = + ScopedHashTable<CallValue, std::pair<Instruction *, unsigned>>; + CallHTType AvailableCalls; + + /// \brief This is the current generation of the memory value. + unsigned CurrentGeneration = 0; + + /// \brief Set up the EarlyCSE runner for a particular function. + EarlyCSE(const DataLayout &DL, const TargetLibraryInfo &TLI, + const TargetTransformInfo &TTI, DominatorTree &DT, + AssumptionCache &AC, MemorySSA *MSSA) + : TLI(TLI), TTI(TTI), DT(DT), AC(AC), SQ(DL, &TLI, &DT, &AC), MSSA(MSSA), + MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {} + + bool run(); + +private: + // Almost a POD, but needs to call the constructors for the scoped hash + // tables so that a new scope gets pushed on. These are RAII so that the + // scope gets popped when the NodeScope is destroyed. + class NodeScope { + public: + NodeScope(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, + CallHTType &AvailableCalls) + : Scope(AvailableValues), LoadScope(AvailableLoads), + CallScope(AvailableCalls) {} + NodeScope(const NodeScope &) = delete; + NodeScope &operator=(const NodeScope &) = delete; + + private: + ScopedHTType::ScopeTy Scope; + LoadHTType::ScopeTy LoadScope; + CallHTType::ScopeTy CallScope; + }; + + // Contains all the needed information to create a stack for doing a depth + // first traversal of the tree. This includes scopes for values, loads, and + // calls as well as the generation. There is a child iterator so that the + // children do not need to be store separately. + class StackNode { + public: + StackNode(ScopedHTType &AvailableValues, LoadHTType &AvailableLoads, + CallHTType &AvailableCalls, unsigned cg, DomTreeNode *n, + DomTreeNode::iterator child, DomTreeNode::iterator end) + : CurrentGeneration(cg), ChildGeneration(cg), Node(n), ChildIter(child), + EndIter(end), Scopes(AvailableValues, AvailableLoads, AvailableCalls) + {} + StackNode(const StackNode &) = delete; + StackNode &operator=(const StackNode &) = delete; + + // Accessors. + unsigned currentGeneration() { return CurrentGeneration; } + unsigned childGeneration() { return ChildGeneration; } + void childGeneration(unsigned generation) { ChildGeneration = generation; } + DomTreeNode *node() { return Node; } + DomTreeNode::iterator childIter() { return ChildIter; } + + DomTreeNode *nextChild() { + DomTreeNode *child = *ChildIter; + ++ChildIter; + return child; + } + + DomTreeNode::iterator end() { return EndIter; } + bool isProcessed() { return Processed; } + void process() { Processed = true; } + + private: + unsigned CurrentGeneration; + unsigned ChildGeneration; + DomTreeNode *Node; + DomTreeNode::iterator ChildIter; + DomTreeNode::iterator EndIter; + NodeScope Scopes; + bool Processed = false; + }; + + /// \brief Wrapper class to handle memory instructions, including loads, + /// stores and intrinsic loads and stores defined by the target. + class ParseMemoryInst { + public: + ParseMemoryInst(Instruction *Inst, const TargetTransformInfo &TTI) + : Inst(Inst) { + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) + if (TTI.getTgtMemIntrinsic(II, Info)) + IsTargetMemInst = true; + } + + bool isLoad() const { + if (IsTargetMemInst) return Info.ReadMem; + return isa<LoadInst>(Inst); + } + + bool isStore() const { + if (IsTargetMemInst) return Info.WriteMem; + return isa<StoreInst>(Inst); + } + + bool isAtomic() const { + if (IsTargetMemInst) + return Info.Ordering != AtomicOrdering::NotAtomic; + return Inst->isAtomic(); + } + + bool isUnordered() const { + if (IsTargetMemInst) + return Info.isUnordered(); + + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + return LI->isUnordered(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + return SI->isUnordered(); + } + // Conservative answer + return !Inst->isAtomic(); + } + + bool isVolatile() const { + if (IsTargetMemInst) + return Info.IsVolatile; + + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + return LI->isVolatile(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + return SI->isVolatile(); + } + // Conservative answer + return true; + } + + bool isInvariantLoad() const { + if (auto *LI = dyn_cast<LoadInst>(Inst)) + return LI->getMetadata(LLVMContext::MD_invariant_load) != nullptr; + return false; + } + + bool isMatchingMemLoc(const ParseMemoryInst &Inst) const { + return (getPointerOperand() == Inst.getPointerOperand() && + getMatchingId() == Inst.getMatchingId()); + } + + bool isValid() const { return getPointerOperand() != nullptr; } + + // For regular (non-intrinsic) loads/stores, this is set to -1. For + // intrinsic loads/stores, the id is retrieved from the corresponding + // field in the MemIntrinsicInfo structure. That field contains + // non-negative values only. + int getMatchingId() const { + if (IsTargetMemInst) return Info.MatchingId; + return -1; + } + + Value *getPointerOperand() const { + if (IsTargetMemInst) return Info.PtrVal; + if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + return LI->getPointerOperand(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + return SI->getPointerOperand(); + } + return nullptr; + } + + bool mayReadFromMemory() const { + if (IsTargetMemInst) return Info.ReadMem; + return Inst->mayReadFromMemory(); + } + + bool mayWriteToMemory() const { + if (IsTargetMemInst) return Info.WriteMem; + return Inst->mayWriteToMemory(); + } + + private: + bool IsTargetMemInst = false; + MemIntrinsicInfo Info; + Instruction *Inst; + }; + + bool processNode(DomTreeNode *Node); + + Value *getOrCreateResult(Value *Inst, Type *ExpectedType) const { + if (auto *LI = dyn_cast<LoadInst>(Inst)) + return LI; + if (auto *SI = dyn_cast<StoreInst>(Inst)) + return SI->getValueOperand(); + assert(isa<IntrinsicInst>(Inst) && "Instruction not supported"); + return TTI.getOrCreateResultFromMemIntrinsic(cast<IntrinsicInst>(Inst), + ExpectedType); + } + + bool isSameMemGeneration(unsigned EarlierGeneration, unsigned LaterGeneration, + Instruction *EarlierInst, Instruction *LaterInst); + + void removeMSSA(Instruction *Inst) { + if (!MSSA) + return; + // Removing a store here can leave MemorySSA in an unoptimized state by + // creating MemoryPhis that have identical arguments and by creating + // MemoryUses whose defining access is not an actual clobber. We handle the + // phi case eagerly here. The non-optimized MemoryUse case is lazily + // updated by MemorySSA getClobberingMemoryAccess. + if (MemoryAccess *MA = MSSA->getMemoryAccess(Inst)) { + // Optimize MemoryPhi nodes that may become redundant by having all the + // same input values once MA is removed. + SmallSetVector<MemoryPhi *, 4> PhisToCheck; + SmallVector<MemoryAccess *, 8> WorkQueue; + WorkQueue.push_back(MA); + // Process MemoryPhi nodes in FIFO order using a ever-growing vector since + // we shouldn't be processing that many phis and this will avoid an + // allocation in almost all cases. + for (unsigned I = 0; I < WorkQueue.size(); ++I) { + MemoryAccess *WI = WorkQueue[I]; + + for (auto *U : WI->users()) + if (MemoryPhi *MP = dyn_cast<MemoryPhi>(U)) + PhisToCheck.insert(MP); + + MSSAUpdater->removeMemoryAccess(WI); + + for (MemoryPhi *MP : PhisToCheck) { + MemoryAccess *FirstIn = MP->getIncomingValue(0); + if (llvm::all_of(MP->incoming_values(), + [=](Use &In) { return In == FirstIn; })) + WorkQueue.push_back(MP); + } + PhisToCheck.clear(); + } + } + } +}; + +} // end anonymous namespace + +/// Determine if the memory referenced by LaterInst is from the same heap +/// version as EarlierInst. +/// This is currently called in two scenarios: +/// +/// load p +/// ... +/// load p +/// +/// and +/// +/// x = load p +/// ... +/// store x, p +/// +/// in both cases we want to verify that there are no possible writes to the +/// memory referenced by p between the earlier and later instruction. +bool EarlyCSE::isSameMemGeneration(unsigned EarlierGeneration, + unsigned LaterGeneration, + Instruction *EarlierInst, + Instruction *LaterInst) { + // Check the simple memory generation tracking first. + if (EarlierGeneration == LaterGeneration) + return true; + + if (!MSSA) + return false; + + // If MemorySSA has determined that one of EarlierInst or LaterInst does not + // read/write memory, then we can safely return true here. + // FIXME: We could be more aggressive when checking doesNotAccessMemory(), + // onlyReadsMemory(), mayReadFromMemory(), and mayWriteToMemory() in this pass + // by also checking the MemorySSA MemoryAccess on the instruction. Initial + // experiments suggest this isn't worthwhile, at least for C/C++ code compiled + // with the default optimization pipeline. + auto *EarlierMA = MSSA->getMemoryAccess(EarlierInst); + if (!EarlierMA) + return true; + auto *LaterMA = MSSA->getMemoryAccess(LaterInst); + if (!LaterMA) + return true; + + // Since we know LaterDef dominates LaterInst and EarlierInst dominates + // LaterInst, if LaterDef dominates EarlierInst then it can't occur between + // EarlierInst and LaterInst and neither can any other write that potentially + // clobbers LaterInst. + MemoryAccess *LaterDef = + MSSA->getWalker()->getClobberingMemoryAccess(LaterInst); + return MSSA->dominates(LaterDef, EarlierMA); +} + +bool EarlyCSE::processNode(DomTreeNode *Node) { + bool Changed = false; + BasicBlock *BB = Node->getBlock(); + + // If this block has a single predecessor, then the predecessor is the parent + // of the domtree node and all of the live out memory values are still current + // in this block. If this block has multiple predecessors, then they could + // have invalidated the live-out memory values of our parent value. For now, + // just be conservative and invalidate memory if this block has multiple + // predecessors. + if (!BB->getSinglePredecessor()) + ++CurrentGeneration; + + // If this node has a single predecessor which ends in a conditional branch, + // we can infer the value of the branch condition given that we took this + // path. We need the single predecessor to ensure there's not another path + // which reaches this block where the condition might hold a different + // value. Since we're adding this to the scoped hash table (like any other + // def), it will have been popped if we encounter a future merge block. + if (BasicBlock *Pred = BB->getSinglePredecessor()) { + auto *BI = dyn_cast<BranchInst>(Pred->getTerminator()); + if (BI && BI->isConditional()) { + auto *CondInst = dyn_cast<Instruction>(BI->getCondition()); + if (CondInst && SimpleValue::canHandle(CondInst)) { + assert(BI->getSuccessor(0) == BB || BI->getSuccessor(1) == BB); + auto *TorF = (BI->getSuccessor(0) == BB) + ? ConstantInt::getTrue(BB->getContext()) + : ConstantInt::getFalse(BB->getContext()); + AvailableValues.insert(CondInst, TorF); + DEBUG(dbgs() << "EarlyCSE CVP: Add conditional value for '" + << CondInst->getName() << "' as " << *TorF << " in " + << BB->getName() << "\n"); + // Replace all dominated uses with the known value. + if (unsigned Count = replaceDominatedUsesWith( + CondInst, TorF, DT, BasicBlockEdge(Pred, BB))) { + Changed = true; + NumCSECVP += Count; + } + } + } + } + + /// LastStore - Keep track of the last non-volatile store that we saw... for + /// as long as there in no instruction that reads memory. If we see a store + /// to the same location, we delete the dead store. This zaps trivial dead + /// stores which can occur in bitfield code among other things. + Instruction *LastStore = nullptr; + + // See if any instructions in the block can be eliminated. If so, do it. If + // not, add them to AvailableValues. + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + Instruction *Inst = &*I++; + + // Dead instructions should just be removed. + if (isInstructionTriviallyDead(Inst, &TLI)) { + DEBUG(dbgs() << "EarlyCSE DCE: " << *Inst << '\n'); + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + ++NumSimplify; + continue; + } + + // Skip assume intrinsics, they don't really have side effects (although + // they're marked as such to ensure preservation of control dependencies), + // and this pass will not bother with its removal. However, we should mark + // its condition as true for all dominated blocks. + if (match(Inst, m_Intrinsic<Intrinsic::assume>())) { + auto *CondI = + dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0)); + if (CondI && SimpleValue::canHandle(CondI)) { + DEBUG(dbgs() << "EarlyCSE considering assumption: " << *Inst << '\n'); + AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); + } else + DEBUG(dbgs() << "EarlyCSE skipping assumption: " << *Inst << '\n'); + continue; + } + + // Skip sideeffect intrinsics, for the same reason as assume intrinsics. + if (match(Inst, m_Intrinsic<Intrinsic::sideeffect>())) { + DEBUG(dbgs() << "EarlyCSE skipping sideeffect: " << *Inst << '\n'); + continue; + } + + // Skip invariant.start intrinsics since they only read memory, and we can + // forward values across it. Also, we dont need to consume the last store + // since the semantics of invariant.start allow us to perform DSE of the + // last store, if there was a store following invariant.start. Consider: + // + // store 30, i8* p + // invariant.start(p) + // store 40, i8* p + // We can DSE the store to 30, since the store 40 to invariant location p + // causes undefined behaviour. + if (match(Inst, m_Intrinsic<Intrinsic::invariant_start>())) + continue; + + if (match(Inst, m_Intrinsic<Intrinsic::experimental_guard>())) { + if (auto *CondI = + dyn_cast<Instruction>(cast<CallInst>(Inst)->getArgOperand(0))) { + if (SimpleValue::canHandle(CondI)) { + // Do we already know the actual value of this condition? + if (auto *KnownCond = AvailableValues.lookup(CondI)) { + // Is the condition known to be true? + if (isa<ConstantInt>(KnownCond) && + cast<ConstantInt>(KnownCond)->isOne()) { + DEBUG(dbgs() << "EarlyCSE removing guard: " << *Inst << '\n'); + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + continue; + } else + // Use the known value if it wasn't true. + cast<CallInst>(Inst)->setArgOperand(0, KnownCond); + } + // The condition we're on guarding here is true for all dominated + // locations. + AvailableValues.insert(CondI, ConstantInt::getTrue(BB->getContext())); + } + } + + // Guard intrinsics read all memory, but don't write any memory. + // Accordingly, don't update the generation but consume the last store (to + // avoid an incorrect DSE). + LastStore = nullptr; + continue; + } + + // If the instruction can be simplified (e.g. X+0 = X) then replace it with + // its simpler value. + if (Value *V = SimplifyInstruction(Inst, SQ)) { + DEBUG(dbgs() << "EarlyCSE Simplify: " << *Inst << " to: " << *V << '\n'); + bool Killed = false; + if (!Inst->use_empty()) { + Inst->replaceAllUsesWith(V); + Changed = true; + } + if (isInstructionTriviallyDead(Inst, &TLI)) { + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + Killed = true; + } + if (Changed) + ++NumSimplify; + if (Killed) + continue; + } + + // If this is a simple instruction that we can value number, process it. + if (SimpleValue::canHandle(Inst)) { + // See if the instruction has an available value. If so, use it. + if (Value *V = AvailableValues.lookup(Inst)) { + DEBUG(dbgs() << "EarlyCSE CSE: " << *Inst << " to: " << *V << '\n'); + if (auto *I = dyn_cast<Instruction>(V)) + I->andIRFlags(Inst); + Inst->replaceAllUsesWith(V); + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + ++NumCSE; + continue; + } + + // Otherwise, just remember that this value is available. + AvailableValues.insert(Inst, Inst); + continue; + } + + ParseMemoryInst MemInst(Inst, TTI); + // If this is a non-volatile load, process it. + if (MemInst.isValid() && MemInst.isLoad()) { + // (conservatively) we can't peak past the ordering implied by this + // operation, but we can add this load to our set of available values + if (MemInst.isVolatile() || !MemInst.isUnordered()) { + LastStore = nullptr; + ++CurrentGeneration; + } + + // If we have an available version of this load, and if it is the right + // generation or the load is known to be from an invariant location, + // replace this instruction. + // + // If either the dominating load or the current load are invariant, then + // we can assume the current load loads the same value as the dominating + // load. + LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); + if (InVal.DefInst != nullptr && + InVal.MatchingId == MemInst.getMatchingId() && + // We don't yet handle removing loads with ordering of any kind. + !MemInst.isVolatile() && MemInst.isUnordered() && + // We can't replace an atomic load with one which isn't also atomic. + InVal.IsAtomic >= MemInst.isAtomic() && + (InVal.IsInvariant || MemInst.isInvariantLoad() || + isSameMemGeneration(InVal.Generation, CurrentGeneration, + InVal.DefInst, Inst))) { + Value *Op = getOrCreateResult(InVal.DefInst, Inst->getType()); + if (Op != nullptr) { + DEBUG(dbgs() << "EarlyCSE CSE LOAD: " << *Inst + << " to: " << *InVal.DefInst << '\n'); + if (!Inst->use_empty()) + Inst->replaceAllUsesWith(Op); + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + ++NumCSELoad; + continue; + } + } + + // Otherwise, remember that we have this instruction. + AvailableLoads.insert( + MemInst.getPointerOperand(), + LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), + MemInst.isAtomic(), MemInst.isInvariantLoad())); + LastStore = nullptr; + continue; + } + + // If this instruction may read from memory or throw (and potentially read + // from memory in the exception handler), forget LastStore. Load/store + // intrinsics will indicate both a read and a write to memory. The target + // may override this (e.g. so that a store intrinsic does not read from + // memory, and thus will be treated the same as a regular store for + // commoning purposes). + if ((Inst->mayReadFromMemory() || Inst->mayThrow()) && + !(MemInst.isValid() && !MemInst.mayReadFromMemory())) + LastStore = nullptr; + + // If this is a read-only call, process it. + if (CallValue::canHandle(Inst)) { + // If we have an available version of this call, and if it is the right + // generation, replace this instruction. + std::pair<Instruction *, unsigned> InVal = AvailableCalls.lookup(Inst); + if (InVal.first != nullptr && + isSameMemGeneration(InVal.second, CurrentGeneration, InVal.first, + Inst)) { + DEBUG(dbgs() << "EarlyCSE CSE CALL: " << *Inst + << " to: " << *InVal.first << '\n'); + if (!Inst->use_empty()) + Inst->replaceAllUsesWith(InVal.first); + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + ++NumCSECall; + continue; + } + + // Otherwise, remember that we have this instruction. + AvailableCalls.insert( + Inst, std::pair<Instruction *, unsigned>(Inst, CurrentGeneration)); + continue; + } + + // A release fence requires that all stores complete before it, but does + // not prevent the reordering of following loads 'before' the fence. As a + // result, we don't need to consider it as writing to memory and don't need + // to advance the generation. We do need to prevent DSE across the fence, + // but that's handled above. + if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) + if (FI->getOrdering() == AtomicOrdering::Release) { + assert(Inst->mayReadFromMemory() && "relied on to prevent DSE above"); + continue; + } + + // write back DSE - If we write back the same value we just loaded from + // the same location and haven't passed any intervening writes or ordering + // operations, we can remove the write. The primary benefit is in allowing + // the available load table to remain valid and value forward past where + // the store originally was. + if (MemInst.isValid() && MemInst.isStore()) { + LoadValue InVal = AvailableLoads.lookup(MemInst.getPointerOperand()); + if (InVal.DefInst && + InVal.DefInst == getOrCreateResult(Inst, InVal.DefInst->getType()) && + InVal.MatchingId == MemInst.getMatchingId() && + // We don't yet handle removing stores with ordering of any kind. + !MemInst.isVolatile() && MemInst.isUnordered() && + isSameMemGeneration(InVal.Generation, CurrentGeneration, + InVal.DefInst, Inst)) { + // It is okay to have a LastStore to a different pointer here if MemorySSA + // tells us that the load and store are from the same memory generation. + // In that case, LastStore should keep its present value since we're + // removing the current store. + assert((!LastStore || + ParseMemoryInst(LastStore, TTI).getPointerOperand() == + MemInst.getPointerOperand() || + MSSA) && + "can't have an intervening store if not using MemorySSA!"); + DEBUG(dbgs() << "EarlyCSE DSE (writeback): " << *Inst << '\n'); + removeMSSA(Inst); + Inst->eraseFromParent(); + Changed = true; + ++NumDSE; + // We can avoid incrementing the generation count since we were able + // to eliminate this store. + continue; + } + } + + // Okay, this isn't something we can CSE at all. Check to see if it is + // something that could modify memory. If so, our available memory values + // cannot be used so bump the generation count. + if (Inst->mayWriteToMemory()) { + ++CurrentGeneration; + + if (MemInst.isValid() && MemInst.isStore()) { + // We do a trivial form of DSE if there are two stores to the same + // location with no intervening loads. Delete the earlier store. + // At the moment, we don't remove ordered stores, but do remove + // unordered atomic stores. There's no special requirement (for + // unordered atomics) about removing atomic stores only in favor of + // other atomic stores since we we're going to execute the non-atomic + // one anyway and the atomic one might never have become visible. + if (LastStore) { + ParseMemoryInst LastStoreMemInst(LastStore, TTI); + assert(LastStoreMemInst.isUnordered() && + !LastStoreMemInst.isVolatile() && + "Violated invariant"); + if (LastStoreMemInst.isMatchingMemLoc(MemInst)) { + DEBUG(dbgs() << "EarlyCSE DEAD STORE: " << *LastStore + << " due to: " << *Inst << '\n'); + removeMSSA(LastStore); + LastStore->eraseFromParent(); + Changed = true; + ++NumDSE; + LastStore = nullptr; + } + // fallthrough - we can exploit information about this store + } + + // Okay, we just invalidated anything we knew about loaded values. Try + // to salvage *something* by remembering that the stored value is a live + // version of the pointer. It is safe to forward from volatile stores + // to non-volatile loads, so we don't have to check for volatility of + // the store. + AvailableLoads.insert( + MemInst.getPointerOperand(), + LoadValue(Inst, CurrentGeneration, MemInst.getMatchingId(), + MemInst.isAtomic(), /*IsInvariant=*/false)); + + // Remember that this was the last unordered store we saw for DSE. We + // don't yet handle DSE on ordered or volatile stores since we don't + // have a good way to model the ordering requirement for following + // passes once the store is removed. We could insert a fence, but + // since fences are slightly stronger than stores in their ordering, + // it's not clear this is a profitable transform. Another option would + // be to merge the ordering with that of the post dominating store. + if (MemInst.isUnordered() && !MemInst.isVolatile()) + LastStore = Inst; + else + LastStore = nullptr; + } + } + } + + return Changed; +} + +bool EarlyCSE::run() { + // Note, deque is being used here because there is significant performance + // gains over vector when the container becomes very large due to the + // specific access patterns. For more information see the mailing list + // discussion on this: + // http://lists.llvm.org/pipermail/llvm-commits/Week-of-Mon-20120116/135228.html + std::deque<StackNode *> nodesToProcess; + + bool Changed = false; + + // Process the root node. + nodesToProcess.push_back(new StackNode( + AvailableValues, AvailableLoads, AvailableCalls, CurrentGeneration, + DT.getRootNode(), DT.getRootNode()->begin(), DT.getRootNode()->end())); + + // Save the current generation. + unsigned LiveOutGeneration = CurrentGeneration; + + // Process the stack. + while (!nodesToProcess.empty()) { + // Grab the first item off the stack. Set the current generation, remove + // the node from the stack, and process it. + StackNode *NodeToProcess = nodesToProcess.back(); + + // Initialize class members. + CurrentGeneration = NodeToProcess->currentGeneration(); + + // Check if the node needs to be processed. + if (!NodeToProcess->isProcessed()) { + // Process the node. + Changed |= processNode(NodeToProcess->node()); + NodeToProcess->childGeneration(CurrentGeneration); + NodeToProcess->process(); + } else if (NodeToProcess->childIter() != NodeToProcess->end()) { + // Push the next child onto the stack. + DomTreeNode *child = NodeToProcess->nextChild(); + nodesToProcess.push_back( + new StackNode(AvailableValues, AvailableLoads, AvailableCalls, + NodeToProcess->childGeneration(), child, child->begin(), + child->end())); + } else { + // It has been processed, and there are no more children to process, + // so delete it and pop it off the stack. + delete NodeToProcess; + nodesToProcess.pop_back(); + } + } // while (!nodes...) + + // Reset the current generation. + CurrentGeneration = LiveOutGeneration; + + return Changed; +} + +PreservedAnalyses EarlyCSEPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto *MSSA = + UseMemorySSA ? &AM.getResult<MemorySSAAnalysis>(F).getMSSA() : nullptr; + + EarlyCSE CSE(F.getParent()->getDataLayout(), TLI, TTI, DT, AC, MSSA); + + if (!CSE.run()) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + if (UseMemorySSA) + PA.preserve<MemorySSAAnalysis>(); + return PA; +} + +namespace { + +/// \brief A simple and fast domtree-based CSE pass. +/// +/// This pass does a simple depth-first walk over the dominator tree, +/// eliminating trivially redundant instructions and using instsimplify to +/// canonicalize things as it goes. It is intended to be fast and catch obvious +/// cases so that instcombine and other passes are more effective. It is +/// expected that a later pass of GVN will catch the interesting/hard cases. +template<bool UseMemorySSA> +class EarlyCSELegacyCommonPass : public FunctionPass { +public: + static char ID; + + EarlyCSELegacyCommonPass() : FunctionPass(ID) { + if (UseMemorySSA) + initializeEarlyCSEMemSSALegacyPassPass(*PassRegistry::getPassRegistry()); + else + initializeEarlyCSELegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto *MSSA = + UseMemorySSA ? &getAnalysis<MemorySSAWrapperPass>().getMSSA() : nullptr; + + EarlyCSE CSE(F.getParent()->getDataLayout(), TLI, TTI, DT, AC, MSSA); + + return CSE.run(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + if (UseMemorySSA) { + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + } + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.setPreservesCFG(); + } +}; + +} // end anonymous namespace + +using EarlyCSELegacyPass = EarlyCSELegacyCommonPass</*UseMemorySSA=*/false>; + +template<> +char EarlyCSELegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(EarlyCSELegacyPass, "early-cse", "Early CSE", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(EarlyCSELegacyPass, "early-cse", "Early CSE", false, false) + +using EarlyCSEMemSSALegacyPass = + EarlyCSELegacyCommonPass</*UseMemorySSA=*/true>; + +template<> +char EarlyCSEMemSSALegacyPass::ID = 0; + +FunctionPass *llvm::createEarlyCSEPass(bool UseMemorySSA) { + if (UseMemorySSA) + return new EarlyCSEMemSSALegacyPass(); + else + return new EarlyCSELegacyPass(); +} + +INITIALIZE_PASS_BEGIN(EarlyCSEMemSSALegacyPass, "early-cse-memssa", + "Early CSE w/ MemorySSA", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_END(EarlyCSEMemSSALegacyPass, "early-cse-memssa", + "Early CSE w/ MemorySSA", false, false) diff --git a/contrib/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp b/contrib/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp new file mode 100644 index 000000000000..063df779a30b --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/FlattenCFGPass.cpp @@ -0,0 +1,80 @@ +//===- FlattenCFGPass.cpp - CFG Flatten Pass ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements flattening of CFG. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/IR/CFG.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +using namespace llvm; + +#define DEBUG_TYPE "flattencfg" + +namespace { +struct FlattenCFGPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid +public: + FlattenCFGPass() : FunctionPass(ID) { + initializeFlattenCFGPassPass(*PassRegistry::getPassRegistry()); + } + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AAResultsWrapperPass>(); + } + +private: + AliasAnalysis *AA; +}; +} + +char FlattenCFGPass::ID = 0; +INITIALIZE_PASS_BEGIN(FlattenCFGPass, "flattencfg", "Flatten the CFG", false, + false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(FlattenCFGPass, "flattencfg", "Flatten the CFG", false, + false) + +// Public interface to the FlattenCFG pass +FunctionPass *llvm::createFlattenCFGPass() { return new FlattenCFGPass(); } + +/// iterativelyFlattenCFG - Call FlattenCFG on all the blocks in the function, +/// iterating until no more changes are made. +static bool iterativelyFlattenCFG(Function &F, AliasAnalysis *AA) { + bool Changed = false; + bool LocalChange = true; + while (LocalChange) { + LocalChange = false; + + // Loop over all of the basic blocks and remove them if they are unneeded... + // + for (Function::iterator BBIt = F.begin(); BBIt != F.end();) { + if (FlattenCFG(&*BBIt++, AA)) { + LocalChange = true; + } + } + Changed |= LocalChange; + } + return Changed; +} + +bool FlattenCFGPass::runOnFunction(Function &F) { + AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + bool EverChanged = false; + // iterativelyFlattenCFG can make some blocks dead. + while (iterativelyFlattenCFG(F, AA)) { + removeUnreachableBlocks(F); + EverChanged = true; + } + return EverChanged; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp b/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp new file mode 100644 index 000000000000..b105ece8dc7c --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/Float2Int.cpp @@ -0,0 +1,525 @@ +//===- Float2Int.cpp - Demote floating point ops to work on integers ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Float2Int pass, which aims to demote floating +// point operations to work on integers, where that is losslessly possible. +// +//===----------------------------------------------------------------------===// + +#define DEBUG_TYPE "float2int" + +#include "llvm/Transforms/Scalar/Float2Int.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include <deque> +#include <functional> // For std::function +using namespace llvm; + +// The algorithm is simple. Start at instructions that convert from the +// float to the int domain: fptoui, fptosi and fcmp. Walk up the def-use +// graph, using an equivalence datastructure to unify graphs that interfere. +// +// Mappable instructions are those with an integer corrollary that, given +// integer domain inputs, produce an integer output; fadd, for example. +// +// If a non-mappable instruction is seen, this entire def-use graph is marked +// as non-transformable. If we see an instruction that converts from the +// integer domain to FP domain (uitofp,sitofp), we terminate our walk. + +/// The largest integer type worth dealing with. +static cl::opt<unsigned> +MaxIntegerBW("float2int-max-integer-bw", cl::init(64), cl::Hidden, + cl::desc("Max integer bitwidth to consider in float2int" + "(default=64)")); + +namespace { + struct Float2IntLegacyPass : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + Float2IntLegacyPass() : FunctionPass(ID) { + initializeFloat2IntLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + return Impl.runImpl(F); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + + private: + Float2IntPass Impl; + }; +} + +char Float2IntLegacyPass::ID = 0; +INITIALIZE_PASS(Float2IntLegacyPass, "float2int", "Float to int", false, false) + +// Given a FCmp predicate, return a matching ICmp predicate if one +// exists, otherwise return BAD_ICMP_PREDICATE. +static CmpInst::Predicate mapFCmpPred(CmpInst::Predicate P) { + switch (P) { + case CmpInst::FCMP_OEQ: + case CmpInst::FCMP_UEQ: + return CmpInst::ICMP_EQ; + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_UGT: + return CmpInst::ICMP_SGT; + case CmpInst::FCMP_OGE: + case CmpInst::FCMP_UGE: + return CmpInst::ICMP_SGE; + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_ULT: + return CmpInst::ICMP_SLT; + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_ULE: + return CmpInst::ICMP_SLE; + case CmpInst::FCMP_ONE: + case CmpInst::FCMP_UNE: + return CmpInst::ICMP_NE; + default: + return CmpInst::BAD_ICMP_PREDICATE; + } +} + +// Given a floating point binary operator, return the matching +// integer version. +static Instruction::BinaryOps mapBinOpcode(unsigned Opcode) { + switch (Opcode) { + default: llvm_unreachable("Unhandled opcode!"); + case Instruction::FAdd: return Instruction::Add; + case Instruction::FSub: return Instruction::Sub; + case Instruction::FMul: return Instruction::Mul; + } +} + +// Find the roots - instructions that convert from the FP domain to +// integer domain. +void Float2IntPass::findRoots(Function &F, SmallPtrSet<Instruction*,8> &Roots) { + for (auto &I : instructions(F)) { + if (isa<VectorType>(I.getType())) + continue; + switch (I.getOpcode()) { + default: break; + case Instruction::FPToUI: + case Instruction::FPToSI: + Roots.insert(&I); + break; + case Instruction::FCmp: + if (mapFCmpPred(cast<CmpInst>(&I)->getPredicate()) != + CmpInst::BAD_ICMP_PREDICATE) + Roots.insert(&I); + break; + } + } +} + +// Helper - mark I as having been traversed, having range R. +void Float2IntPass::seen(Instruction *I, ConstantRange R) { + DEBUG(dbgs() << "F2I: " << *I << ":" << R << "\n"); + auto IT = SeenInsts.find(I); + if (IT != SeenInsts.end()) + IT->second = std::move(R); + else + SeenInsts.insert(std::make_pair(I, std::move(R))); +} + +// Helper - get a range representing a poison value. +ConstantRange Float2IntPass::badRange() { + return ConstantRange(MaxIntegerBW + 1, true); +} +ConstantRange Float2IntPass::unknownRange() { + return ConstantRange(MaxIntegerBW + 1, false); +} +ConstantRange Float2IntPass::validateRange(ConstantRange R) { + if (R.getBitWidth() > MaxIntegerBW + 1) + return badRange(); + return R; +} + +// The most obvious way to structure the search is a depth-first, eager +// search from each root. However, that require direct recursion and so +// can only handle small instruction sequences. Instead, we split the search +// up into two phases: +// - walkBackwards: A breadth-first walk of the use-def graph starting from +// the roots. Populate "SeenInsts" with interesting +// instructions and poison values if they're obvious and +// cheap to compute. Calculate the equivalance set structure +// while we're here too. +// - walkForwards: Iterate over SeenInsts in reverse order, so we visit +// defs before their uses. Calculate the real range info. + +// Breadth-first walk of the use-def graph; determine the set of nodes +// we care about and eagerly determine if some of them are poisonous. +void Float2IntPass::walkBackwards(const SmallPtrSetImpl<Instruction*> &Roots) { + std::deque<Instruction*> Worklist(Roots.begin(), Roots.end()); + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + + if (SeenInsts.find(I) != SeenInsts.end()) + // Seen already. + continue; + + switch (I->getOpcode()) { + // FIXME: Handle select and phi nodes. + default: + // Path terminated uncleanly. + seen(I, badRange()); + break; + + case Instruction::UIToFP: + case Instruction::SIToFP: { + // Path terminated cleanly - use the type of the integer input to seed + // the analysis. + unsigned BW = I->getOperand(0)->getType()->getPrimitiveSizeInBits(); + auto Input = ConstantRange(BW, true); + auto CastOp = (Instruction::CastOps)I->getOpcode(); + seen(I, validateRange(Input.castOp(CastOp, MaxIntegerBW+1))); + continue; + } + + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FCmp: + seen(I, unknownRange()); + break; + } + + for (Value *O : I->operands()) { + if (Instruction *OI = dyn_cast<Instruction>(O)) { + // Unify def-use chains if they interfere. + ECs.unionSets(I, OI); + if (SeenInsts.find(I)->second != badRange()) + Worklist.push_back(OI); + } else if (!isa<ConstantFP>(O)) { + // Not an instruction or ConstantFP? we can't do anything. + seen(I, badRange()); + } + } + } +} + +// Walk forwards down the list of seen instructions, so we visit defs before +// uses. +void Float2IntPass::walkForwards() { + for (auto &It : reverse(SeenInsts)) { + if (It.second != unknownRange()) + continue; + + Instruction *I = It.first; + std::function<ConstantRange(ArrayRef<ConstantRange>)> Op; + switch (I->getOpcode()) { + // FIXME: Handle select and phi nodes. + default: + case Instruction::UIToFP: + case Instruction::SIToFP: + llvm_unreachable("Should have been handled in walkForwards!"); + + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + Op = [I](ArrayRef<ConstantRange> Ops) { + assert(Ops.size() == 2 && "its a binary operator!"); + auto BinOp = (Instruction::BinaryOps) I->getOpcode(); + return Ops[0].binaryOp(BinOp, Ops[1]); + }; + break; + + // + // Root-only instructions - we'll only see these if they're the + // first node in a walk. + // + case Instruction::FPToUI: + case Instruction::FPToSI: + Op = [I](ArrayRef<ConstantRange> Ops) { + assert(Ops.size() == 1 && "FPTo[US]I is a unary operator!"); + // Note: We're ignoring the casts output size here as that's what the + // caller expects. + auto CastOp = (Instruction::CastOps)I->getOpcode(); + return Ops[0].castOp(CastOp, MaxIntegerBW+1); + }; + break; + + case Instruction::FCmp: + Op = [](ArrayRef<ConstantRange> Ops) { + assert(Ops.size() == 2 && "FCmp is a binary operator!"); + return Ops[0].unionWith(Ops[1]); + }; + break; + } + + bool Abort = false; + SmallVector<ConstantRange,4> OpRanges; + for (Value *O : I->operands()) { + if (Instruction *OI = dyn_cast<Instruction>(O)) { + assert(SeenInsts.find(OI) != SeenInsts.end() && + "def not seen before use!"); + OpRanges.push_back(SeenInsts.find(OI)->second); + } else if (ConstantFP *CF = dyn_cast<ConstantFP>(O)) { + // Work out if the floating point number can be losslessly represented + // as an integer. + // APFloat::convertToInteger(&Exact) purports to do what we want, but + // the exactness can be too precise. For example, negative zero can + // never be exactly converted to an integer. + // + // Instead, we ask APFloat to round itself to an integral value - this + // preserves sign-of-zero - then compare the result with the original. + // + const APFloat &F = CF->getValueAPF(); + + // First, weed out obviously incorrect values. Non-finite numbers + // can't be represented and neither can negative zero, unless + // we're in fast math mode. + if (!F.isFinite() || + (F.isZero() && F.isNegative() && isa<FPMathOperator>(I) && + !I->hasNoSignedZeros())) { + seen(I, badRange()); + Abort = true; + break; + } + + APFloat NewF = F; + auto Res = NewF.roundToIntegral(APFloat::rmNearestTiesToEven); + if (Res != APFloat::opOK || NewF.compare(F) != APFloat::cmpEqual) { + seen(I, badRange()); + Abort = true; + break; + } + // OK, it's representable. Now get it. + APSInt Int(MaxIntegerBW+1, false); + bool Exact; + CF->getValueAPF().convertToInteger(Int, + APFloat::rmNearestTiesToEven, + &Exact); + OpRanges.push_back(ConstantRange(Int)); + } else { + llvm_unreachable("Should have already marked this as badRange!"); + } + } + + // Reduce the operands' ranges to a single range and return. + if (!Abort) + seen(I, Op(OpRanges)); + } +} + +// If there is a valid transform to be done, do it. +bool Float2IntPass::validateAndTransform() { + bool MadeChange = false; + + // Iterate over every disjoint partition of the def-use graph. + for (auto It = ECs.begin(), E = ECs.end(); It != E; ++It) { + ConstantRange R(MaxIntegerBW + 1, false); + bool Fail = false; + Type *ConvertedToTy = nullptr; + + // For every member of the partition, union all the ranges together. + for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); + MI != ME; ++MI) { + Instruction *I = *MI; + auto SeenI = SeenInsts.find(I); + if (SeenI == SeenInsts.end()) + continue; + + R = R.unionWith(SeenI->second); + // We need to ensure I has no users that have not been seen. + // If it does, transformation would be illegal. + // + // Don't count the roots, as they terminate the graphs. + if (Roots.count(I) == 0) { + // Set the type of the conversion while we're here. + if (!ConvertedToTy) + ConvertedToTy = I->getType(); + for (User *U : I->users()) { + Instruction *UI = dyn_cast<Instruction>(U); + if (!UI || SeenInsts.find(UI) == SeenInsts.end()) { + DEBUG(dbgs() << "F2I: Failing because of " << *U << "\n"); + Fail = true; + break; + } + } + } + if (Fail) + break; + } + + // If the set was empty, or we failed, or the range is poisonous, + // bail out. + if (ECs.member_begin(It) == ECs.member_end() || Fail || + R.isFullSet() || R.isSignWrappedSet()) + continue; + assert(ConvertedToTy && "Must have set the convertedtoty by this point!"); + + // The number of bits required is the maximum of the upper and + // lower limits, plus one so it can be signed. + unsigned MinBW = std::max(R.getLower().getMinSignedBits(), + R.getUpper().getMinSignedBits()) + 1; + DEBUG(dbgs() << "F2I: MinBitwidth=" << MinBW << ", R: " << R << "\n"); + + // If we've run off the realms of the exactly representable integers, + // the floating point result will differ from an integer approximation. + + // Do we need more bits than are in the mantissa of the type we converted + // to? semanticsPrecision returns the number of mantissa bits plus one + // for the sign bit. + unsigned MaxRepresentableBits + = APFloat::semanticsPrecision(ConvertedToTy->getFltSemantics()) - 1; + if (MinBW > MaxRepresentableBits) { + DEBUG(dbgs() << "F2I: Value not guaranteed to be representable!\n"); + continue; + } + if (MinBW > 64) { + DEBUG(dbgs() << "F2I: Value requires more than 64 bits to represent!\n"); + continue; + } + + // OK, R is known to be representable. Now pick a type for it. + // FIXME: Pick the smallest legal type that will fit. + Type *Ty = (MinBW > 32) ? Type::getInt64Ty(*Ctx) : Type::getInt32Ty(*Ctx); + + for (auto MI = ECs.member_begin(It), ME = ECs.member_end(); + MI != ME; ++MI) + convert(*MI, Ty); + MadeChange = true; + } + + return MadeChange; +} + +Value *Float2IntPass::convert(Instruction *I, Type *ToTy) { + if (ConvertedInsts.find(I) != ConvertedInsts.end()) + // Already converted this instruction. + return ConvertedInsts[I]; + + SmallVector<Value*,4> NewOperands; + for (Value *V : I->operands()) { + // Don't recurse if we're an instruction that terminates the path. + if (I->getOpcode() == Instruction::UIToFP || + I->getOpcode() == Instruction::SIToFP) { + NewOperands.push_back(V); + } else if (Instruction *VI = dyn_cast<Instruction>(V)) { + NewOperands.push_back(convert(VI, ToTy)); + } else if (ConstantFP *CF = dyn_cast<ConstantFP>(V)) { + APSInt Val(ToTy->getPrimitiveSizeInBits(), /*IsUnsigned=*/false); + bool Exact; + CF->getValueAPF().convertToInteger(Val, + APFloat::rmNearestTiesToEven, + &Exact); + NewOperands.push_back(ConstantInt::get(ToTy, Val)); + } else { + llvm_unreachable("Unhandled operand type?"); + } + } + + // Now create a new instruction. + IRBuilder<> IRB(I); + Value *NewV = nullptr; + switch (I->getOpcode()) { + default: llvm_unreachable("Unhandled instruction!"); + + case Instruction::FPToUI: + NewV = IRB.CreateZExtOrTrunc(NewOperands[0], I->getType()); + break; + + case Instruction::FPToSI: + NewV = IRB.CreateSExtOrTrunc(NewOperands[0], I->getType()); + break; + + case Instruction::FCmp: { + CmpInst::Predicate P = mapFCmpPred(cast<CmpInst>(I)->getPredicate()); + assert(P != CmpInst::BAD_ICMP_PREDICATE && "Unhandled predicate!"); + NewV = IRB.CreateICmp(P, NewOperands[0], NewOperands[1], I->getName()); + break; + } + + case Instruction::UIToFP: + NewV = IRB.CreateZExtOrTrunc(NewOperands[0], ToTy); + break; + + case Instruction::SIToFP: + NewV = IRB.CreateSExtOrTrunc(NewOperands[0], ToTy); + break; + + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + NewV = IRB.CreateBinOp(mapBinOpcode(I->getOpcode()), + NewOperands[0], NewOperands[1], + I->getName()); + break; + } + + // If we're a root instruction, RAUW. + if (Roots.count(I)) + I->replaceAllUsesWith(NewV); + + ConvertedInsts[I] = NewV; + return NewV; +} + +// Perform dead code elimination on the instructions we just modified. +void Float2IntPass::cleanup() { + for (auto &I : reverse(ConvertedInsts)) + I.first->eraseFromParent(); +} + +bool Float2IntPass::runImpl(Function &F) { + DEBUG(dbgs() << "F2I: Looking at function " << F.getName() << "\n"); + // Clear out all state. + ECs = EquivalenceClasses<Instruction*>(); + SeenInsts.clear(); + ConvertedInsts.clear(); + Roots.clear(); + + Ctx = &F.getParent()->getContext(); + + findRoots(F, Roots); + + walkBackwards(Roots); + walkForwards(); + + bool Modified = validateAndTransform(); + if (Modified) + cleanup(); + return Modified; +} + +namespace llvm { +FunctionPass *createFloat2IntPass() { return new Float2IntLegacyPass(); } + +PreservedAnalyses Float2IntPass::run(Function &F, FunctionAnalysisManager &) { + if (!runImpl(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + return PA; +} +} // End namespace llvm diff --git a/contrib/llvm/lib/Transforms/Scalar/GVN.cpp b/contrib/llvm/lib/Transforms/Scalar/GVN.cpp new file mode 100644 index 000000000000..e2c1eaf58e43 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/GVN.cpp @@ -0,0 +1,2667 @@ +//===- GVN.cpp - Eliminate redundant values and loads ---------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs global value numbering to eliminate fully redundant +// instructions. It also performs simple dead load elimination. +// +// Note that this pass does the value numbering itself; it does not use the +// ValueNumbering analysis passes. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/PHITransAddr.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/VNCoercion.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <utility> +#include <vector> + +using namespace llvm; +using namespace llvm::gvn; +using namespace llvm::VNCoercion; +using namespace PatternMatch; + +#define DEBUG_TYPE "gvn" + +STATISTIC(NumGVNInstr, "Number of instructions deleted"); +STATISTIC(NumGVNLoad, "Number of loads deleted"); +STATISTIC(NumGVNPRE, "Number of instructions PRE'd"); +STATISTIC(NumGVNBlocks, "Number of blocks merged"); +STATISTIC(NumGVNSimpl, "Number of instructions simplified"); +STATISTIC(NumGVNEqProp, "Number of equalities propagated"); +STATISTIC(NumPRELoad, "Number of loads PRE'd"); + +static cl::opt<bool> EnablePRE("enable-pre", + cl::init(true), cl::Hidden); +static cl::opt<bool> EnableLoadPRE("enable-load-pre", cl::init(true)); + +// Maximum allowed recursion depth. +static cl::opt<uint32_t> +MaxRecurseDepth("max-recurse-depth", cl::Hidden, cl::init(1000), cl::ZeroOrMore, + cl::desc("Max recurse depth (default = 1000)")); + +struct llvm::GVN::Expression { + uint32_t opcode; + Type *type; + bool commutative = false; + SmallVector<uint32_t, 4> varargs; + + Expression(uint32_t o = ~2U) : opcode(o) {} + + bool operator==(const Expression &other) const { + if (opcode != other.opcode) + return false; + if (opcode == ~0U || opcode == ~1U) + return true; + if (type != other.type) + return false; + if (varargs != other.varargs) + return false; + return true; + } + + friend hash_code hash_value(const Expression &Value) { + return hash_combine( + Value.opcode, Value.type, + hash_combine_range(Value.varargs.begin(), Value.varargs.end())); + } +}; + +namespace llvm { + +template <> struct DenseMapInfo<GVN::Expression> { + static inline GVN::Expression getEmptyKey() { return ~0U; } + static inline GVN::Expression getTombstoneKey() { return ~1U; } + + static unsigned getHashValue(const GVN::Expression &e) { + using llvm::hash_value; + + return static_cast<unsigned>(hash_value(e)); + } + + static bool isEqual(const GVN::Expression &LHS, const GVN::Expression &RHS) { + return LHS == RHS; + } +}; + +} // end namespace llvm + +/// Represents a particular available value that we know how to materialize. +/// Materialization of an AvailableValue never fails. An AvailableValue is +/// implicitly associated with a rematerialization point which is the +/// location of the instruction from which it was formed. +struct llvm::gvn::AvailableValue { + enum ValType { + SimpleVal, // A simple offsetted value that is accessed. + LoadVal, // A value produced by a load. + MemIntrin, // A memory intrinsic which is loaded from. + UndefVal // A UndefValue representing a value from dead block (which + // is not yet physically removed from the CFG). + }; + + /// V - The value that is live out of the block. + PointerIntPair<Value *, 2, ValType> Val; + + /// Offset - The byte offset in Val that is interesting for the load query. + unsigned Offset; + + static AvailableValue get(Value *V, unsigned Offset = 0) { + AvailableValue Res; + Res.Val.setPointer(V); + Res.Val.setInt(SimpleVal); + Res.Offset = Offset; + return Res; + } + + static AvailableValue getMI(MemIntrinsic *MI, unsigned Offset = 0) { + AvailableValue Res; + Res.Val.setPointer(MI); + Res.Val.setInt(MemIntrin); + Res.Offset = Offset; + return Res; + } + + static AvailableValue getLoad(LoadInst *LI, unsigned Offset = 0) { + AvailableValue Res; + Res.Val.setPointer(LI); + Res.Val.setInt(LoadVal); + Res.Offset = Offset; + return Res; + } + + static AvailableValue getUndef() { + AvailableValue Res; + Res.Val.setPointer(nullptr); + Res.Val.setInt(UndefVal); + Res.Offset = 0; + return Res; + } + + bool isSimpleValue() const { return Val.getInt() == SimpleVal; } + bool isCoercedLoadValue() const { return Val.getInt() == LoadVal; } + bool isMemIntrinValue() const { return Val.getInt() == MemIntrin; } + bool isUndefValue() const { return Val.getInt() == UndefVal; } + + Value *getSimpleValue() const { + assert(isSimpleValue() && "Wrong accessor"); + return Val.getPointer(); + } + + LoadInst *getCoercedLoadValue() const { + assert(isCoercedLoadValue() && "Wrong accessor"); + return cast<LoadInst>(Val.getPointer()); + } + + MemIntrinsic *getMemIntrinValue() const { + assert(isMemIntrinValue() && "Wrong accessor"); + return cast<MemIntrinsic>(Val.getPointer()); + } + + /// Emit code at the specified insertion point to adjust the value defined + /// here to the specified type. This handles various coercion cases. + Value *MaterializeAdjustedValue(LoadInst *LI, Instruction *InsertPt, + GVN &gvn) const; +}; + +/// Represents an AvailableValue which can be rematerialized at the end of +/// the associated BasicBlock. +struct llvm::gvn::AvailableValueInBlock { + /// BB - The basic block in question. + BasicBlock *BB; + + /// AV - The actual available value + AvailableValue AV; + + static AvailableValueInBlock get(BasicBlock *BB, AvailableValue &&AV) { + AvailableValueInBlock Res; + Res.BB = BB; + Res.AV = std::move(AV); + return Res; + } + + static AvailableValueInBlock get(BasicBlock *BB, Value *V, + unsigned Offset = 0) { + return get(BB, AvailableValue::get(V, Offset)); + } + + static AvailableValueInBlock getUndef(BasicBlock *BB) { + return get(BB, AvailableValue::getUndef()); + } + + /// Emit code at the end of this block to adjust the value defined here to + /// the specified type. This handles various coercion cases. + Value *MaterializeAdjustedValue(LoadInst *LI, GVN &gvn) const { + return AV.MaterializeAdjustedValue(LI, BB->getTerminator(), gvn); + } +}; + +//===----------------------------------------------------------------------===// +// ValueTable Internal Functions +//===----------------------------------------------------------------------===// + +GVN::Expression GVN::ValueTable::createExpr(Instruction *I) { + Expression e; + e.type = I->getType(); + e.opcode = I->getOpcode(); + for (Instruction::op_iterator OI = I->op_begin(), OE = I->op_end(); + OI != OE; ++OI) + e.varargs.push_back(lookupOrAdd(*OI)); + if (I->isCommutative()) { + // Ensure that commutative instructions that only differ by a permutation + // of their operands get the same value number by sorting the operand value + // numbers. Since all commutative instructions have two operands it is more + // efficient to sort by hand rather than using, say, std::sort. + assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!"); + if (e.varargs[0] > e.varargs[1]) + std::swap(e.varargs[0], e.varargs[1]); + e.commutative = true; + } + + if (CmpInst *C = dyn_cast<CmpInst>(I)) { + // Sort the operand value numbers so x<y and y>x get the same value number. + CmpInst::Predicate Predicate = C->getPredicate(); + if (e.varargs[0] > e.varargs[1]) { + std::swap(e.varargs[0], e.varargs[1]); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + e.opcode = (C->getOpcode() << 8) | Predicate; + e.commutative = true; + } else if (InsertValueInst *E = dyn_cast<InsertValueInst>(I)) { + for (InsertValueInst::idx_iterator II = E->idx_begin(), IE = E->idx_end(); + II != IE; ++II) + e.varargs.push_back(*II); + } + + return e; +} + +GVN::Expression GVN::ValueTable::createCmpExpr(unsigned Opcode, + CmpInst::Predicate Predicate, + Value *LHS, Value *RHS) { + assert((Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) && + "Not a comparison!"); + Expression e; + e.type = CmpInst::makeCmpResultType(LHS->getType()); + e.varargs.push_back(lookupOrAdd(LHS)); + e.varargs.push_back(lookupOrAdd(RHS)); + + // Sort the operand value numbers so x<y and y>x get the same value number. + if (e.varargs[0] > e.varargs[1]) { + std::swap(e.varargs[0], e.varargs[1]); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + e.opcode = (Opcode << 8) | Predicate; + e.commutative = true; + return e; +} + +GVN::Expression GVN::ValueTable::createExtractvalueExpr(ExtractValueInst *EI) { + assert(EI && "Not an ExtractValueInst?"); + Expression e; + e.type = EI->getType(); + e.opcode = 0; + + IntrinsicInst *I = dyn_cast<IntrinsicInst>(EI->getAggregateOperand()); + if (I != nullptr && EI->getNumIndices() == 1 && *EI->idx_begin() == 0 ) { + // EI might be an extract from one of our recognised intrinsics. If it + // is we'll synthesize a semantically equivalent expression instead on + // an extract value expression. + switch (I->getIntrinsicID()) { + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: + e.opcode = Instruction::Add; + break; + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: + e.opcode = Instruction::Sub; + break; + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + e.opcode = Instruction::Mul; + break; + default: + break; + } + + if (e.opcode != 0) { + // Intrinsic recognized. Grab its args to finish building the expression. + assert(I->getNumArgOperands() == 2 && + "Expect two args for recognised intrinsics."); + e.varargs.push_back(lookupOrAdd(I->getArgOperand(0))); + e.varargs.push_back(lookupOrAdd(I->getArgOperand(1))); + return e; + } + } + + // Not a recognised intrinsic. Fall back to producing an extract value + // expression. + e.opcode = EI->getOpcode(); + for (Instruction::op_iterator OI = EI->op_begin(), OE = EI->op_end(); + OI != OE; ++OI) + e.varargs.push_back(lookupOrAdd(*OI)); + + for (ExtractValueInst::idx_iterator II = EI->idx_begin(), IE = EI->idx_end(); + II != IE; ++II) + e.varargs.push_back(*II); + + return e; +} + +//===----------------------------------------------------------------------===// +// ValueTable External Functions +//===----------------------------------------------------------------------===// + +GVN::ValueTable::ValueTable() = default; +GVN::ValueTable::ValueTable(const ValueTable &) = default; +GVN::ValueTable::ValueTable(ValueTable &&) = default; +GVN::ValueTable::~ValueTable() = default; + +/// add - Insert a value into the table with a specified value number. +void GVN::ValueTable::add(Value *V, uint32_t num) { + valueNumbering.insert(std::make_pair(V, num)); + if (PHINode *PN = dyn_cast<PHINode>(V)) + NumberingPhi[num] = PN; +} + +uint32_t GVN::ValueTable::lookupOrAddCall(CallInst *C) { + if (AA->doesNotAccessMemory(C)) { + Expression exp = createExpr(C); + uint32_t e = assignExpNewValueNum(exp).first; + valueNumbering[C] = e; + return e; + } else if (AA->onlyReadsMemory(C)) { + Expression exp = createExpr(C); + auto ValNum = assignExpNewValueNum(exp); + if (ValNum.second) { + valueNumbering[C] = ValNum.first; + return ValNum.first; + } + if (!MD) { + uint32_t e = assignExpNewValueNum(exp).first; + valueNumbering[C] = e; + return e; + } + + MemDepResult local_dep = MD->getDependency(C); + + if (!local_dep.isDef() && !local_dep.isNonLocal()) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + + if (local_dep.isDef()) { + CallInst* local_cdep = cast<CallInst>(local_dep.getInst()); + + if (local_cdep->getNumArgOperands() != C->getNumArgOperands()) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + + for (unsigned i = 0, e = C->getNumArgOperands(); i < e; ++i) { + uint32_t c_vn = lookupOrAdd(C->getArgOperand(i)); + uint32_t cd_vn = lookupOrAdd(local_cdep->getArgOperand(i)); + if (c_vn != cd_vn) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + } + + uint32_t v = lookupOrAdd(local_cdep); + valueNumbering[C] = v; + return v; + } + + // Non-local case. + const MemoryDependenceResults::NonLocalDepInfo &deps = + MD->getNonLocalCallDependency(CallSite(C)); + // FIXME: Move the checking logic to MemDep! + CallInst* cdep = nullptr; + + // Check to see if we have a single dominating call instruction that is + // identical to C. + for (unsigned i = 0, e = deps.size(); i != e; ++i) { + const NonLocalDepEntry *I = &deps[i]; + if (I->getResult().isNonLocal()) + continue; + + // We don't handle non-definitions. If we already have a call, reject + // instruction dependencies. + if (!I->getResult().isDef() || cdep != nullptr) { + cdep = nullptr; + break; + } + + CallInst *NonLocalDepCall = dyn_cast<CallInst>(I->getResult().getInst()); + // FIXME: All duplicated with non-local case. + if (NonLocalDepCall && DT->properlyDominates(I->getBB(), C->getParent())){ + cdep = NonLocalDepCall; + continue; + } + + cdep = nullptr; + break; + } + + if (!cdep) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + + if (cdep->getNumArgOperands() != C->getNumArgOperands()) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + for (unsigned i = 0, e = C->getNumArgOperands(); i < e; ++i) { + uint32_t c_vn = lookupOrAdd(C->getArgOperand(i)); + uint32_t cd_vn = lookupOrAdd(cdep->getArgOperand(i)); + if (c_vn != cd_vn) { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } + } + + uint32_t v = lookupOrAdd(cdep); + valueNumbering[C] = v; + return v; + } else { + valueNumbering[C] = nextValueNumber; + return nextValueNumber++; + } +} + +/// Returns true if a value number exists for the specified value. +bool GVN::ValueTable::exists(Value *V) const { return valueNumbering.count(V) != 0; } + +/// lookup_or_add - Returns the value number for the specified value, assigning +/// it a new number if it did not have one before. +uint32_t GVN::ValueTable::lookupOrAdd(Value *V) { + DenseMap<Value*, uint32_t>::iterator VI = valueNumbering.find(V); + if (VI != valueNumbering.end()) + return VI->second; + + if (!isa<Instruction>(V)) { + valueNumbering[V] = nextValueNumber; + return nextValueNumber++; + } + + Instruction* I = cast<Instruction>(V); + Expression exp; + switch (I->getOpcode()) { + case Instruction::Call: + return lookupOrAddCall(cast<CallInst>(I)); + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::BitCast: + case Instruction::Select: + case Instruction::ExtractElement: + case Instruction::InsertElement: + case Instruction::ShuffleVector: + case Instruction::InsertValue: + case Instruction::GetElementPtr: + exp = createExpr(I); + break; + case Instruction::ExtractValue: + exp = createExtractvalueExpr(cast<ExtractValueInst>(I)); + break; + case Instruction::PHI: + valueNumbering[V] = nextValueNumber; + NumberingPhi[nextValueNumber] = cast<PHINode>(V); + return nextValueNumber++; + default: + valueNumbering[V] = nextValueNumber; + return nextValueNumber++; + } + + uint32_t e = assignExpNewValueNum(exp).first; + valueNumbering[V] = e; + return e; +} + +/// Returns the value number of the specified value. Fails if +/// the value has not yet been numbered. +uint32_t GVN::ValueTable::lookup(Value *V, bool Verify) const { + DenseMap<Value*, uint32_t>::const_iterator VI = valueNumbering.find(V); + if (Verify) { + assert(VI != valueNumbering.end() && "Value not numbered?"); + return VI->second; + } + return (VI != valueNumbering.end()) ? VI->second : 0; +} + +/// Returns the value number of the given comparison, +/// assigning it a new number if it did not have one before. Useful when +/// we deduced the result of a comparison, but don't immediately have an +/// instruction realizing that comparison to hand. +uint32_t GVN::ValueTable::lookupOrAddCmp(unsigned Opcode, + CmpInst::Predicate Predicate, + Value *LHS, Value *RHS) { + Expression exp = createCmpExpr(Opcode, Predicate, LHS, RHS); + return assignExpNewValueNum(exp).first; +} + +/// Remove all entries from the ValueTable. +void GVN::ValueTable::clear() { + valueNumbering.clear(); + expressionNumbering.clear(); + NumberingPhi.clear(); + PhiTranslateTable.clear(); + nextValueNumber = 1; + Expressions.clear(); + ExprIdx.clear(); + nextExprNumber = 0; +} + +/// Remove a value from the value numbering. +void GVN::ValueTable::erase(Value *V) { + uint32_t Num = valueNumbering.lookup(V); + valueNumbering.erase(V); + // If V is PHINode, V <--> value number is an one-to-one mapping. + if (isa<PHINode>(V)) + NumberingPhi.erase(Num); +} + +/// verifyRemoved - Verify that the value is removed from all internal data +/// structures. +void GVN::ValueTable::verifyRemoved(const Value *V) const { + for (DenseMap<Value*, uint32_t>::const_iterator + I = valueNumbering.begin(), E = valueNumbering.end(); I != E; ++I) { + assert(I->first != V && "Inst still occurs in value numbering map!"); + } +} + +//===----------------------------------------------------------------------===// +// GVN Pass +//===----------------------------------------------------------------------===// + +PreservedAnalyses GVN::run(Function &F, FunctionAnalysisManager &AM) { + // FIXME: The order of evaluation of these 'getResult' calls is very + // significant! Re-ordering these variables will cause GVN when run alone to + // be less effective! We should fix memdep and basic-aa to not exhibit this + // behavior, but until then don't change the order here. + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &MemDep = AM.getResult<MemoryDependenceAnalysis>(F); + auto *LI = AM.getCachedResult<LoopAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + bool Changed = runImpl(F, AC, DT, TLI, AA, &MemDep, LI, &ORE); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); + PA.preserve<TargetLibraryAnalysis>(); + return PA; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +LLVM_DUMP_METHOD void GVN::dump(DenseMap<uint32_t, Value*>& d) const { + errs() << "{\n"; + for (DenseMap<uint32_t, Value*>::iterator I = d.begin(), + E = d.end(); I != E; ++I) { + errs() << I->first << "\n"; + I->second->dump(); + } + errs() << "}\n"; +} +#endif + +/// Return true if we can prove that the value +/// we're analyzing is fully available in the specified block. As we go, keep +/// track of which blocks we know are fully alive in FullyAvailableBlocks. This +/// map is actually a tri-state map with the following values: +/// 0) we know the block *is not* fully available. +/// 1) we know the block *is* fully available. +/// 2) we do not know whether the block is fully available or not, but we are +/// currently speculating that it will be. +/// 3) we are speculating for this block and have used that to speculate for +/// other blocks. +static bool IsValueFullyAvailableInBlock(BasicBlock *BB, + DenseMap<BasicBlock*, char> &FullyAvailableBlocks, + uint32_t RecurseDepth) { + if (RecurseDepth > MaxRecurseDepth) + return false; + + // Optimistically assume that the block is fully available and check to see + // if we already know about this block in one lookup. + std::pair<DenseMap<BasicBlock*, char>::iterator, char> IV = + FullyAvailableBlocks.insert(std::make_pair(BB, 2)); + + // If the entry already existed for this block, return the precomputed value. + if (!IV.second) { + // If this is a speculative "available" value, mark it as being used for + // speculation of other blocks. + if (IV.first->second == 2) + IV.first->second = 3; + return IV.first->second != 0; + } + + // Otherwise, see if it is fully available in all predecessors. + pred_iterator PI = pred_begin(BB), PE = pred_end(BB); + + // If this block has no predecessors, it isn't live-in here. + if (PI == PE) + goto SpeculationFailure; + + for (; PI != PE; ++PI) + // If the value isn't fully available in one of our predecessors, then it + // isn't fully available in this block either. Undo our previous + // optimistic assumption and bail out. + if (!IsValueFullyAvailableInBlock(*PI, FullyAvailableBlocks,RecurseDepth+1)) + goto SpeculationFailure; + + return true; + +// If we get here, we found out that this is not, after +// all, a fully-available block. We have a problem if we speculated on this and +// used the speculation to mark other blocks as available. +SpeculationFailure: + char &BBVal = FullyAvailableBlocks[BB]; + + // If we didn't speculate on this, just return with it set to false. + if (BBVal == 2) { + BBVal = 0; + return false; + } + + // If we did speculate on this value, we could have blocks set to 1 that are + // incorrect. Walk the (transitive) successors of this block and mark them as + // 0 if set to one. + SmallVector<BasicBlock*, 32> BBWorklist; + BBWorklist.push_back(BB); + + do { + BasicBlock *Entry = BBWorklist.pop_back_val(); + // Note that this sets blocks to 0 (unavailable) if they happen to not + // already be in FullyAvailableBlocks. This is safe. + char &EntryVal = FullyAvailableBlocks[Entry]; + if (EntryVal == 0) continue; // Already unavailable. + + // Mark as unavailable. + EntryVal = 0; + + BBWorklist.append(succ_begin(Entry), succ_end(Entry)); + } while (!BBWorklist.empty()); + + return false; +} + +/// Given a set of loads specified by ValuesPerBlock, +/// construct SSA form, allowing us to eliminate LI. This returns the value +/// that should be used at LI's definition site. +static Value *ConstructSSAForLoadSet(LoadInst *LI, + SmallVectorImpl<AvailableValueInBlock> &ValuesPerBlock, + GVN &gvn) { + // Check for the fully redundant, dominating load case. In this case, we can + // just use the dominating value directly. + if (ValuesPerBlock.size() == 1 && + gvn.getDominatorTree().properlyDominates(ValuesPerBlock[0].BB, + LI->getParent())) { + assert(!ValuesPerBlock[0].AV.isUndefValue() && + "Dead BB dominate this block"); + return ValuesPerBlock[0].MaterializeAdjustedValue(LI, gvn); + } + + // Otherwise, we have to construct SSA form. + SmallVector<PHINode*, 8> NewPHIs; + SSAUpdater SSAUpdate(&NewPHIs); + SSAUpdate.Initialize(LI->getType(), LI->getName()); + + for (const AvailableValueInBlock &AV : ValuesPerBlock) { + BasicBlock *BB = AV.BB; + + if (SSAUpdate.HasValueForBlock(BB)) + continue; + + SSAUpdate.AddAvailableValue(BB, AV.MaterializeAdjustedValue(LI, gvn)); + } + + // Perform PHI construction. + return SSAUpdate.GetValueInMiddleOfBlock(LI->getParent()); +} + +Value *AvailableValue::MaterializeAdjustedValue(LoadInst *LI, + Instruction *InsertPt, + GVN &gvn) const { + Value *Res; + Type *LoadTy = LI->getType(); + const DataLayout &DL = LI->getModule()->getDataLayout(); + if (isSimpleValue()) { + Res = getSimpleValue(); + if (Res->getType() != LoadTy) { + Res = getStoreValueForLoad(Res, Offset, LoadTy, InsertPt, DL); + + DEBUG(dbgs() << "GVN COERCED NONLOCAL VAL:\nOffset: " << Offset << " " + << *getSimpleValue() << '\n' + << *Res << '\n' << "\n\n\n"); + } + } else if (isCoercedLoadValue()) { + LoadInst *Load = getCoercedLoadValue(); + if (Load->getType() == LoadTy && Offset == 0) { + Res = Load; + } else { + Res = getLoadValueForLoad(Load, Offset, LoadTy, InsertPt, DL); + // We would like to use gvn.markInstructionForDeletion here, but we can't + // because the load is already memoized into the leader map table that GVN + // tracks. It is potentially possible to remove the load from the table, + // but then there all of the operations based on it would need to be + // rehashed. Just leave the dead load around. + gvn.getMemDep().removeInstruction(Load); + DEBUG(dbgs() << "GVN COERCED NONLOCAL LOAD:\nOffset: " << Offset << " " + << *getCoercedLoadValue() << '\n' + << *Res << '\n' + << "\n\n\n"); + } + } else if (isMemIntrinValue()) { + Res = getMemInstValueForLoad(getMemIntrinValue(), Offset, LoadTy, + InsertPt, DL); + DEBUG(dbgs() << "GVN COERCED NONLOCAL MEM INTRIN:\nOffset: " << Offset + << " " << *getMemIntrinValue() << '\n' + << *Res << '\n' << "\n\n\n"); + } else { + assert(isUndefValue() && "Should be UndefVal"); + DEBUG(dbgs() << "GVN COERCED NONLOCAL Undef:\n";); + return UndefValue::get(LoadTy); + } + assert(Res && "failed to materialize?"); + return Res; +} + +static bool isLifetimeStart(const Instruction *Inst) { + if (const IntrinsicInst* II = dyn_cast<IntrinsicInst>(Inst)) + return II->getIntrinsicID() == Intrinsic::lifetime_start; + return false; +} + +/// \brief Try to locate the three instruction involved in a missed +/// load-elimination case that is due to an intervening store. +static void reportMayClobberedLoad(LoadInst *LI, MemDepResult DepInfo, + DominatorTree *DT, + OptimizationRemarkEmitter *ORE) { + using namespace ore; + + User *OtherAccess = nullptr; + + OptimizationRemarkMissed R(DEBUG_TYPE, "LoadClobbered", LI); + R << "load of type " << NV("Type", LI->getType()) << " not eliminated" + << setExtraArgs(); + + for (auto *U : LI->getPointerOperand()->users()) + if (U != LI && (isa<LoadInst>(U) || isa<StoreInst>(U)) && + DT->dominates(cast<Instruction>(U), LI)) { + // FIXME: for now give up if there are multiple memory accesses that + // dominate the load. We need further analysis to decide which one is + // that we're forwarding from. + if (OtherAccess) + OtherAccess = nullptr; + else + OtherAccess = U; + } + + if (OtherAccess) + R << " in favor of " << NV("OtherAccess", OtherAccess); + + R << " because it is clobbered by " << NV("ClobberedBy", DepInfo.getInst()); + + ORE->emit(R); +} + +bool GVN::AnalyzeLoadAvailability(LoadInst *LI, MemDepResult DepInfo, + Value *Address, AvailableValue &Res) { + assert((DepInfo.isDef() || DepInfo.isClobber()) && + "expected a local dependence"); + assert(LI->isUnordered() && "rules below are incorrect for ordered access"); + + const DataLayout &DL = LI->getModule()->getDataLayout(); + + if (DepInfo.isClobber()) { + // If the dependence is to a store that writes to a superset of the bits + // read by the load, we can extract the bits we need for the load from the + // stored value. + if (StoreInst *DepSI = dyn_cast<StoreInst>(DepInfo.getInst())) { + // Can't forward from non-atomic to atomic without violating memory model. + if (Address && LI->isAtomic() <= DepSI->isAtomic()) { + int Offset = + analyzeLoadFromClobberingStore(LI->getType(), Address, DepSI, DL); + if (Offset != -1) { + Res = AvailableValue::get(DepSI->getValueOperand(), Offset); + return true; + } + } + } + + // Check to see if we have something like this: + // load i32* P + // load i8* (P+1) + // if we have this, replace the later with an extraction from the former. + if (LoadInst *DepLI = dyn_cast<LoadInst>(DepInfo.getInst())) { + // If this is a clobber and L is the first instruction in its block, then + // we have the first instruction in the entry block. + // Can't forward from non-atomic to atomic without violating memory model. + if (DepLI != LI && Address && LI->isAtomic() <= DepLI->isAtomic()) { + int Offset = + analyzeLoadFromClobberingLoad(LI->getType(), Address, DepLI, DL); + + if (Offset != -1) { + Res = AvailableValue::getLoad(DepLI, Offset); + return true; + } + } + } + + // If the clobbering value is a memset/memcpy/memmove, see if we can + // forward a value on from it. + if (MemIntrinsic *DepMI = dyn_cast<MemIntrinsic>(DepInfo.getInst())) { + if (Address && !LI->isAtomic()) { + int Offset = analyzeLoadFromClobberingMemInst(LI->getType(), Address, + DepMI, DL); + if (Offset != -1) { + Res = AvailableValue::getMI(DepMI, Offset); + return true; + } + } + } + // Nothing known about this clobber, have to be conservative + DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; + LI->printAsOperand(dbgs()); + Instruction *I = DepInfo.getInst(); + dbgs() << " is clobbered by " << *I << '\n'; + ); + if (ORE->allowExtraAnalysis(DEBUG_TYPE)) + reportMayClobberedLoad(LI, DepInfo, DT, ORE); + + return false; + } + assert(DepInfo.isDef() && "follows from above"); + + Instruction *DepInst = DepInfo.getInst(); + + // Loading the allocation -> undef. + if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI) || + // Loading immediately after lifetime begin -> undef. + isLifetimeStart(DepInst)) { + Res = AvailableValue::get(UndefValue::get(LI->getType())); + return true; + } + + // Loading from calloc (which zero initializes memory) -> zero + if (isCallocLikeFn(DepInst, TLI)) { + Res = AvailableValue::get(Constant::getNullValue(LI->getType())); + return true; + } + + if (StoreInst *S = dyn_cast<StoreInst>(DepInst)) { + // Reject loads and stores that are to the same address but are of + // different types if we have to. If the stored value is larger or equal to + // the loaded value, we can reuse it. + if (S->getValueOperand()->getType() != LI->getType() && + !canCoerceMustAliasedValueToLoad(S->getValueOperand(), + LI->getType(), DL)) + return false; + + // Can't forward from non-atomic to atomic without violating memory model. + if (S->isAtomic() < LI->isAtomic()) + return false; + + Res = AvailableValue::get(S->getValueOperand()); + return true; + } + + if (LoadInst *LD = dyn_cast<LoadInst>(DepInst)) { + // If the types mismatch and we can't handle it, reject reuse of the load. + // If the stored value is larger or equal to the loaded value, we can reuse + // it. + if (LD->getType() != LI->getType() && + !canCoerceMustAliasedValueToLoad(LD, LI->getType(), DL)) + return false; + + // Can't forward from non-atomic to atomic without violating memory model. + if (LD->isAtomic() < LI->isAtomic()) + return false; + + Res = AvailableValue::getLoad(LD); + return true; + } + + // Unknown def - must be conservative + DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; + LI->printAsOperand(dbgs()); + dbgs() << " has unknown def " << *DepInst << '\n'; + ); + return false; +} + +void GVN::AnalyzeLoadAvailability(LoadInst *LI, LoadDepVect &Deps, + AvailValInBlkVect &ValuesPerBlock, + UnavailBlkVect &UnavailableBlocks) { + // Filter out useless results (non-locals, etc). Keep track of the blocks + // where we have a value available in repl, also keep track of whether we see + // dependencies that produce an unknown value for the load (such as a call + // that could potentially clobber the load). + unsigned NumDeps = Deps.size(); + for (unsigned i = 0, e = NumDeps; i != e; ++i) { + BasicBlock *DepBB = Deps[i].getBB(); + MemDepResult DepInfo = Deps[i].getResult(); + + if (DeadBlocks.count(DepBB)) { + // Dead dependent mem-op disguise as a load evaluating the same value + // as the load in question. + ValuesPerBlock.push_back(AvailableValueInBlock::getUndef(DepBB)); + continue; + } + + if (!DepInfo.isDef() && !DepInfo.isClobber()) { + UnavailableBlocks.push_back(DepBB); + continue; + } + + // The address being loaded in this non-local block may not be the same as + // the pointer operand of the load if PHI translation occurs. Make sure + // to consider the right address. + Value *Address = Deps[i].getAddress(); + + AvailableValue AV; + if (AnalyzeLoadAvailability(LI, DepInfo, Address, AV)) { + // subtlety: because we know this was a non-local dependency, we know + // it's safe to materialize anywhere between the instruction within + // DepInfo and the end of it's block. + ValuesPerBlock.push_back(AvailableValueInBlock::get(DepBB, + std::move(AV))); + } else { + UnavailableBlocks.push_back(DepBB); + } + } + + assert(NumDeps == ValuesPerBlock.size() + UnavailableBlocks.size() && + "post condition violation"); +} + +bool GVN::PerformLoadPRE(LoadInst *LI, AvailValInBlkVect &ValuesPerBlock, + UnavailBlkVect &UnavailableBlocks) { + // Okay, we have *some* definitions of the value. This means that the value + // is available in some of our (transitive) predecessors. Lets think about + // doing PRE of this load. This will involve inserting a new load into the + // predecessor when it's not available. We could do this in general, but + // prefer to not increase code size. As such, we only do this when we know + // that we only have to insert *one* load (which means we're basically moving + // the load, not inserting a new one). + + SmallPtrSet<BasicBlock *, 4> Blockers(UnavailableBlocks.begin(), + UnavailableBlocks.end()); + + // Let's find the first basic block with more than one predecessor. Walk + // backwards through predecessors if needed. + BasicBlock *LoadBB = LI->getParent(); + BasicBlock *TmpBB = LoadBB; + bool IsSafeToSpeculativelyExecute = isSafeToSpeculativelyExecute(LI); + + // Check that there is no implicit control flow instructions above our load in + // its block. If there is an instruction that doesn't always pass the + // execution to the following instruction, then moving through it may become + // invalid. For example: + // + // int arr[LEN]; + // int index = ???; + // ... + // guard(0 <= index && index < LEN); + // use(arr[index]); + // + // It is illegal to move the array access to any point above the guard, + // because if the index is out of bounds we should deoptimize rather than + // access the array. + // Check that there is no guard in this block above our intruction. + if (!IsSafeToSpeculativelyExecute) { + auto It = FirstImplicitControlFlowInsts.find(TmpBB); + if (It != FirstImplicitControlFlowInsts.end()) { + assert(It->second->getParent() == TmpBB && + "Implicit control flow map broken?"); + if (OI->dominates(It->second, LI)) + return false; + } + } + while (TmpBB->getSinglePredecessor()) { + TmpBB = TmpBB->getSinglePredecessor(); + if (TmpBB == LoadBB) // Infinite (unreachable) loop. + return false; + if (Blockers.count(TmpBB)) + return false; + + // If any of these blocks has more than one successor (i.e. if the edge we + // just traversed was critical), then there are other paths through this + // block along which the load may not be anticipated. Hoisting the load + // above this block would be adding the load to execution paths along + // which it was not previously executed. + if (TmpBB->getTerminator()->getNumSuccessors() != 1) + return false; + + // Check that there is no implicit control flow in a block above. + if (!IsSafeToSpeculativelyExecute && + FirstImplicitControlFlowInsts.count(TmpBB)) + return false; + } + + assert(TmpBB); + LoadBB = TmpBB; + + // Check to see how many predecessors have the loaded value fully + // available. + MapVector<BasicBlock *, Value *> PredLoads; + DenseMap<BasicBlock*, char> FullyAvailableBlocks; + for (const AvailableValueInBlock &AV : ValuesPerBlock) + FullyAvailableBlocks[AV.BB] = true; + for (BasicBlock *UnavailableBB : UnavailableBlocks) + FullyAvailableBlocks[UnavailableBB] = false; + + SmallVector<BasicBlock *, 4> CriticalEdgePred; + for (BasicBlock *Pred : predecessors(LoadBB)) { + // If any predecessor block is an EH pad that does not allow non-PHI + // instructions before the terminator, we can't PRE the load. + if (Pred->getTerminator()->isEHPad()) { + DEBUG(dbgs() + << "COULD NOT PRE LOAD BECAUSE OF AN EH PAD PREDECESSOR '" + << Pred->getName() << "': " << *LI << '\n'); + return false; + } + + if (IsValueFullyAvailableInBlock(Pred, FullyAvailableBlocks, 0)) { + continue; + } + + if (Pred->getTerminator()->getNumSuccessors() != 1) { + if (isa<IndirectBrInst>(Pred->getTerminator())) { + DEBUG(dbgs() << "COULD NOT PRE LOAD BECAUSE OF INDBR CRITICAL EDGE '" + << Pred->getName() << "': " << *LI << '\n'); + return false; + } + + if (LoadBB->isEHPad()) { + DEBUG(dbgs() + << "COULD NOT PRE LOAD BECAUSE OF AN EH PAD CRITICAL EDGE '" + << Pred->getName() << "': " << *LI << '\n'); + return false; + } + + CriticalEdgePred.push_back(Pred); + } else { + // Only add the predecessors that will not be split for now. + PredLoads[Pred] = nullptr; + } + } + + // Decide whether PRE is profitable for this load. + unsigned NumUnavailablePreds = PredLoads.size() + CriticalEdgePred.size(); + assert(NumUnavailablePreds != 0 && + "Fully available value should already be eliminated!"); + + // If this load is unavailable in multiple predecessors, reject it. + // FIXME: If we could restructure the CFG, we could make a common pred with + // all the preds that don't have an available LI and insert a new load into + // that one block. + if (NumUnavailablePreds != 1) + return false; + + // Split critical edges, and update the unavailable predecessors accordingly. + for (BasicBlock *OrigPred : CriticalEdgePred) { + BasicBlock *NewPred = splitCriticalEdges(OrigPred, LoadBB); + assert(!PredLoads.count(OrigPred) && "Split edges shouldn't be in map!"); + PredLoads[NewPred] = nullptr; + DEBUG(dbgs() << "Split critical edge " << OrigPred->getName() << "->" + << LoadBB->getName() << '\n'); + } + + // Check if the load can safely be moved to all the unavailable predecessors. + bool CanDoPRE = true; + const DataLayout &DL = LI->getModule()->getDataLayout(); + SmallVector<Instruction*, 8> NewInsts; + for (auto &PredLoad : PredLoads) { + BasicBlock *UnavailablePred = PredLoad.first; + + // Do PHI translation to get its value in the predecessor if necessary. The + // returned pointer (if non-null) is guaranteed to dominate UnavailablePred. + + // If all preds have a single successor, then we know it is safe to insert + // the load on the pred (?!?), so we can insert code to materialize the + // pointer if it is not available. + PHITransAddr Address(LI->getPointerOperand(), DL, AC); + Value *LoadPtr = nullptr; + LoadPtr = Address.PHITranslateWithInsertion(LoadBB, UnavailablePred, + *DT, NewInsts); + + // If we couldn't find or insert a computation of this phi translated value, + // we fail PRE. + if (!LoadPtr) { + DEBUG(dbgs() << "COULDN'T INSERT PHI TRANSLATED VALUE OF: " + << *LI->getPointerOperand() << "\n"); + CanDoPRE = false; + break; + } + + PredLoad.second = LoadPtr; + } + + if (!CanDoPRE) { + while (!NewInsts.empty()) { + Instruction *I = NewInsts.pop_back_val(); + markInstructionForDeletion(I); + } + // HINT: Don't revert the edge-splitting as following transformation may + // also need to split these critical edges. + return !CriticalEdgePred.empty(); + } + + // Okay, we can eliminate this load by inserting a reload in the predecessor + // and using PHI construction to get the value in the other predecessors, do + // it. + DEBUG(dbgs() << "GVN REMOVING PRE LOAD: " << *LI << '\n'); + DEBUG(if (!NewInsts.empty()) + dbgs() << "INSERTED " << NewInsts.size() << " INSTS: " + << *NewInsts.back() << '\n'); + + // Assign value numbers to the new instructions. + for (Instruction *I : NewInsts) { + // Instructions that have been inserted in predecessor(s) to materialize + // the load address do not retain their original debug locations. Doing + // so could lead to confusing (but correct) source attributions. + // FIXME: How do we retain source locations without causing poor debugging + // behavior? + I->setDebugLoc(DebugLoc()); + + // FIXME: We really _ought_ to insert these value numbers into their + // parent's availability map. However, in doing so, we risk getting into + // ordering issues. If a block hasn't been processed yet, we would be + // marking a value as AVAIL-IN, which isn't what we intend. + VN.lookupOrAdd(I); + } + + for (const auto &PredLoad : PredLoads) { + BasicBlock *UnavailablePred = PredLoad.first; + Value *LoadPtr = PredLoad.second; + + auto *NewLoad = new LoadInst(LoadPtr, LI->getName()+".pre", + LI->isVolatile(), LI->getAlignment(), + LI->getOrdering(), LI->getSyncScopeID(), + UnavailablePred->getTerminator()); + NewLoad->setDebugLoc(LI->getDebugLoc()); + + // Transfer the old load's AA tags to the new load. + AAMDNodes Tags; + LI->getAAMetadata(Tags); + if (Tags) + NewLoad->setAAMetadata(Tags); + + if (auto *MD = LI->getMetadata(LLVMContext::MD_invariant_load)) + NewLoad->setMetadata(LLVMContext::MD_invariant_load, MD); + if (auto *InvGroupMD = LI->getMetadata(LLVMContext::MD_invariant_group)) + NewLoad->setMetadata(LLVMContext::MD_invariant_group, InvGroupMD); + if (auto *RangeMD = LI->getMetadata(LLVMContext::MD_range)) + NewLoad->setMetadata(LLVMContext::MD_range, RangeMD); + + // We do not propagate the old load's debug location, because the new + // load now lives in a different BB, and we want to avoid a jumpy line + // table. + // FIXME: How do we retain source locations without causing poor debugging + // behavior? + + // Add the newly created load. + ValuesPerBlock.push_back(AvailableValueInBlock::get(UnavailablePred, + NewLoad)); + MD->invalidateCachedPointerInfo(LoadPtr); + DEBUG(dbgs() << "GVN INSERTED " << *NewLoad << '\n'); + } + + // Perform PHI construction. + Value *V = ConstructSSAForLoadSet(LI, ValuesPerBlock, *this); + LI->replaceAllUsesWith(V); + if (isa<PHINode>(V)) + V->takeName(LI); + if (Instruction *I = dyn_cast<Instruction>(V)) + I->setDebugLoc(LI->getDebugLoc()); + if (V->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(V); + markInstructionForDeletion(LI); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "LoadPRE", LI) + << "load eliminated by PRE"; + }); + ++NumPRELoad; + return true; +} + +static void reportLoadElim(LoadInst *LI, Value *AvailableValue, + OptimizationRemarkEmitter *ORE) { + using namespace ore; + + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "LoadElim", LI) + << "load of type " << NV("Type", LI->getType()) << " eliminated" + << setExtraArgs() << " in favor of " + << NV("InfavorOfValue", AvailableValue); + }); +} + +/// Attempt to eliminate a load whose dependencies are +/// non-local by performing PHI construction. +bool GVN::processNonLocalLoad(LoadInst *LI) { + // non-local speculations are not allowed under asan. + if (LI->getParent()->getParent()->hasFnAttribute( + Attribute::SanitizeAddress) || + LI->getParent()->getParent()->hasFnAttribute( + Attribute::SanitizeHWAddress)) + return false; + + // Step 1: Find the non-local dependencies of the load. + LoadDepVect Deps; + MD->getNonLocalPointerDependency(LI, Deps); + + // If we had to process more than one hundred blocks to find the + // dependencies, this load isn't worth worrying about. Optimizing + // it will be too expensive. + unsigned NumDeps = Deps.size(); + if (NumDeps > 100) + return false; + + // If we had a phi translation failure, we'll have a single entry which is a + // clobber in the current block. Reject this early. + if (NumDeps == 1 && + !Deps[0].getResult().isDef() && !Deps[0].getResult().isClobber()) { + DEBUG( + dbgs() << "GVN: non-local load "; + LI->printAsOperand(dbgs()); + dbgs() << " has unknown dependencies\n"; + ); + return false; + } + + // If this load follows a GEP, see if we can PRE the indices before analyzing. + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LI->getOperand(0))) { + for (GetElementPtrInst::op_iterator OI = GEP->idx_begin(), + OE = GEP->idx_end(); + OI != OE; ++OI) + if (Instruction *I = dyn_cast<Instruction>(OI->get())) + performScalarPRE(I); + } + + // Step 2: Analyze the availability of the load + AvailValInBlkVect ValuesPerBlock; + UnavailBlkVect UnavailableBlocks; + AnalyzeLoadAvailability(LI, Deps, ValuesPerBlock, UnavailableBlocks); + + // If we have no predecessors that produce a known value for this load, exit + // early. + if (ValuesPerBlock.empty()) + return false; + + // Step 3: Eliminate fully redundancy. + // + // If all of the instructions we depend on produce a known value for this + // load, then it is fully redundant and we can use PHI insertion to compute + // its value. Insert PHIs and remove the fully redundant value now. + if (UnavailableBlocks.empty()) { + DEBUG(dbgs() << "GVN REMOVING NONLOCAL LOAD: " << *LI << '\n'); + + // Perform PHI construction. + Value *V = ConstructSSAForLoadSet(LI, ValuesPerBlock, *this); + LI->replaceAllUsesWith(V); + + if (isa<PHINode>(V)) + V->takeName(LI); + if (Instruction *I = dyn_cast<Instruction>(V)) + // If instruction I has debug info, then we should not update it. + // Also, if I has a null DebugLoc, then it is still potentially incorrect + // to propagate LI's DebugLoc because LI may not post-dominate I. + if (LI->getDebugLoc() && LI->getParent() == I->getParent()) + I->setDebugLoc(LI->getDebugLoc()); + if (V->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(V); + markInstructionForDeletion(LI); + ++NumGVNLoad; + reportLoadElim(LI, V, ORE); + return true; + } + + // Step 4: Eliminate partial redundancy. + if (!EnablePRE || !EnableLoadPRE) + return false; + + return PerformLoadPRE(LI, ValuesPerBlock, UnavailableBlocks); +} + +bool GVN::processAssumeIntrinsic(IntrinsicInst *IntrinsicI) { + assert(IntrinsicI->getIntrinsicID() == Intrinsic::assume && + "This function can only be called with llvm.assume intrinsic"); + Value *V = IntrinsicI->getArgOperand(0); + + if (ConstantInt *Cond = dyn_cast<ConstantInt>(V)) { + if (Cond->isZero()) { + Type *Int8Ty = Type::getInt8Ty(V->getContext()); + // Insert a new store to null instruction before the load to indicate that + // this code is not reachable. FIXME: We could insert unreachable + // instruction directly because we can modify the CFG. + new StoreInst(UndefValue::get(Int8Ty), + Constant::getNullValue(Int8Ty->getPointerTo()), + IntrinsicI); + } + markInstructionForDeletion(IntrinsicI); + return false; + } else if (isa<Constant>(V)) { + // If it's not false, and constant, it must evaluate to true. This means our + // assume is assume(true), and thus, pointless, and we don't want to do + // anything more here. + return false; + } + + Constant *True = ConstantInt::getTrue(V->getContext()); + bool Changed = false; + + for (BasicBlock *Successor : successors(IntrinsicI->getParent())) { + BasicBlockEdge Edge(IntrinsicI->getParent(), Successor); + + // This property is only true in dominated successors, propagateEquality + // will check dominance for us. + Changed |= propagateEquality(V, True, Edge, false); + } + + // We can replace assume value with true, which covers cases like this: + // call void @llvm.assume(i1 %cmp) + // br i1 %cmp, label %bb1, label %bb2 ; will change %cmp to true + ReplaceWithConstMap[V] = True; + + // If one of *cmp *eq operand is const, adding it to map will cover this: + // %cmp = fcmp oeq float 3.000000e+00, %0 ; const on lhs could happen + // call void @llvm.assume(i1 %cmp) + // ret float %0 ; will change it to ret float 3.000000e+00 + if (auto *CmpI = dyn_cast<CmpInst>(V)) { + if (CmpI->getPredicate() == CmpInst::Predicate::ICMP_EQ || + CmpI->getPredicate() == CmpInst::Predicate::FCMP_OEQ || + (CmpI->getPredicate() == CmpInst::Predicate::FCMP_UEQ && + CmpI->getFastMathFlags().noNaNs())) { + Value *CmpLHS = CmpI->getOperand(0); + Value *CmpRHS = CmpI->getOperand(1); + if (isa<Constant>(CmpLHS)) + std::swap(CmpLHS, CmpRHS); + auto *RHSConst = dyn_cast<Constant>(CmpRHS); + + // If only one operand is constant. + if (RHSConst != nullptr && !isa<Constant>(CmpLHS)) + ReplaceWithConstMap[CmpLHS] = RHSConst; + } + } + return Changed; +} + +static void patchReplacementInstruction(Instruction *I, Value *Repl) { + auto *ReplInst = dyn_cast<Instruction>(Repl); + if (!ReplInst) + return; + + // Patch the replacement so that it is not more restrictive than the value + // being replaced. + // Note that if 'I' is a load being replaced by some operation, + // for example, by an arithmetic operation, then andIRFlags() + // would just erase all math flags from the original arithmetic + // operation, which is clearly not wanted and not needed. + if (!isa<LoadInst>(I)) + ReplInst->andIRFlags(I); + + // FIXME: If both the original and replacement value are part of the + // same control-flow region (meaning that the execution of one + // guarantees the execution of the other), then we can combine the + // noalias scopes here and do better than the general conservative + // answer used in combineMetadata(). + + // In general, GVN unifies expressions over different control-flow + // regions, and so we need a conservative combination of the noalias + // scopes. + static const unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group}; + combineMetadata(ReplInst, I, KnownIDs); +} + +static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { + patchReplacementInstruction(I, Repl); + I->replaceAllUsesWith(Repl); +} + +/// Attempt to eliminate a load, first by eliminating it +/// locally, and then attempting non-local elimination if that fails. +bool GVN::processLoad(LoadInst *L) { + if (!MD) + return false; + + // This code hasn't been audited for ordered or volatile memory access + if (!L->isUnordered()) + return false; + + if (L->use_empty()) { + markInstructionForDeletion(L); + return true; + } + + // ... to a pointer that has been loaded from before... + MemDepResult Dep = MD->getDependency(L); + + // If it is defined in another block, try harder. + if (Dep.isNonLocal()) + return processNonLocalLoad(L); + + // Only handle the local case below + if (!Dep.isDef() && !Dep.isClobber()) { + // This might be a NonFuncLocal or an Unknown + DEBUG( + // fast print dep, using operator<< on instruction is too slow. + dbgs() << "GVN: load "; + L->printAsOperand(dbgs()); + dbgs() << " has unknown dependence\n"; + ); + return false; + } + + AvailableValue AV; + if (AnalyzeLoadAvailability(L, Dep, L->getPointerOperand(), AV)) { + Value *AvailableValue = AV.MaterializeAdjustedValue(L, L, *this); + + // Replace the load! + patchAndReplaceAllUsesWith(L, AvailableValue); + markInstructionForDeletion(L); + ++NumGVNLoad; + reportLoadElim(L, AvailableValue, ORE); + // Tell MDA to rexamine the reused pointer since we might have more + // information after forwarding it. + if (MD && AvailableValue->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(AvailableValue); + return true; + } + + return false; +} + +/// Return a pair the first field showing the value number of \p Exp and the +/// second field showing whether it is a value number newly created. +std::pair<uint32_t, bool> +GVN::ValueTable::assignExpNewValueNum(Expression &Exp) { + uint32_t &e = expressionNumbering[Exp]; + bool CreateNewValNum = !e; + if (CreateNewValNum) { + Expressions.push_back(Exp); + if (ExprIdx.size() < nextValueNumber + 1) + ExprIdx.resize(nextValueNumber * 2); + e = nextValueNumber; + ExprIdx[nextValueNumber++] = nextExprNumber++; + } + return {e, CreateNewValNum}; +} + +/// Return whether all the values related with the same \p num are +/// defined in \p BB. +bool GVN::ValueTable::areAllValsInBB(uint32_t Num, const BasicBlock *BB, + GVN &Gvn) { + LeaderTableEntry *Vals = &Gvn.LeaderTable[Num]; + while (Vals && Vals->BB == BB) + Vals = Vals->Next; + return !Vals; +} + +/// Wrap phiTranslateImpl to provide caching functionality. +uint32_t GVN::ValueTable::phiTranslate(const BasicBlock *Pred, + const BasicBlock *PhiBlock, uint32_t Num, + GVN &Gvn) { + auto FindRes = PhiTranslateTable.find({Num, Pred}); + if (FindRes != PhiTranslateTable.end()) + return FindRes->second; + uint32_t NewNum = phiTranslateImpl(Pred, PhiBlock, Num, Gvn); + PhiTranslateTable.insert({{Num, Pred}, NewNum}); + return NewNum; +} + +/// Translate value number \p Num using phis, so that it has the values of +/// the phis in BB. +uint32_t GVN::ValueTable::phiTranslateImpl(const BasicBlock *Pred, + const BasicBlock *PhiBlock, + uint32_t Num, GVN &Gvn) { + if (PHINode *PN = NumberingPhi[Num]) { + for (unsigned i = 0; i != PN->getNumIncomingValues(); ++i) { + if (PN->getParent() == PhiBlock && PN->getIncomingBlock(i) == Pred) + if (uint32_t TransVal = lookup(PN->getIncomingValue(i), false)) + return TransVal; + } + return Num; + } + + // If there is any value related with Num is defined in a BB other than + // PhiBlock, it cannot depend on a phi in PhiBlock without going through + // a backedge. We can do an early exit in that case to save compile time. + if (!areAllValsInBB(Num, PhiBlock, Gvn)) + return Num; + + if (Num >= ExprIdx.size() || ExprIdx[Num] == 0) + return Num; + Expression Exp = Expressions[ExprIdx[Num]]; + + for (unsigned i = 0; i < Exp.varargs.size(); i++) { + // For InsertValue and ExtractValue, some varargs are index numbers + // instead of value numbers. Those index numbers should not be + // translated. + if ((i > 1 && Exp.opcode == Instruction::InsertValue) || + (i > 0 && Exp.opcode == Instruction::ExtractValue)) + continue; + Exp.varargs[i] = phiTranslate(Pred, PhiBlock, Exp.varargs[i], Gvn); + } + + if (Exp.commutative) { + assert(Exp.varargs.size() == 2 && "Unsupported commutative expression!"); + if (Exp.varargs[0] > Exp.varargs[1]) { + std::swap(Exp.varargs[0], Exp.varargs[1]); + uint32_t Opcode = Exp.opcode >> 8; + if (Opcode == Instruction::ICmp || Opcode == Instruction::FCmp) + Exp.opcode = (Opcode << 8) | + CmpInst::getSwappedPredicate( + static_cast<CmpInst::Predicate>(Exp.opcode & 255)); + } + } + + if (uint32_t NewNum = expressionNumbering[Exp]) + return NewNum; + return Num; +} + +/// Erase stale entry from phiTranslate cache so phiTranslate can be computed +/// again. +void GVN::ValueTable::eraseTranslateCacheEntry(uint32_t Num, + const BasicBlock &CurrBlock) { + for (const BasicBlock *Pred : predecessors(&CurrBlock)) { + auto FindRes = PhiTranslateTable.find({Num, Pred}); + if (FindRes != PhiTranslateTable.end()) + PhiTranslateTable.erase(FindRes); + } +} + +// In order to find a leader for a given value number at a +// specific basic block, we first obtain the list of all Values for that number, +// and then scan the list to find one whose block dominates the block in +// question. This is fast because dominator tree queries consist of only +// a few comparisons of DFS numbers. +Value *GVN::findLeader(const BasicBlock *BB, uint32_t num) { + LeaderTableEntry Vals = LeaderTable[num]; + if (!Vals.Val) return nullptr; + + Value *Val = nullptr; + if (DT->dominates(Vals.BB, BB)) { + Val = Vals.Val; + if (isa<Constant>(Val)) return Val; + } + + LeaderTableEntry* Next = Vals.Next; + while (Next) { + if (DT->dominates(Next->BB, BB)) { + if (isa<Constant>(Next->Val)) return Next->Val; + if (!Val) Val = Next->Val; + } + + Next = Next->Next; + } + + return Val; +} + +/// There is an edge from 'Src' to 'Dst'. Return +/// true if every path from the entry block to 'Dst' passes via this edge. In +/// particular 'Dst' must not be reachable via another edge from 'Src'. +static bool isOnlyReachableViaThisEdge(const BasicBlockEdge &E, + DominatorTree *DT) { + // While in theory it is interesting to consider the case in which Dst has + // more than one predecessor, because Dst might be part of a loop which is + // only reachable from Src, in practice it is pointless since at the time + // GVN runs all such loops have preheaders, which means that Dst will have + // been changed to have only one predecessor, namely Src. + const BasicBlock *Pred = E.getEnd()->getSinglePredecessor(); + assert((!Pred || Pred == E.getStart()) && + "No edge between these basic blocks!"); + return Pred != nullptr; +} + +void GVN::assignBlockRPONumber(Function &F) { + uint32_t NextBlockNumber = 1; + ReversePostOrderTraversal<Function *> RPOT(&F); + for (BasicBlock *BB : RPOT) + BlockRPONumber[BB] = NextBlockNumber++; +} + +// Tries to replace instruction with const, using information from +// ReplaceWithConstMap. +bool GVN::replaceOperandsWithConsts(Instruction *Instr) const { + bool Changed = false; + for (unsigned OpNum = 0; OpNum < Instr->getNumOperands(); ++OpNum) { + Value *Operand = Instr->getOperand(OpNum); + auto it = ReplaceWithConstMap.find(Operand); + if (it != ReplaceWithConstMap.end()) { + assert(!isa<Constant>(Operand) && + "Replacing constants with constants is invalid"); + DEBUG(dbgs() << "GVN replacing: " << *Operand << " with " << *it->second + << " in instruction " << *Instr << '\n'); + Instr->setOperand(OpNum, it->second); + Changed = true; + } + } + return Changed; +} + +/// The given values are known to be equal in every block +/// dominated by 'Root'. Exploit this, for example by replacing 'LHS' with +/// 'RHS' everywhere in the scope. Returns whether a change was made. +/// If DominatesByEdge is false, then it means that we will propagate the RHS +/// value starting from the end of Root.Start. +bool GVN::propagateEquality(Value *LHS, Value *RHS, const BasicBlockEdge &Root, + bool DominatesByEdge) { + SmallVector<std::pair<Value*, Value*>, 4> Worklist; + Worklist.push_back(std::make_pair(LHS, RHS)); + bool Changed = false; + // For speed, compute a conservative fast approximation to + // DT->dominates(Root, Root.getEnd()); + const bool RootDominatesEnd = isOnlyReachableViaThisEdge(Root, DT); + + while (!Worklist.empty()) { + std::pair<Value*, Value*> Item = Worklist.pop_back_val(); + LHS = Item.first; RHS = Item.second; + + if (LHS == RHS) + continue; + assert(LHS->getType() == RHS->getType() && "Equality but unequal types!"); + + // Don't try to propagate equalities between constants. + if (isa<Constant>(LHS) && isa<Constant>(RHS)) + continue; + + // Prefer a constant on the right-hand side, or an Argument if no constants. + if (isa<Constant>(LHS) || (isa<Argument>(LHS) && !isa<Constant>(RHS))) + std::swap(LHS, RHS); + assert((isa<Argument>(LHS) || isa<Instruction>(LHS)) && "Unexpected value!"); + + // If there is no obvious reason to prefer the left-hand side over the + // right-hand side, ensure the longest lived term is on the right-hand side, + // so the shortest lived term will be replaced by the longest lived. + // This tends to expose more simplifications. + uint32_t LVN = VN.lookupOrAdd(LHS); + if ((isa<Argument>(LHS) && isa<Argument>(RHS)) || + (isa<Instruction>(LHS) && isa<Instruction>(RHS))) { + // Move the 'oldest' value to the right-hand side, using the value number + // as a proxy for age. + uint32_t RVN = VN.lookupOrAdd(RHS); + if (LVN < RVN) { + std::swap(LHS, RHS); + LVN = RVN; + } + } + + // If value numbering later sees that an instruction in the scope is equal + // to 'LHS' then ensure it will be turned into 'RHS'. In order to preserve + // the invariant that instructions only occur in the leader table for their + // own value number (this is used by removeFromLeaderTable), do not do this + // if RHS is an instruction (if an instruction in the scope is morphed into + // LHS then it will be turned into RHS by the next GVN iteration anyway, so + // using the leader table is about compiling faster, not optimizing better). + // The leader table only tracks basic blocks, not edges. Only add to if we + // have the simple case where the edge dominates the end. + if (RootDominatesEnd && !isa<Instruction>(RHS)) + addToLeaderTable(LVN, RHS, Root.getEnd()); + + // Replace all occurrences of 'LHS' with 'RHS' everywhere in the scope. As + // LHS always has at least one use that is not dominated by Root, this will + // never do anything if LHS has only one use. + if (!LHS->hasOneUse()) { + unsigned NumReplacements = + DominatesByEdge + ? replaceDominatedUsesWith(LHS, RHS, *DT, Root) + : replaceDominatedUsesWith(LHS, RHS, *DT, Root.getStart()); + + Changed |= NumReplacements > 0; + NumGVNEqProp += NumReplacements; + } + + // Now try to deduce additional equalities from this one. For example, if + // the known equality was "(A != B)" == "false" then it follows that A and B + // are equal in the scope. Only boolean equalities with an explicit true or + // false RHS are currently supported. + if (!RHS->getType()->isIntegerTy(1)) + // Not a boolean equality - bail out. + continue; + ConstantInt *CI = dyn_cast<ConstantInt>(RHS); + if (!CI) + // RHS neither 'true' nor 'false' - bail out. + continue; + // Whether RHS equals 'true'. Otherwise it equals 'false'. + bool isKnownTrue = CI->isMinusOne(); + bool isKnownFalse = !isKnownTrue; + + // If "A && B" is known true then both A and B are known true. If "A || B" + // is known false then both A and B are known false. + Value *A, *B; + if ((isKnownTrue && match(LHS, m_And(m_Value(A), m_Value(B)))) || + (isKnownFalse && match(LHS, m_Or(m_Value(A), m_Value(B))))) { + Worklist.push_back(std::make_pair(A, RHS)); + Worklist.push_back(std::make_pair(B, RHS)); + continue; + } + + // If we are propagating an equality like "(A == B)" == "true" then also + // propagate the equality A == B. When propagating a comparison such as + // "(A >= B)" == "true", replace all instances of "A < B" with "false". + if (CmpInst *Cmp = dyn_cast<CmpInst>(LHS)) { + Value *Op0 = Cmp->getOperand(0), *Op1 = Cmp->getOperand(1); + + // If "A == B" is known true, or "A != B" is known false, then replace + // A with B everywhere in the scope. + if ((isKnownTrue && Cmp->getPredicate() == CmpInst::ICMP_EQ) || + (isKnownFalse && Cmp->getPredicate() == CmpInst::ICMP_NE)) + Worklist.push_back(std::make_pair(Op0, Op1)); + + // Handle the floating point versions of equality comparisons too. + if ((isKnownTrue && Cmp->getPredicate() == CmpInst::FCMP_OEQ) || + (isKnownFalse && Cmp->getPredicate() == CmpInst::FCMP_UNE)) { + + // Floating point -0.0 and 0.0 compare equal, so we can only + // propagate values if we know that we have a constant and that + // its value is non-zero. + + // FIXME: We should do this optimization if 'no signed zeros' is + // applicable via an instruction-level fast-math-flag or some other + // indicator that relaxed FP semantics are being used. + + if (isa<ConstantFP>(Op1) && !cast<ConstantFP>(Op1)->isZero()) + Worklist.push_back(std::make_pair(Op0, Op1)); + } + + // If "A >= B" is known true, replace "A < B" with false everywhere. + CmpInst::Predicate NotPred = Cmp->getInversePredicate(); + Constant *NotVal = ConstantInt::get(Cmp->getType(), isKnownFalse); + // Since we don't have the instruction "A < B" immediately to hand, work + // out the value number that it would have and use that to find an + // appropriate instruction (if any). + uint32_t NextNum = VN.getNextUnusedValueNumber(); + uint32_t Num = VN.lookupOrAddCmp(Cmp->getOpcode(), NotPred, Op0, Op1); + // If the number we were assigned was brand new then there is no point in + // looking for an instruction realizing it: there cannot be one! + if (Num < NextNum) { + Value *NotCmp = findLeader(Root.getEnd(), Num); + if (NotCmp && isa<Instruction>(NotCmp)) { + unsigned NumReplacements = + DominatesByEdge + ? replaceDominatedUsesWith(NotCmp, NotVal, *DT, Root) + : replaceDominatedUsesWith(NotCmp, NotVal, *DT, + Root.getStart()); + Changed |= NumReplacements > 0; + NumGVNEqProp += NumReplacements; + } + } + // Ensure that any instruction in scope that gets the "A < B" value number + // is replaced with false. + // The leader table only tracks basic blocks, not edges. Only add to if we + // have the simple case where the edge dominates the end. + if (RootDominatesEnd) + addToLeaderTable(Num, NotVal, Root.getEnd()); + + continue; + } + } + + return Changed; +} + +/// When calculating availability, handle an instruction +/// by inserting it into the appropriate sets +bool GVN::processInstruction(Instruction *I) { + // Ignore dbg info intrinsics. + if (isa<DbgInfoIntrinsic>(I)) + return false; + + // If the instruction can be easily simplified then do so now in preference + // to value numbering it. Value numbering often exposes redundancies, for + // example if it determines that %y is equal to %x then the instruction + // "%z = and i32 %x, %y" becomes "%z = and i32 %x, %x" which we now simplify. + const DataLayout &DL = I->getModule()->getDataLayout(); + if (Value *V = SimplifyInstruction(I, {DL, TLI, DT, AC})) { + bool Changed = false; + if (!I->use_empty()) { + I->replaceAllUsesWith(V); + Changed = true; + } + if (isInstructionTriviallyDead(I, TLI)) { + markInstructionForDeletion(I); + Changed = true; + } + if (Changed) { + if (MD && V->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(V); + ++NumGVNSimpl; + return true; + } + } + + if (IntrinsicInst *IntrinsicI = dyn_cast<IntrinsicInst>(I)) + if (IntrinsicI->getIntrinsicID() == Intrinsic::assume) + return processAssumeIntrinsic(IntrinsicI); + + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + if (processLoad(LI)) + return true; + + unsigned Num = VN.lookupOrAdd(LI); + addToLeaderTable(Num, LI, LI->getParent()); + return false; + } + + // For conditional branches, we can perform simple conditional propagation on + // the condition value itself. + if (BranchInst *BI = dyn_cast<BranchInst>(I)) { + if (!BI->isConditional()) + return false; + + if (isa<Constant>(BI->getCondition())) + return processFoldableCondBr(BI); + + Value *BranchCond = BI->getCondition(); + BasicBlock *TrueSucc = BI->getSuccessor(0); + BasicBlock *FalseSucc = BI->getSuccessor(1); + // Avoid multiple edges early. + if (TrueSucc == FalseSucc) + return false; + + BasicBlock *Parent = BI->getParent(); + bool Changed = false; + + Value *TrueVal = ConstantInt::getTrue(TrueSucc->getContext()); + BasicBlockEdge TrueE(Parent, TrueSucc); + Changed |= propagateEquality(BranchCond, TrueVal, TrueE, true); + + Value *FalseVal = ConstantInt::getFalse(FalseSucc->getContext()); + BasicBlockEdge FalseE(Parent, FalseSucc); + Changed |= propagateEquality(BranchCond, FalseVal, FalseE, true); + + return Changed; + } + + // For switches, propagate the case values into the case destinations. + if (SwitchInst *SI = dyn_cast<SwitchInst>(I)) { + Value *SwitchCond = SI->getCondition(); + BasicBlock *Parent = SI->getParent(); + bool Changed = false; + + // Remember how many outgoing edges there are to every successor. + SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; + for (unsigned i = 0, n = SI->getNumSuccessors(); i != n; ++i) + ++SwitchEdges[SI->getSuccessor(i)]; + + for (SwitchInst::CaseIt i = SI->case_begin(), e = SI->case_end(); + i != e; ++i) { + BasicBlock *Dst = i->getCaseSuccessor(); + // If there is only a single edge, propagate the case value into it. + if (SwitchEdges.lookup(Dst) == 1) { + BasicBlockEdge E(Parent, Dst); + Changed |= propagateEquality(SwitchCond, i->getCaseValue(), E, true); + } + } + return Changed; + } + + // Instructions with void type don't return a value, so there's + // no point in trying to find redundancies in them. + if (I->getType()->isVoidTy()) + return false; + + uint32_t NextNum = VN.getNextUnusedValueNumber(); + unsigned Num = VN.lookupOrAdd(I); + + // Allocations are always uniquely numbered, so we can save time and memory + // by fast failing them. + if (isa<AllocaInst>(I) || isa<TerminatorInst>(I) || isa<PHINode>(I)) { + addToLeaderTable(Num, I, I->getParent()); + return false; + } + + // If the number we were assigned was a brand new VN, then we don't + // need to do a lookup to see if the number already exists + // somewhere in the domtree: it can't! + if (Num >= NextNum) { + addToLeaderTable(Num, I, I->getParent()); + return false; + } + + // Perform fast-path value-number based elimination of values inherited from + // dominators. + Value *Repl = findLeader(I->getParent(), Num); + if (!Repl) { + // Failure, just remember this instance for future use. + addToLeaderTable(Num, I, I->getParent()); + return false; + } else if (Repl == I) { + // If I was the result of a shortcut PRE, it might already be in the table + // and the best replacement for itself. Nothing to do. + return false; + } + + // Remove it! + patchAndReplaceAllUsesWith(I, Repl); + if (MD && Repl->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(Repl); + markInstructionForDeletion(I); + return true; +} + +/// runOnFunction - This is the main transformation entry point for a function. +bool GVN::runImpl(Function &F, AssumptionCache &RunAC, DominatorTree &RunDT, + const TargetLibraryInfo &RunTLI, AAResults &RunAA, + MemoryDependenceResults *RunMD, LoopInfo *LI, + OptimizationRemarkEmitter *RunORE) { + AC = &RunAC; + DT = &RunDT; + VN.setDomTree(DT); + TLI = &RunTLI; + VN.setAliasAnalysis(&RunAA); + MD = RunMD; + OrderedInstructions OrderedInstrs(DT); + OI = &OrderedInstrs; + VN.setMemDep(MD); + ORE = RunORE; + + bool Changed = false; + bool ShouldContinue = true; + + // Merge unconditional branches, allowing PRE to catch more + // optimization opportunities. + for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE; ) { + BasicBlock *BB = &*FI++; + + bool removedBlock = MergeBlockIntoPredecessor(BB, DT, LI, MD); + if (removedBlock) + ++NumGVNBlocks; + + Changed |= removedBlock; + } + + unsigned Iteration = 0; + while (ShouldContinue) { + DEBUG(dbgs() << "GVN iteration: " << Iteration << "\n"); + ShouldContinue = iterateOnFunction(F); + Changed |= ShouldContinue; + ++Iteration; + } + + if (EnablePRE) { + // Fabricate val-num for dead-code in order to suppress assertion in + // performPRE(). + assignValNumForDeadCode(); + assignBlockRPONumber(F); + bool PREChanged = true; + while (PREChanged) { + PREChanged = performPRE(F); + Changed |= PREChanged; + } + } + + // FIXME: Should perform GVN again after PRE does something. PRE can move + // computations into blocks where they become fully redundant. Note that + // we can't do this until PRE's critical edge splitting updates memdep. + // Actually, when this happens, we should just fully integrate PRE into GVN. + + cleanupGlobalSets(); + // Do not cleanup DeadBlocks in cleanupGlobalSets() as it's called for each + // iteration. + DeadBlocks.clear(); + + return Changed; +} + +bool GVN::processBlock(BasicBlock *BB) { + // FIXME: Kill off InstrsToErase by doing erasing eagerly in a helper function + // (and incrementing BI before processing an instruction). + assert(InstrsToErase.empty() && + "We expect InstrsToErase to be empty across iterations"); + if (DeadBlocks.count(BB)) + return false; + + // Clearing map before every BB because it can be used only for single BB. + ReplaceWithConstMap.clear(); + bool ChangedFunction = false; + + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); + BI != BE;) { + if (!ReplaceWithConstMap.empty()) + ChangedFunction |= replaceOperandsWithConsts(&*BI); + ChangedFunction |= processInstruction(&*BI); + + if (InstrsToErase.empty()) { + ++BI; + continue; + } + + // If we need some instructions deleted, do it now. + NumGVNInstr += InstrsToErase.size(); + + // Avoid iterator invalidation. + bool AtStart = BI == BB->begin(); + if (!AtStart) + --BI; + + bool InvalidateImplicitCF = false; + const Instruction *MaybeFirstICF = FirstImplicitControlFlowInsts.lookup(BB); + for (auto *I : InstrsToErase) { + assert(I->getParent() == BB && "Removing instruction from wrong block?"); + DEBUG(dbgs() << "GVN removed: " << *I << '\n'); + if (MD) MD->removeInstruction(I); + DEBUG(verifyRemoved(I)); + if (MaybeFirstICF == I) { + // We have erased the first ICF in block. The map needs to be updated. + InvalidateImplicitCF = true; + // Do not keep dangling pointer on the erased instruction. + MaybeFirstICF = nullptr; + } + I->eraseFromParent(); + } + + OI->invalidateBlock(BB); + InstrsToErase.clear(); + if (InvalidateImplicitCF) + fillImplicitControlFlowInfo(BB); + + if (AtStart) + BI = BB->begin(); + else + ++BI; + } + + return ChangedFunction; +} + +// Instantiate an expression in a predecessor that lacked it. +bool GVN::performScalarPREInsertion(Instruction *Instr, BasicBlock *Pred, + BasicBlock *Curr, unsigned int ValNo) { + // Because we are going top-down through the block, all value numbers + // will be available in the predecessor by the time we need them. Any + // that weren't originally present will have been instantiated earlier + // in this loop. + bool success = true; + for (unsigned i = 0, e = Instr->getNumOperands(); i != e; ++i) { + Value *Op = Instr->getOperand(i); + if (isa<Argument>(Op) || isa<Constant>(Op) || isa<GlobalValue>(Op)) + continue; + // This could be a newly inserted instruction, in which case, we won't + // find a value number, and should give up before we hurt ourselves. + // FIXME: Rewrite the infrastructure to let it easier to value number + // and process newly inserted instructions. + if (!VN.exists(Op)) { + success = false; + break; + } + uint32_t TValNo = + VN.phiTranslate(Pred, Curr, VN.lookup(Op), *this); + if (Value *V = findLeader(Pred, TValNo)) { + Instr->setOperand(i, V); + } else { + success = false; + break; + } + } + + // Fail out if we encounter an operand that is not available in + // the PRE predecessor. This is typically because of loads which + // are not value numbered precisely. + if (!success) + return false; + + Instr->insertBefore(Pred->getTerminator()); + Instr->setName(Instr->getName() + ".pre"); + Instr->setDebugLoc(Instr->getDebugLoc()); + + unsigned Num = VN.lookupOrAdd(Instr); + VN.add(Instr, Num); + + // Update the availability map to include the new instruction. + addToLeaderTable(Num, Instr, Pred); + return true; +} + +bool GVN::performScalarPRE(Instruction *CurInst) { + if (isa<AllocaInst>(CurInst) || isa<TerminatorInst>(CurInst) || + isa<PHINode>(CurInst) || CurInst->getType()->isVoidTy() || + CurInst->mayReadFromMemory() || CurInst->mayHaveSideEffects() || + isa<DbgInfoIntrinsic>(CurInst)) + return false; + + // Don't do PRE on compares. The PHI would prevent CodeGenPrepare from + // sinking the compare again, and it would force the code generator to + // move the i1 from processor flags or predicate registers into a general + // purpose register. + if (isa<CmpInst>(CurInst)) + return false; + + // We don't currently value number ANY inline asm calls. + if (CallInst *CallI = dyn_cast<CallInst>(CurInst)) + if (CallI->isInlineAsm()) + return false; + + uint32_t ValNo = VN.lookup(CurInst); + + // Look for the predecessors for PRE opportunities. We're + // only trying to solve the basic diamond case, where + // a value is computed in the successor and one predecessor, + // but not the other. We also explicitly disallow cases + // where the successor is its own predecessor, because they're + // more complicated to get right. + unsigned NumWith = 0; + unsigned NumWithout = 0; + BasicBlock *PREPred = nullptr; + BasicBlock *CurrentBlock = CurInst->getParent(); + + SmallVector<std::pair<Value *, BasicBlock *>, 8> predMap; + for (BasicBlock *P : predecessors(CurrentBlock)) { + // We're not interested in PRE where blocks with predecessors that are + // not reachable. + if (!DT->isReachableFromEntry(P)) { + NumWithout = 2; + break; + } + // It is not safe to do PRE when P->CurrentBlock is a loop backedge, and + // when CurInst has operand defined in CurrentBlock (so it may be defined + // by phi in the loop header). + if (BlockRPONumber[P] >= BlockRPONumber[CurrentBlock] && + llvm::any_of(CurInst->operands(), [&](const Use &U) { + if (auto *Inst = dyn_cast<Instruction>(U.get())) + return Inst->getParent() == CurrentBlock; + return false; + })) { + NumWithout = 2; + break; + } + + uint32_t TValNo = VN.phiTranslate(P, CurrentBlock, ValNo, *this); + Value *predV = findLeader(P, TValNo); + if (!predV) { + predMap.push_back(std::make_pair(static_cast<Value *>(nullptr), P)); + PREPred = P; + ++NumWithout; + } else if (predV == CurInst) { + /* CurInst dominates this predecessor. */ + NumWithout = 2; + break; + } else { + predMap.push_back(std::make_pair(predV, P)); + ++NumWith; + } + } + + // Don't do PRE when it might increase code size, i.e. when + // we would need to insert instructions in more than one pred. + if (NumWithout > 1 || NumWith == 0) + return false; + + // We may have a case where all predecessors have the instruction, + // and we just need to insert a phi node. Otherwise, perform + // insertion. + Instruction *PREInstr = nullptr; + + if (NumWithout != 0) { + if (!isSafeToSpeculativelyExecute(CurInst)) { + // It is only valid to insert a new instruction if the current instruction + // is always executed. An instruction with implicit control flow could + // prevent us from doing it. If we cannot speculate the execution, then + // PRE should be prohibited. + auto It = FirstImplicitControlFlowInsts.find(CurrentBlock); + if (It != FirstImplicitControlFlowInsts.end()) { + assert(It->second->getParent() == CurrentBlock && + "Implicit control flow map broken?"); + if (OI->dominates(It->second, CurInst)) + return false; + } + } + + // Don't do PRE across indirect branch. + if (isa<IndirectBrInst>(PREPred->getTerminator())) + return false; + + // We can't do PRE safely on a critical edge, so instead we schedule + // the edge to be split and perform the PRE the next time we iterate + // on the function. + unsigned SuccNum = GetSuccessorNumber(PREPred, CurrentBlock); + if (isCriticalEdge(PREPred->getTerminator(), SuccNum)) { + toSplit.push_back(std::make_pair(PREPred->getTerminator(), SuccNum)); + return false; + } + // We need to insert somewhere, so let's give it a shot + PREInstr = CurInst->clone(); + if (!performScalarPREInsertion(PREInstr, PREPred, CurrentBlock, ValNo)) { + // If we failed insertion, make sure we remove the instruction. + DEBUG(verifyRemoved(PREInstr)); + PREInstr->deleteValue(); + return false; + } + } + + // Either we should have filled in the PRE instruction, or we should + // not have needed insertions. + assert(PREInstr != nullptr || NumWithout == 0); + + ++NumGVNPRE; + + // Create a PHI to make the value available in this block. + PHINode *Phi = + PHINode::Create(CurInst->getType(), predMap.size(), + CurInst->getName() + ".pre-phi", &CurrentBlock->front()); + for (unsigned i = 0, e = predMap.size(); i != e; ++i) { + if (Value *V = predMap[i].first) { + // If we use an existing value in this phi, we have to patch the original + // value because the phi will be used to replace a later value. + patchReplacementInstruction(CurInst, V); + Phi->addIncoming(V, predMap[i].second); + } else + Phi->addIncoming(PREInstr, PREPred); + } + + VN.add(Phi, ValNo); + // After creating a new PHI for ValNo, the phi translate result for ValNo will + // be changed, so erase the related stale entries in phi translate cache. + VN.eraseTranslateCacheEntry(ValNo, *CurrentBlock); + addToLeaderTable(ValNo, Phi, CurrentBlock); + Phi->setDebugLoc(CurInst->getDebugLoc()); + CurInst->replaceAllUsesWith(Phi); + if (MD && Phi->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(Phi); + VN.erase(CurInst); + removeFromLeaderTable(ValNo, CurInst, CurrentBlock); + + DEBUG(dbgs() << "GVN PRE removed: " << *CurInst << '\n'); + if (MD) + MD->removeInstruction(CurInst); + DEBUG(verifyRemoved(CurInst)); + bool InvalidateImplicitCF = + FirstImplicitControlFlowInsts.lookup(CurInst->getParent()) == CurInst; + // FIXME: Intended to be markInstructionForDeletion(CurInst), but it causes + // some assertion failures. + OI->invalidateBlock(CurrentBlock); + CurInst->eraseFromParent(); + if (InvalidateImplicitCF) + fillImplicitControlFlowInfo(CurrentBlock); + ++NumGVNInstr; + + return true; +} + +/// Perform a purely local form of PRE that looks for diamond +/// control flow patterns and attempts to perform simple PRE at the join point. +bool GVN::performPRE(Function &F) { + bool Changed = false; + for (BasicBlock *CurrentBlock : depth_first(&F.getEntryBlock())) { + // Nothing to PRE in the entry block. + if (CurrentBlock == &F.getEntryBlock()) + continue; + + // Don't perform PRE on an EH pad. + if (CurrentBlock->isEHPad()) + continue; + + for (BasicBlock::iterator BI = CurrentBlock->begin(), + BE = CurrentBlock->end(); + BI != BE;) { + Instruction *CurInst = &*BI++; + Changed |= performScalarPRE(CurInst); + } + } + + if (splitCriticalEdges()) + Changed = true; + + return Changed; +} + +/// Split the critical edge connecting the given two blocks, and return +/// the block inserted to the critical edge. +BasicBlock *GVN::splitCriticalEdges(BasicBlock *Pred, BasicBlock *Succ) { + BasicBlock *BB = + SplitCriticalEdge(Pred, Succ, CriticalEdgeSplittingOptions(DT)); + if (MD) + MD->invalidateCachedPredecessors(); + return BB; +} + +/// Split critical edges found during the previous +/// iteration that may enable further optimization. +bool GVN::splitCriticalEdges() { + if (toSplit.empty()) + return false; + do { + std::pair<TerminatorInst*, unsigned> Edge = toSplit.pop_back_val(); + SplitCriticalEdge(Edge.first, Edge.second, + CriticalEdgeSplittingOptions(DT)); + } while (!toSplit.empty()); + if (MD) MD->invalidateCachedPredecessors(); + return true; +} + +/// Executes one iteration of GVN +bool GVN::iterateOnFunction(Function &F) { + cleanupGlobalSets(); + + // Top-down walk of the dominator tree + bool Changed = false; + // Needed for value numbering with phi construction to work. + // RPOT walks the graph in its constructor and will not be invalidated during + // processBlock. + ReversePostOrderTraversal<Function *> RPOT(&F); + + for (BasicBlock *BB : RPOT) + fillImplicitControlFlowInfo(BB); + for (BasicBlock *BB : RPOT) + Changed |= processBlock(BB); + + return Changed; +} + +void GVN::cleanupGlobalSets() { + VN.clear(); + LeaderTable.clear(); + BlockRPONumber.clear(); + TableAllocator.Reset(); + FirstImplicitControlFlowInsts.clear(); +} + +void +GVN::fillImplicitControlFlowInfo(BasicBlock *BB) { + // Make sure that all marked instructions are actually deleted by this point, + // so that we don't need to care about omitting them. + assert(InstrsToErase.empty() && "Filling before removed all marked insns?"); + auto MayNotTransferExecutionToSuccessor = [&](const Instruction *I) { + // If a block's instruction doesn't always pass the control to its successor + // instruction, mark the block as having implicit control flow. We use them + // to avoid wrong assumptions of sort "if A is executed and B post-dominates + // A, then B is also executed". This is not true is there is an implicit + // control flow instruction (e.g. a guard) between them. + // + // TODO: Currently, isGuaranteedToTransferExecutionToSuccessor returns false + // for volatile stores and loads because they can trap. The discussion on + // whether or not it is correct is still ongoing. We might want to get rid + // of this logic in the future. Anyways, trapping instructions shouldn't + // introduce implicit control flow, so we explicitly allow them here. This + // must be removed once isGuaranteedToTransferExecutionToSuccessor is fixed. + if (isGuaranteedToTransferExecutionToSuccessor(I)) + return false; + if (isa<LoadInst>(I)) { + assert(cast<LoadInst>(I)->isVolatile() && + "Non-volatile load should transfer execution to successor!"); + return false; + } + if (isa<StoreInst>(I)) { + assert(cast<StoreInst>(I)->isVolatile() && + "Non-volatile store should transfer execution to successor!"); + return false; + } + return true; + }; + FirstImplicitControlFlowInsts.erase(BB); + + for (auto &I : *BB) + if (MayNotTransferExecutionToSuccessor(&I)) { + FirstImplicitControlFlowInsts[BB] = &I; + break; + } +} + +/// Verify that the specified instruction does not occur in our +/// internal data structures. +void GVN::verifyRemoved(const Instruction *Inst) const { + VN.verifyRemoved(Inst); + + // Walk through the value number scope to make sure the instruction isn't + // ferreted away in it. + for (DenseMap<uint32_t, LeaderTableEntry>::const_iterator + I = LeaderTable.begin(), E = LeaderTable.end(); I != E; ++I) { + const LeaderTableEntry *Node = &I->second; + assert(Node->Val != Inst && "Inst still in value numbering scope!"); + + while (Node->Next) { + Node = Node->Next; + assert(Node->Val != Inst && "Inst still in value numbering scope!"); + } + } +} + +/// BB is declared dead, which implied other blocks become dead as well. This +/// function is to add all these blocks to "DeadBlocks". For the dead blocks' +/// live successors, update their phi nodes by replacing the operands +/// corresponding to dead blocks with UndefVal. +void GVN::addDeadBlock(BasicBlock *BB) { + SmallVector<BasicBlock *, 4> NewDead; + SmallSetVector<BasicBlock *, 4> DF; + + NewDead.push_back(BB); + while (!NewDead.empty()) { + BasicBlock *D = NewDead.pop_back_val(); + if (DeadBlocks.count(D)) + continue; + + // All blocks dominated by D are dead. + SmallVector<BasicBlock *, 8> Dom; + DT->getDescendants(D, Dom); + DeadBlocks.insert(Dom.begin(), Dom.end()); + + // Figure out the dominance-frontier(D). + for (BasicBlock *B : Dom) { + for (BasicBlock *S : successors(B)) { + if (DeadBlocks.count(S)) + continue; + + bool AllPredDead = true; + for (BasicBlock *P : predecessors(S)) + if (!DeadBlocks.count(P)) { + AllPredDead = false; + break; + } + + if (!AllPredDead) { + // S could be proved dead later on. That is why we don't update phi + // operands at this moment. + DF.insert(S); + } else { + // While S is not dominated by D, it is dead by now. This could take + // place if S already have a dead predecessor before D is declared + // dead. + NewDead.push_back(S); + } + } + } + } + + // For the dead blocks' live successors, update their phi nodes by replacing + // the operands corresponding to dead blocks with UndefVal. + for(SmallSetVector<BasicBlock *, 4>::iterator I = DF.begin(), E = DF.end(); + I != E; I++) { + BasicBlock *B = *I; + if (DeadBlocks.count(B)) + continue; + + SmallVector<BasicBlock *, 4> Preds(pred_begin(B), pred_end(B)); + for (BasicBlock *P : Preds) { + if (!DeadBlocks.count(P)) + continue; + + if (isCriticalEdge(P->getTerminator(), GetSuccessorNumber(P, B))) { + if (BasicBlock *S = splitCriticalEdges(P, B)) + DeadBlocks.insert(P = S); + } + + for (BasicBlock::iterator II = B->begin(); isa<PHINode>(II); ++II) { + PHINode &Phi = cast<PHINode>(*II); + Phi.setIncomingValue(Phi.getBasicBlockIndex(P), + UndefValue::get(Phi.getType())); + } + } + } +} + +// If the given branch is recognized as a foldable branch (i.e. conditional +// branch with constant condition), it will perform following analyses and +// transformation. +// 1) If the dead out-coming edge is a critical-edge, split it. Let +// R be the target of the dead out-coming edge. +// 1) Identify the set of dead blocks implied by the branch's dead outcoming +// edge. The result of this step will be {X| X is dominated by R} +// 2) Identify those blocks which haves at least one dead predecessor. The +// result of this step will be dominance-frontier(R). +// 3) Update the PHIs in DF(R) by replacing the operands corresponding to +// dead blocks with "UndefVal" in an hope these PHIs will optimized away. +// +// Return true iff *NEW* dead code are found. +bool GVN::processFoldableCondBr(BranchInst *BI) { + if (!BI || BI->isUnconditional()) + return false; + + // If a branch has two identical successors, we cannot declare either dead. + if (BI->getSuccessor(0) == BI->getSuccessor(1)) + return false; + + ConstantInt *Cond = dyn_cast<ConstantInt>(BI->getCondition()); + if (!Cond) + return false; + + BasicBlock *DeadRoot = + Cond->getZExtValue() ? BI->getSuccessor(1) : BI->getSuccessor(0); + if (DeadBlocks.count(DeadRoot)) + return false; + + if (!DeadRoot->getSinglePredecessor()) + DeadRoot = splitCriticalEdges(BI->getParent(), DeadRoot); + + addDeadBlock(DeadRoot); + return true; +} + +// performPRE() will trigger assert if it comes across an instruction without +// associated val-num. As it normally has far more live instructions than dead +// instructions, it makes more sense just to "fabricate" a val-number for the +// dead code than checking if instruction involved is dead or not. +void GVN::assignValNumForDeadCode() { + for (BasicBlock *BB : DeadBlocks) { + for (Instruction &Inst : *BB) { + unsigned ValNum = VN.lookupOrAdd(&Inst); + addToLeaderTable(ValNum, &Inst, BB); + } + } +} + +class llvm::gvn::GVNLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + + explicit GVNLegacyPass(bool NoLoads = false) + : FunctionPass(ID), NoLoads(NoLoads) { + initializeGVNLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto *LIWP = getAnalysisIfAvailable<LoopInfoWrapperPass>(); + + return Impl.runImpl( + F, getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + getAnalysis<AAResultsWrapperPass>().getAAResults(), + NoLoads ? nullptr + : &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(), + LIWP ? &LIWP->getLoopInfo() : nullptr, + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + if (!NoLoads) + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<TargetLibraryInfoWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + } + +private: + bool NoLoads; + GVN Impl; +}; + +char GVNLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(GVNLegacyPass, "gvn", "Global Value Numbering", false, false) + +// The public interface to this file... +FunctionPass *llvm::createGVNPass(bool NoLoads) { + return new GVNLegacyPass(NoLoads); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp b/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp new file mode 100644 index 000000000000..026fab5dbd3b --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/GVNHoist.cpp @@ -0,0 +1,1207 @@ +//===- GVNHoist.cpp - Hoist scalar and load expressions -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass hoists expressions from branches to a common dominator. It uses +// GVN (global value numbering) to discover expressions computing the same +// values. The primary goals of code-hoisting are: +// 1. To reduce the code size. +// 2. In some cases reduce critical path (by exposing more ILP). +// +// The algorithm factors out the reachability of values such that multiple +// queries to find reachability of values are fast. This is based on finding the +// ANTIC points in the CFG which do not change during hoisting. The ANTIC points +// are basically the dominance-frontiers in the inverse graph. So we introduce a +// data structure (CHI nodes) to keep track of values flowing out of a basic +// block. We only do this for values with multiple occurrences in the function +// as they are the potential hoistable candidates. This approach allows us to +// hoist instructions to a basic block with more than two successors, as well as +// deal with infinite loops in a trivial way. +// +// Limitations: This pass does not hoist fully redundant expressions because +// they are already handled by GVN-PRE. It is advisable to run gvn-hoist before +// and after gvn-pre because gvn-pre creates opportunities for more instructions +// to be hoisted. +// +// Hoisting may affect the performance in some cases. To mitigate that, hoisting +// is disabled in the following cases. +// 1. Scalars across calls. +// 2. geps when corresponding load/store cannot be hoisted. +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/IteratedDominanceFrontier.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/MemorySSAUpdater.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <memory> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "gvn-hoist" + +STATISTIC(NumHoisted, "Number of instructions hoisted"); +STATISTIC(NumRemoved, "Number of instructions removed"); +STATISTIC(NumLoadsHoisted, "Number of loads hoisted"); +STATISTIC(NumLoadsRemoved, "Number of loads removed"); +STATISTIC(NumStoresHoisted, "Number of stores hoisted"); +STATISTIC(NumStoresRemoved, "Number of stores removed"); +STATISTIC(NumCallsHoisted, "Number of calls hoisted"); +STATISTIC(NumCallsRemoved, "Number of calls removed"); + +static cl::opt<int> + MaxHoistedThreshold("gvn-max-hoisted", cl::Hidden, cl::init(-1), + cl::desc("Max number of instructions to hoist " + "(default unlimited = -1)")); + +static cl::opt<int> MaxNumberOfBBSInPath( + "gvn-hoist-max-bbs", cl::Hidden, cl::init(4), + cl::desc("Max number of basic blocks on the path between " + "hoisting locations (default = 4, unlimited = -1)")); + +static cl::opt<int> MaxDepthInBB( + "gvn-hoist-max-depth", cl::Hidden, cl::init(100), + cl::desc("Hoist instructions from the beginning of the BB up to the " + "maximum specified depth (default = 100, unlimited = -1)")); + +static cl::opt<int> + MaxChainLength("gvn-hoist-max-chain-length", cl::Hidden, cl::init(10), + cl::desc("Maximum length of dependent chains to hoist " + "(default = 10, unlimited = -1)")); + +namespace llvm { + +using BBSideEffectsSet = DenseMap<const BasicBlock *, bool>; +using SmallVecInsn = SmallVector<Instruction *, 4>; +using SmallVecImplInsn = SmallVectorImpl<Instruction *>; + +// Each element of a hoisting list contains the basic block where to hoist and +// a list of instructions to be hoisted. +using HoistingPointInfo = std::pair<BasicBlock *, SmallVecInsn>; + +using HoistingPointList = SmallVector<HoistingPointInfo, 4>; + +// A map from a pair of VNs to all the instructions with those VNs. +using VNType = std::pair<unsigned, unsigned>; + +using VNtoInsns = DenseMap<VNType, SmallVector<Instruction *, 4>>; + +// CHI keeps information about values flowing out of a basic block. It is +// similar to PHI but in the inverse graph, and used for outgoing values on each +// edge. For conciseness, it is computed only for instructions with multiple +// occurrences in the CFG because they are the only hoistable candidates. +// A (CHI[{V, B, I1}, {V, C, I2}] +// / \ +// / \ +// B(I1) C (I2) +// The Value number for both I1 and I2 is V, the CHI node will save the +// instruction as well as the edge where the value is flowing to. +struct CHIArg { + VNType VN; + + // Edge destination (shows the direction of flow), may not be where the I is. + BasicBlock *Dest; + + // The instruction (VN) which uses the values flowing out of CHI. + Instruction *I; + + bool operator==(const CHIArg &A) { return VN == A.VN; } + bool operator!=(const CHIArg &A) { return !(*this == A); } +}; + +using CHIIt = SmallVectorImpl<CHIArg>::iterator; +using CHIArgs = iterator_range<CHIIt>; +using OutValuesType = DenseMap<BasicBlock *, SmallVector<CHIArg, 2>>; +using InValuesType = + DenseMap<BasicBlock *, SmallVector<std::pair<VNType, Instruction *>, 2>>; + +// An invalid value number Used when inserting a single value number into +// VNtoInsns. +enum : unsigned { InvalidVN = ~2U }; + +// Records all scalar instructions candidate for code hoisting. +class InsnInfo { + VNtoInsns VNtoScalars; + +public: + // Inserts I and its value number in VNtoScalars. + void insert(Instruction *I, GVN::ValueTable &VN) { + // Scalar instruction. + unsigned V = VN.lookupOrAdd(I); + VNtoScalars[{V, InvalidVN}].push_back(I); + } + + const VNtoInsns &getVNTable() const { return VNtoScalars; } +}; + +// Records all load instructions candidate for code hoisting. +class LoadInfo { + VNtoInsns VNtoLoads; + +public: + // Insert Load and the value number of its memory address in VNtoLoads. + void insert(LoadInst *Load, GVN::ValueTable &VN) { + if (Load->isSimple()) { + unsigned V = VN.lookupOrAdd(Load->getPointerOperand()); + VNtoLoads[{V, InvalidVN}].push_back(Load); + } + } + + const VNtoInsns &getVNTable() const { return VNtoLoads; } +}; + +// Records all store instructions candidate for code hoisting. +class StoreInfo { + VNtoInsns VNtoStores; + +public: + // Insert the Store and a hash number of the store address and the stored + // value in VNtoStores. + void insert(StoreInst *Store, GVN::ValueTable &VN) { + if (!Store->isSimple()) + return; + // Hash the store address and the stored value. + Value *Ptr = Store->getPointerOperand(); + Value *Val = Store->getValueOperand(); + VNtoStores[{VN.lookupOrAdd(Ptr), VN.lookupOrAdd(Val)}].push_back(Store); + } + + const VNtoInsns &getVNTable() const { return VNtoStores; } +}; + +// Records all call instructions candidate for code hoisting. +class CallInfo { + VNtoInsns VNtoCallsScalars; + VNtoInsns VNtoCallsLoads; + VNtoInsns VNtoCallsStores; + +public: + // Insert Call and its value numbering in one of the VNtoCalls* containers. + void insert(CallInst *Call, GVN::ValueTable &VN) { + // A call that doesNotAccessMemory is handled as a Scalar, + // onlyReadsMemory will be handled as a Load instruction, + // all other calls will be handled as stores. + unsigned V = VN.lookupOrAdd(Call); + auto Entry = std::make_pair(V, InvalidVN); + + if (Call->doesNotAccessMemory()) + VNtoCallsScalars[Entry].push_back(Call); + else if (Call->onlyReadsMemory()) + VNtoCallsLoads[Entry].push_back(Call); + else + VNtoCallsStores[Entry].push_back(Call); + } + + const VNtoInsns &getScalarVNTable() const { return VNtoCallsScalars; } + const VNtoInsns &getLoadVNTable() const { return VNtoCallsLoads; } + const VNtoInsns &getStoreVNTable() const { return VNtoCallsStores; } +}; + +static void combineKnownMetadata(Instruction *ReplInst, Instruction *I) { + static const unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group}; + combineMetadata(ReplInst, I, KnownIDs); +} + +// This pass hoists common computations across branches sharing common +// dominator. The primary goal is to reduce the code size, and in some +// cases reduce critical path (by exposing more ILP). +class GVNHoist { +public: + GVNHoist(DominatorTree *DT, PostDominatorTree *PDT, AliasAnalysis *AA, + MemoryDependenceResults *MD, MemorySSA *MSSA) + : DT(DT), PDT(PDT), AA(AA), MD(MD), MSSA(MSSA), + MSSAUpdater(llvm::make_unique<MemorySSAUpdater>(MSSA)) {} + + bool run(Function &F) { + NumFuncArgs = F.arg_size(); + VN.setDomTree(DT); + VN.setAliasAnalysis(AA); + VN.setMemDep(MD); + bool Res = false; + // Perform DFS Numbering of instructions. + unsigned BBI = 0; + for (const BasicBlock *BB : depth_first(&F.getEntryBlock())) { + DFSNumber[BB] = ++BBI; + unsigned I = 0; + for (auto &Inst : *BB) + DFSNumber[&Inst] = ++I; + } + + int ChainLength = 0; + + // FIXME: use lazy evaluation of VN to avoid the fix-point computation. + while (true) { + if (MaxChainLength != -1 && ++ChainLength >= MaxChainLength) + return Res; + + auto HoistStat = hoistExpressions(F); + if (HoistStat.first + HoistStat.second == 0) + return Res; + + if (HoistStat.second > 0) + // To address a limitation of the current GVN, we need to rerun the + // hoisting after we hoisted loads or stores in order to be able to + // hoist all scalars dependent on the hoisted ld/st. + VN.clear(); + + Res = true; + } + + return Res; + } + + // Copied from NewGVN.cpp + // This function provides global ranking of operations so that we can place + // them in a canonical order. Note that rank alone is not necessarily enough + // for a complete ordering, as constants all have the same rank. However, + // generally, we will simplify an operation with all constants so that it + // doesn't matter what order they appear in. + unsigned int rank(const Value *V) const { + // Prefer constants to undef to anything else + // Undef is a constant, have to check it first. + // Prefer smaller constants to constantexprs + if (isa<ConstantExpr>(V)) + return 2; + if (isa<UndefValue>(V)) + return 1; + if (isa<Constant>(V)) + return 0; + else if (auto *A = dyn_cast<Argument>(V)) + return 3 + A->getArgNo(); + + // Need to shift the instruction DFS by number of arguments + 3 to account + // for the constant and argument ranking above. + auto Result = DFSNumber.lookup(V); + if (Result > 0) + return 4 + NumFuncArgs + Result; + // Unreachable or something else, just return a really large number. + return ~0; + } + +private: + GVN::ValueTable VN; + DominatorTree *DT; + PostDominatorTree *PDT; + AliasAnalysis *AA; + MemoryDependenceResults *MD; + MemorySSA *MSSA; + std::unique_ptr<MemorySSAUpdater> MSSAUpdater; + DenseMap<const Value *, unsigned> DFSNumber; + BBSideEffectsSet BBSideEffects; + DenseSet<const BasicBlock *> HoistBarrier; + SmallVector<BasicBlock *, 32> IDFBlocks; + unsigned NumFuncArgs; + const bool HoistingGeps = false; + + enum InsKind { Unknown, Scalar, Load, Store }; + + // Return true when there are exception handling in BB. + bool hasEH(const BasicBlock *BB) { + auto It = BBSideEffects.find(BB); + if (It != BBSideEffects.end()) + return It->second; + + if (BB->isEHPad() || BB->hasAddressTaken()) { + BBSideEffects[BB] = true; + return true; + } + + if (BB->getTerminator()->mayThrow()) { + BBSideEffects[BB] = true; + return true; + } + + BBSideEffects[BB] = false; + return false; + } + + // Return true when a successor of BB dominates A. + bool successorDominate(const BasicBlock *BB, const BasicBlock *A) { + for (const BasicBlock *Succ : BB->getTerminator()->successors()) + if (DT->dominates(Succ, A)) + return true; + + return false; + } + + // Return true when I1 appears before I2 in the instructions of BB. + bool firstInBB(const Instruction *I1, const Instruction *I2) { + assert(I1->getParent() == I2->getParent()); + unsigned I1DFS = DFSNumber.lookup(I1); + unsigned I2DFS = DFSNumber.lookup(I2); + assert(I1DFS && I2DFS); + return I1DFS < I2DFS; + } + + // Return true when there are memory uses of Def in BB. + bool hasMemoryUse(const Instruction *NewPt, MemoryDef *Def, + const BasicBlock *BB) { + const MemorySSA::AccessList *Acc = MSSA->getBlockAccesses(BB); + if (!Acc) + return false; + + Instruction *OldPt = Def->getMemoryInst(); + const BasicBlock *OldBB = OldPt->getParent(); + const BasicBlock *NewBB = NewPt->getParent(); + bool ReachedNewPt = false; + + for (const MemoryAccess &MA : *Acc) + if (const MemoryUse *MU = dyn_cast<MemoryUse>(&MA)) { + Instruction *Insn = MU->getMemoryInst(); + + // Do not check whether MU aliases Def when MU occurs after OldPt. + if (BB == OldBB && firstInBB(OldPt, Insn)) + break; + + // Do not check whether MU aliases Def when MU occurs before NewPt. + if (BB == NewBB) { + if (!ReachedNewPt) { + if (firstInBB(Insn, NewPt)) + continue; + ReachedNewPt = true; + } + } + if (MemorySSAUtil::defClobbersUseOrDef(Def, MU, *AA)) + return true; + } + + return false; + } + + bool hasEHhelper(const BasicBlock *BB, const BasicBlock *SrcBB, + int &NBBsOnAllPaths) { + // Stop walk once the limit is reached. + if (NBBsOnAllPaths == 0) + return true; + + // Impossible to hoist with exceptions on the path. + if (hasEH(BB)) + return true; + + // No such instruction after HoistBarrier in a basic block was + // selected for hoisting so instructions selected within basic block with + // a hoist barrier can be hoisted. + if ((BB != SrcBB) && HoistBarrier.count(BB)) + return true; + + return false; + } + + // Return true when there are exception handling or loads of memory Def + // between Def and NewPt. This function is only called for stores: Def is + // the MemoryDef of the store to be hoisted. + + // Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and + // return true when the counter NBBsOnAllPaths reaces 0, except when it is + // initialized to -1 which is unlimited. + bool hasEHOrLoadsOnPath(const Instruction *NewPt, MemoryDef *Def, + int &NBBsOnAllPaths) { + const BasicBlock *NewBB = NewPt->getParent(); + const BasicBlock *OldBB = Def->getBlock(); + assert(DT->dominates(NewBB, OldBB) && "invalid path"); + assert(DT->dominates(Def->getDefiningAccess()->getBlock(), NewBB) && + "def does not dominate new hoisting point"); + + // Walk all basic blocks reachable in depth-first iteration on the inverse + // CFG from OldBB to NewBB. These blocks are all the blocks that may be + // executed between the execution of NewBB and OldBB. Hoisting an expression + // from OldBB into NewBB has to be safe on all execution paths. + for (auto I = idf_begin(OldBB), E = idf_end(OldBB); I != E;) { + const BasicBlock *BB = *I; + if (BB == NewBB) { + // Stop traversal when reaching HoistPt. + I.skipChildren(); + continue; + } + + if (hasEHhelper(BB, OldBB, NBBsOnAllPaths)) + return true; + + // Check that we do not move a store past loads. + if (hasMemoryUse(NewPt, Def, BB)) + return true; + + // -1 is unlimited number of blocks on all paths. + if (NBBsOnAllPaths != -1) + --NBBsOnAllPaths; + + ++I; + } + + return false; + } + + // Return true when there are exception handling between HoistPt and BB. + // Decrement by 1 NBBsOnAllPaths for each block between HoistPt and BB, and + // return true when the counter NBBsOnAllPaths reaches 0, except when it is + // initialized to -1 which is unlimited. + bool hasEHOnPath(const BasicBlock *HoistPt, const BasicBlock *SrcBB, + int &NBBsOnAllPaths) { + assert(DT->dominates(HoistPt, SrcBB) && "Invalid path"); + + // Walk all basic blocks reachable in depth-first iteration on + // the inverse CFG from BBInsn to NewHoistPt. These blocks are all the + // blocks that may be executed between the execution of NewHoistPt and + // BBInsn. Hoisting an expression from BBInsn into NewHoistPt has to be safe + // on all execution paths. + for (auto I = idf_begin(SrcBB), E = idf_end(SrcBB); I != E;) { + const BasicBlock *BB = *I; + if (BB == HoistPt) { + // Stop traversal when reaching NewHoistPt. + I.skipChildren(); + continue; + } + + if (hasEHhelper(BB, SrcBB, NBBsOnAllPaths)) + return true; + + // -1 is unlimited number of blocks on all paths. + if (NBBsOnAllPaths != -1) + --NBBsOnAllPaths; + + ++I; + } + + return false; + } + + // Return true when it is safe to hoist a memory load or store U from OldPt + // to NewPt. + bool safeToHoistLdSt(const Instruction *NewPt, const Instruction *OldPt, + MemoryUseOrDef *U, InsKind K, int &NBBsOnAllPaths) { + // In place hoisting is safe. + if (NewPt == OldPt) + return true; + + const BasicBlock *NewBB = NewPt->getParent(); + const BasicBlock *OldBB = OldPt->getParent(); + const BasicBlock *UBB = U->getBlock(); + + // Check for dependences on the Memory SSA. + MemoryAccess *D = U->getDefiningAccess(); + BasicBlock *DBB = D->getBlock(); + if (DT->properlyDominates(NewBB, DBB)) + // Cannot move the load or store to NewBB above its definition in DBB. + return false; + + if (NewBB == DBB && !MSSA->isLiveOnEntryDef(D)) + if (auto *UD = dyn_cast<MemoryUseOrDef>(D)) + if (firstInBB(NewPt, UD->getMemoryInst())) + // Cannot move the load or store to NewPt above its definition in D. + return false; + + // Check for unsafe hoistings due to side effects. + if (K == InsKind::Store) { + if (hasEHOrLoadsOnPath(NewPt, dyn_cast<MemoryDef>(U), NBBsOnAllPaths)) + return false; + } else if (hasEHOnPath(NewBB, OldBB, NBBsOnAllPaths)) + return false; + + if (UBB == NewBB) { + if (DT->properlyDominates(DBB, NewBB)) + return true; + assert(UBB == DBB); + assert(MSSA->locallyDominates(D, U)); + } + + // No side effects: it is safe to hoist. + return true; + } + + // Return true when it is safe to hoist scalar instructions from all blocks in + // WL to HoistBB. + bool safeToHoistScalar(const BasicBlock *HoistBB, const BasicBlock *BB, + int &NBBsOnAllPaths) { + return !hasEHOnPath(HoistBB, BB, NBBsOnAllPaths); + } + + // In the inverse CFG, the dominance frontier of basic block (BB) is the + // point where ANTIC needs to be computed for instructions which are going + // to be hoisted. Since this point does not change during gvn-hoist, + // we compute it only once (on demand). + // The ides is inspired from: + // "Partial Redundancy Elimination in SSA Form" + // ROBERT KENNEDY, SUN CHAN, SHIN-MING LIU, RAYMOND LO, PENG TU and FRED CHOW + // They use similar idea in the forward graph to to find fully redundant and + // partially redundant expressions, here it is used in the inverse graph to + // find fully anticipable instructions at merge point (post-dominator in + // the inverse CFG). + // Returns the edge via which an instruction in BB will get the values from. + + // Returns true when the values are flowing out to each edge. + bool valueAnticipable(CHIArgs C, TerminatorInst *TI) const { + if (TI->getNumSuccessors() > (unsigned)std::distance(C.begin(), C.end())) + return false; // Not enough args in this CHI. + + for (auto CHI : C) { + BasicBlock *Dest = CHI.Dest; + // Find if all the edges have values flowing out of BB. + bool Found = llvm::any_of(TI->successors(), [Dest](const BasicBlock *BB) { + return BB == Dest; }); + if (!Found) + return false; + } + return true; + } + + // Check if it is safe to hoist values tracked by CHI in the range + // [Begin, End) and accumulate them in Safe. + void checkSafety(CHIArgs C, BasicBlock *BB, InsKind K, + SmallVectorImpl<CHIArg> &Safe) { + int NumBBsOnAllPaths = MaxNumberOfBBSInPath; + for (auto CHI : C) { + Instruction *Insn = CHI.I; + if (!Insn) // No instruction was inserted in this CHI. + continue; + if (K == InsKind::Scalar) { + if (safeToHoistScalar(BB, Insn->getParent(), NumBBsOnAllPaths)) + Safe.push_back(CHI); + } else { + MemoryUseOrDef *UD = MSSA->getMemoryAccess(Insn); + if (safeToHoistLdSt(BB->getTerminator(), Insn, UD, K, NumBBsOnAllPaths)) + Safe.push_back(CHI); + } + } + } + + using RenameStackType = DenseMap<VNType, SmallVector<Instruction *, 2>>; + + // Push all the VNs corresponding to BB into RenameStack. + void fillRenameStack(BasicBlock *BB, InValuesType &ValueBBs, + RenameStackType &RenameStack) { + auto it1 = ValueBBs.find(BB); + if (it1 != ValueBBs.end()) { + // Iterate in reverse order to keep lower ranked values on the top. + for (std::pair<VNType, Instruction *> &VI : reverse(it1->second)) { + // Get the value of instruction I + DEBUG(dbgs() << "\nPushing on stack: " << *VI.second); + RenameStack[VI.first].push_back(VI.second); + } + } + } + + void fillChiArgs(BasicBlock *BB, OutValuesType &CHIBBs, + RenameStackType &RenameStack) { + // For each *predecessor* (because Post-DOM) of BB check if it has a CHI + for (auto Pred : predecessors(BB)) { + auto P = CHIBBs.find(Pred); + if (P == CHIBBs.end()) { + continue; + } + DEBUG(dbgs() << "\nLooking at CHIs in: " << Pred->getName();); + // A CHI is found (BB -> Pred is an edge in the CFG) + // Pop the stack until Top(V) = Ve. + auto &VCHI = P->second; + for (auto It = VCHI.begin(), E = VCHI.end(); It != E;) { + CHIArg &C = *It; + if (!C.Dest) { + auto si = RenameStack.find(C.VN); + // The Basic Block where CHI is must dominate the value we want to + // track in a CHI. In the PDom walk, there can be values in the + // stack which are not control dependent e.g., nested loop. + if (si != RenameStack.end() && si->second.size() && + DT->properlyDominates(Pred, si->second.back()->getParent())) { + C.Dest = BB; // Assign the edge + C.I = si->second.pop_back_val(); // Assign the argument + DEBUG(dbgs() << "\nCHI Inserted in BB: " << C.Dest->getName() + << *C.I << ", VN: " << C.VN.first << ", " + << C.VN.second); + } + // Move to next CHI of a different value + It = std::find_if(It, VCHI.end(), + [It](CHIArg &A) { return A != *It; }); + } else + ++It; + } + } + } + + // Walk the post-dominator tree top-down and use a stack for each value to + // store the last value you see. When you hit a CHI from a given edge, the + // value to use as the argument is at the top of the stack, add the value to + // CHI and pop. + void insertCHI(InValuesType &ValueBBs, OutValuesType &CHIBBs) { + auto Root = PDT->getNode(nullptr); + if (!Root) + return; + // Depth first walk on PDom tree to fill the CHIargs at each PDF. + RenameStackType RenameStack; + for (auto Node : depth_first(Root)) { + BasicBlock *BB = Node->getBlock(); + if (!BB) + continue; + + // Collect all values in BB and push to stack. + fillRenameStack(BB, ValueBBs, RenameStack); + + // Fill outgoing values in each CHI corresponding to BB. + fillChiArgs(BB, CHIBBs, RenameStack); + } + } + + // Walk all the CHI-nodes to find ones which have a empty-entry and remove + // them Then collect all the instructions which are safe to hoist and see if + // they form a list of anticipable values. OutValues contains CHIs + // corresponding to each basic block. + void findHoistableCandidates(OutValuesType &CHIBBs, InsKind K, + HoistingPointList &HPL) { + auto cmpVN = [](const CHIArg &A, const CHIArg &B) { return A.VN < B.VN; }; + + // CHIArgs now have the outgoing values, so check for anticipability and + // accumulate hoistable candidates in HPL. + for (std::pair<BasicBlock *, SmallVector<CHIArg, 2>> &A : CHIBBs) { + BasicBlock *BB = A.first; + SmallVectorImpl<CHIArg> &CHIs = A.second; + // Vector of PHIs contains PHIs for different instructions. + // Sort the args according to their VNs, such that identical + // instructions are together. + std::stable_sort(CHIs.begin(), CHIs.end(), cmpVN); + auto TI = BB->getTerminator(); + auto B = CHIs.begin(); + // [PreIt, PHIIt) form a range of CHIs which have identical VNs. + auto PHIIt = std::find_if(CHIs.begin(), CHIs.end(), + [B](CHIArg &A) { return A != *B; }); + auto PrevIt = CHIs.begin(); + while (PrevIt != PHIIt) { + // Collect values which satisfy safety checks. + SmallVector<CHIArg, 2> Safe; + // We check for safety first because there might be multiple values in + // the same path, some of which are not safe to be hoisted, but overall + // each edge has at least one value which can be hoisted, making the + // value anticipable along that path. + checkSafety(make_range(PrevIt, PHIIt), BB, K, Safe); + + // List of safe values should be anticipable at TI. + if (valueAnticipable(make_range(Safe.begin(), Safe.end()), TI)) { + HPL.push_back({BB, SmallVecInsn()}); + SmallVecInsn &V = HPL.back().second; + for (auto B : Safe) + V.push_back(B.I); + } + + // Check other VNs + PrevIt = PHIIt; + PHIIt = std::find_if(PrevIt, CHIs.end(), + [PrevIt](CHIArg &A) { return A != *PrevIt; }); + } + } + } + + // Compute insertion points for each values which can be fully anticipated at + // a dominator. HPL contains all such values. + void computeInsertionPoints(const VNtoInsns &Map, HoistingPointList &HPL, + InsKind K) { + // Sort VNs based on their rankings + std::vector<VNType> Ranks; + for (const auto &Entry : Map) { + Ranks.push_back(Entry.first); + } + + // TODO: Remove fully-redundant expressions. + // Get instruction from the Map, assume that all the Instructions + // with same VNs have same rank (this is an approximation). + std::sort(Ranks.begin(), Ranks.end(), + [this, &Map](const VNType &r1, const VNType &r2) { + return (rank(*Map.lookup(r1).begin()) < + rank(*Map.lookup(r2).begin())); + }); + + // - Sort VNs according to their rank, and start with lowest ranked VN + // - Take a VN and for each instruction with same VN + // - Find the dominance frontier in the inverse graph (PDF) + // - Insert the chi-node at PDF + // - Remove the chi-nodes with missing entries + // - Remove values from CHI-nodes which do not truly flow out, e.g., + // modified along the path. + // - Collect the remaining values that are still anticipable + SmallVector<BasicBlock *, 2> IDFBlocks; + ReverseIDFCalculator IDFs(*PDT); + OutValuesType OutValue; + InValuesType InValue; + for (const auto &R : Ranks) { + const SmallVecInsn &V = Map.lookup(R); + if (V.size() < 2) + continue; + const VNType &VN = R; + SmallPtrSet<BasicBlock *, 2> VNBlocks; + for (auto &I : V) { + BasicBlock *BBI = I->getParent(); + if (!hasEH(BBI)) + VNBlocks.insert(BBI); + } + // Compute the Post Dominance Frontiers of each basic block + // The dominance frontier of a live block X in the reverse + // control graph is the set of blocks upon which X is control + // dependent. The following sequence computes the set of blocks + // which currently have dead terminators that are control + // dependence sources of a block which is in NewLiveBlocks. + IDFs.setDefiningBlocks(VNBlocks); + IDFs.calculate(IDFBlocks); + + // Make a map of BB vs instructions to be hoisted. + for (unsigned i = 0; i < V.size(); ++i) { + InValue[V[i]->getParent()].push_back(std::make_pair(VN, V[i])); + } + // Insert empty CHI node for this VN. This is used to factor out + // basic blocks where the ANTIC can potentially change. + for (auto IDFB : IDFBlocks) { // TODO: Prune out useless CHI insertions. + for (unsigned i = 0; i < V.size(); ++i) { + CHIArg C = {VN, nullptr, nullptr}; + // Ignore spurious PDFs. + if (DT->properlyDominates(IDFB, V[i]->getParent())) { + OutValue[IDFB].push_back(C); + DEBUG(dbgs() << "\nInsertion a CHI for BB: " << IDFB->getName() + << ", for Insn: " << *V[i]); + } + } + } + } + + // Insert CHI args at each PDF to iterate on factored graph of + // control dependence. + insertCHI(InValue, OutValue); + // Using the CHI args inserted at each PDF, find fully anticipable values. + findHoistableCandidates(OutValue, K, HPL); + } + + // Return true when all operands of Instr are available at insertion point + // HoistPt. When limiting the number of hoisted expressions, one could hoist + // a load without hoisting its access function. So before hoisting any + // expression, make sure that all its operands are available at insert point. + bool allOperandsAvailable(const Instruction *I, + const BasicBlock *HoistPt) const { + for (const Use &Op : I->operands()) + if (const auto *Inst = dyn_cast<Instruction>(&Op)) + if (!DT->dominates(Inst->getParent(), HoistPt)) + return false; + + return true; + } + + // Same as allOperandsAvailable with recursive check for GEP operands. + bool allGepOperandsAvailable(const Instruction *I, + const BasicBlock *HoistPt) const { + for (const Use &Op : I->operands()) + if (const auto *Inst = dyn_cast<Instruction>(&Op)) + if (!DT->dominates(Inst->getParent(), HoistPt)) { + if (const GetElementPtrInst *GepOp = + dyn_cast<GetElementPtrInst>(Inst)) { + if (!allGepOperandsAvailable(GepOp, HoistPt)) + return false; + // Gep is available if all operands of GepOp are available. + } else { + // Gep is not available if it has operands other than GEPs that are + // defined in blocks not dominating HoistPt. + return false; + } + } + return true; + } + + // Make all operands of the GEP available. + void makeGepsAvailable(Instruction *Repl, BasicBlock *HoistPt, + const SmallVecInsn &InstructionsToHoist, + Instruction *Gep) const { + assert(allGepOperandsAvailable(Gep, HoistPt) && + "GEP operands not available"); + + Instruction *ClonedGep = Gep->clone(); + for (unsigned i = 0, e = Gep->getNumOperands(); i != e; ++i) + if (Instruction *Op = dyn_cast<Instruction>(Gep->getOperand(i))) { + // Check whether the operand is already available. + if (DT->dominates(Op->getParent(), HoistPt)) + continue; + + // As a GEP can refer to other GEPs, recursively make all the operands + // of this GEP available at HoistPt. + if (GetElementPtrInst *GepOp = dyn_cast<GetElementPtrInst>(Op)) + makeGepsAvailable(ClonedGep, HoistPt, InstructionsToHoist, GepOp); + } + + // Copy Gep and replace its uses in Repl with ClonedGep. + ClonedGep->insertBefore(HoistPt->getTerminator()); + + // Conservatively discard any optimization hints, they may differ on the + // other paths. + ClonedGep->dropUnknownNonDebugMetadata(); + + // If we have optimization hints which agree with each other along different + // paths, preserve them. + for (const Instruction *OtherInst : InstructionsToHoist) { + const GetElementPtrInst *OtherGep; + if (auto *OtherLd = dyn_cast<LoadInst>(OtherInst)) + OtherGep = cast<GetElementPtrInst>(OtherLd->getPointerOperand()); + else + OtherGep = cast<GetElementPtrInst>( + cast<StoreInst>(OtherInst)->getPointerOperand()); + ClonedGep->andIRFlags(OtherGep); + } + + // Replace uses of Gep with ClonedGep in Repl. + Repl->replaceUsesOfWith(Gep, ClonedGep); + } + + void updateAlignment(Instruction *I, Instruction *Repl) { + if (auto *ReplacementLoad = dyn_cast<LoadInst>(Repl)) { + ReplacementLoad->setAlignment( + std::min(ReplacementLoad->getAlignment(), + cast<LoadInst>(I)->getAlignment())); + ++NumLoadsRemoved; + } else if (auto *ReplacementStore = dyn_cast<StoreInst>(Repl)) { + ReplacementStore->setAlignment( + std::min(ReplacementStore->getAlignment(), + cast<StoreInst>(I)->getAlignment())); + ++NumStoresRemoved; + } else if (auto *ReplacementAlloca = dyn_cast<AllocaInst>(Repl)) { + ReplacementAlloca->setAlignment( + std::max(ReplacementAlloca->getAlignment(), + cast<AllocaInst>(I)->getAlignment())); + } else if (isa<CallInst>(Repl)) { + ++NumCallsRemoved; + } + } + + // Remove all the instructions in Candidates and replace their usage with Repl. + // Returns the number of instructions removed. + unsigned rauw(const SmallVecInsn &Candidates, Instruction *Repl, + MemoryUseOrDef *NewMemAcc) { + unsigned NR = 0; + for (Instruction *I : Candidates) { + if (I != Repl) { + ++NR; + updateAlignment(I, Repl); + if (NewMemAcc) { + // Update the uses of the old MSSA access with NewMemAcc. + MemoryAccess *OldMA = MSSA->getMemoryAccess(I); + OldMA->replaceAllUsesWith(NewMemAcc); + MSSAUpdater->removeMemoryAccess(OldMA); + } + + Repl->andIRFlags(I); + combineKnownMetadata(Repl, I); + I->replaceAllUsesWith(Repl); + // Also invalidate the Alias Analysis cache. + MD->removeInstruction(I); + I->eraseFromParent(); + } + } + return NR; + } + + // Replace all Memory PHI usage with NewMemAcc. + void raMPHIuw(MemoryUseOrDef *NewMemAcc) { + SmallPtrSet<MemoryPhi *, 4> UsePhis; + for (User *U : NewMemAcc->users()) + if (MemoryPhi *Phi = dyn_cast<MemoryPhi>(U)) + UsePhis.insert(Phi); + + for (MemoryPhi *Phi : UsePhis) { + auto In = Phi->incoming_values(); + if (llvm::all_of(In, [&](Use &U) { return U == NewMemAcc; })) { + Phi->replaceAllUsesWith(NewMemAcc); + MSSAUpdater->removeMemoryAccess(Phi); + } + } + } + + // Remove all other instructions and replace them with Repl. + unsigned removeAndReplace(const SmallVecInsn &Candidates, Instruction *Repl, + BasicBlock *DestBB, bool MoveAccess) { + MemoryUseOrDef *NewMemAcc = MSSA->getMemoryAccess(Repl); + if (MoveAccess && NewMemAcc) { + // The definition of this ld/st will not change: ld/st hoisting is + // legal when the ld/st is not moved past its current definition. + MSSAUpdater->moveToPlace(NewMemAcc, DestBB, MemorySSA::End); + } + + // Replace all other instructions with Repl with memory access NewMemAcc. + unsigned NR = rauw(Candidates, Repl, NewMemAcc); + + // Remove MemorySSA phi nodes with the same arguments. + if (NewMemAcc) + raMPHIuw(NewMemAcc); + return NR; + } + + // In the case Repl is a load or a store, we make all their GEPs + // available: GEPs are not hoisted by default to avoid the address + // computations to be hoisted without the associated load or store. + bool makeGepOperandsAvailable(Instruction *Repl, BasicBlock *HoistPt, + const SmallVecInsn &InstructionsToHoist) const { + // Check whether the GEP of a ld/st can be synthesized at HoistPt. + GetElementPtrInst *Gep = nullptr; + Instruction *Val = nullptr; + if (auto *Ld = dyn_cast<LoadInst>(Repl)) { + Gep = dyn_cast<GetElementPtrInst>(Ld->getPointerOperand()); + } else if (auto *St = dyn_cast<StoreInst>(Repl)) { + Gep = dyn_cast<GetElementPtrInst>(St->getPointerOperand()); + Val = dyn_cast<Instruction>(St->getValueOperand()); + // Check that the stored value is available. + if (Val) { + if (isa<GetElementPtrInst>(Val)) { + // Check whether we can compute the GEP at HoistPt. + if (!allGepOperandsAvailable(Val, HoistPt)) + return false; + } else if (!DT->dominates(Val->getParent(), HoistPt)) + return false; + } + } + + // Check whether we can compute the Gep at HoistPt. + if (!Gep || !allGepOperandsAvailable(Gep, HoistPt)) + return false; + + makeGepsAvailable(Repl, HoistPt, InstructionsToHoist, Gep); + + if (Val && isa<GetElementPtrInst>(Val)) + makeGepsAvailable(Repl, HoistPt, InstructionsToHoist, Val); + + return true; + } + + std::pair<unsigned, unsigned> hoist(HoistingPointList &HPL) { + unsigned NI = 0, NL = 0, NS = 0, NC = 0, NR = 0; + for (const HoistingPointInfo &HP : HPL) { + // Find out whether we already have one of the instructions in HoistPt, + // in which case we do not have to move it. + BasicBlock *DestBB = HP.first; + const SmallVecInsn &InstructionsToHoist = HP.second; + Instruction *Repl = nullptr; + for (Instruction *I : InstructionsToHoist) + if (I->getParent() == DestBB) + // If there are two instructions in HoistPt to be hoisted in place: + // update Repl to be the first one, such that we can rename the uses + // of the second based on the first. + if (!Repl || firstInBB(I, Repl)) + Repl = I; + + // Keep track of whether we moved the instruction so we know whether we + // should move the MemoryAccess. + bool MoveAccess = true; + if (Repl) { + // Repl is already in HoistPt: it remains in place. + assert(allOperandsAvailable(Repl, DestBB) && + "instruction depends on operands that are not available"); + MoveAccess = false; + } else { + // When we do not find Repl in HoistPt, select the first in the list + // and move it to HoistPt. + Repl = InstructionsToHoist.front(); + + // We can move Repl in HoistPt only when all operands are available. + // The order in which hoistings are done may influence the availability + // of operands. + if (!allOperandsAvailable(Repl, DestBB)) { + // When HoistingGeps there is nothing more we can do to make the + // operands available: just continue. + if (HoistingGeps) + continue; + + // When not HoistingGeps we need to copy the GEPs. + if (!makeGepOperandsAvailable(Repl, DestBB, InstructionsToHoist)) + continue; + } + + // Move the instruction at the end of HoistPt. + Instruction *Last = DestBB->getTerminator(); + MD->removeInstruction(Repl); + Repl->moveBefore(Last); + + DFSNumber[Repl] = DFSNumber[Last]++; + } + + NR += removeAndReplace(InstructionsToHoist, Repl, DestBB, MoveAccess); + + if (isa<LoadInst>(Repl)) + ++NL; + else if (isa<StoreInst>(Repl)) + ++NS; + else if (isa<CallInst>(Repl)) + ++NC; + else // Scalar + ++NI; + } + + NumHoisted += NL + NS + NC + NI; + NumRemoved += NR; + NumLoadsHoisted += NL; + NumStoresHoisted += NS; + NumCallsHoisted += NC; + return {NI, NL + NC + NS}; + } + + // Hoist all expressions. Returns Number of scalars hoisted + // and number of non-scalars hoisted. + std::pair<unsigned, unsigned> hoistExpressions(Function &F) { + InsnInfo II; + LoadInfo LI; + StoreInfo SI; + CallInfo CI; + for (BasicBlock *BB : depth_first(&F.getEntryBlock())) { + int InstructionNb = 0; + for (Instruction &I1 : *BB) { + // If I1 cannot guarantee progress, subsequent instructions + // in BB cannot be hoisted anyways. + if (!isGuaranteedToTransferExecutionToSuccessor(&I1)) { + HoistBarrier.insert(BB); + break; + } + // Only hoist the first instructions in BB up to MaxDepthInBB. Hoisting + // deeper may increase the register pressure and compilation time. + if (MaxDepthInBB != -1 && InstructionNb++ >= MaxDepthInBB) + break; + + // Do not value number terminator instructions. + if (isa<TerminatorInst>(&I1)) + break; + + if (auto *Load = dyn_cast<LoadInst>(&I1)) + LI.insert(Load, VN); + else if (auto *Store = dyn_cast<StoreInst>(&I1)) + SI.insert(Store, VN); + else if (auto *Call = dyn_cast<CallInst>(&I1)) { + if (auto *Intr = dyn_cast<IntrinsicInst>(Call)) { + if (isa<DbgInfoIntrinsic>(Intr) || + Intr->getIntrinsicID() == Intrinsic::assume || + Intr->getIntrinsicID() == Intrinsic::sideeffect) + continue; + } + if (Call->mayHaveSideEffects()) + break; + + if (Call->isConvergent()) + break; + + CI.insert(Call, VN); + } else if (HoistingGeps || !isa<GetElementPtrInst>(&I1)) + // Do not hoist scalars past calls that may write to memory because + // that could result in spills later. geps are handled separately. + // TODO: We can relax this for targets like AArch64 as they have more + // registers than X86. + II.insert(&I1, VN); + } + } + + HoistingPointList HPL; + computeInsertionPoints(II.getVNTable(), HPL, InsKind::Scalar); + computeInsertionPoints(LI.getVNTable(), HPL, InsKind::Load); + computeInsertionPoints(SI.getVNTable(), HPL, InsKind::Store); + computeInsertionPoints(CI.getScalarVNTable(), HPL, InsKind::Scalar); + computeInsertionPoints(CI.getLoadVNTable(), HPL, InsKind::Load); + computeInsertionPoints(CI.getStoreVNTable(), HPL, InsKind::Store); + return hoist(HPL); + } +}; + +class GVNHoistLegacyPass : public FunctionPass { +public: + static char ID; + + GVNHoistLegacyPass() : FunctionPass(ID) { + initializeGVNHoistLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &PDT = getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + auto &MD = getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + auto &MSSA = getAnalysis<MemorySSAWrapperPass>().getMSSA(); + + GVNHoist G(&DT, &PDT, &AA, &MD, &MSSA); + return G.run(F); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<MemorySSAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<MemorySSAWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; + +} // end namespace llvm + +PreservedAnalyses GVNHoistPass::run(Function &F, FunctionAnalysisManager &AM) { + DominatorTree &DT = AM.getResult<DominatorTreeAnalysis>(F); + PostDominatorTree &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); + AliasAnalysis &AA = AM.getResult<AAManager>(F); + MemoryDependenceResults &MD = AM.getResult<MemoryDependenceAnalysis>(F); + MemorySSA &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + GVNHoist G(&DT, &PDT, &AA, &MD, &MSSA); + if (!G.run(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<MemorySSAAnalysis>(); + PA.preserve<GlobalsAA>(); + return PA; +} + +char GVNHoistLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(GVNHoistLegacyPass, "gvn-hoist", + "Early GVN Hoisting of Expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(GVNHoistLegacyPass, "gvn-hoist", + "Early GVN Hoisting of Expressions", false, false) + +FunctionPass *llvm::createGVNHoistPass() { return new GVNHoistLegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/GVNSink.cpp b/contrib/llvm/lib/Transforms/Scalar/GVNSink.cpp new file mode 100644 index 000000000000..5594c29bbd9f --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/GVNSink.cpp @@ -0,0 +1,922 @@ +//===- GVNSink.cpp - sink expressions into successors ---------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// \file GVNSink.cpp +/// This pass attempts to sink instructions into successors, reducing static +/// instruction count and enabling if-conversion. +/// +/// We use a variant of global value numbering to decide what can be sunk. +/// Consider: +/// +/// [ %a1 = add i32 %b, 1 ] [ %c1 = add i32 %d, 1 ] +/// [ %a2 = xor i32 %a1, 1 ] [ %c2 = xor i32 %c1, 1 ] +/// \ / +/// [ %e = phi i32 %a2, %c2 ] +/// [ add i32 %e, 4 ] +/// +/// +/// GVN would number %a1 and %c1 differently because they compute different +/// results - the VN of an instruction is a function of its opcode and the +/// transitive closure of its operands. This is the key property for hoisting +/// and CSE. +/// +/// What we want when sinking however is for a numbering that is a function of +/// the *uses* of an instruction, which allows us to answer the question "if I +/// replace %a1 with %c1, will it contribute in an equivalent way to all +/// successive instructions?". The PostValueTable class in GVN provides this +/// mapping. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ArrayRecycler.h" +#include "llvm/Support/AtomicOrdering.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/GVNExpression.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "gvn-sink" + +STATISTIC(NumRemoved, "Number of instructions removed"); + +namespace llvm { +namespace GVNExpression { + +LLVM_DUMP_METHOD void Expression::dump() const { + print(dbgs()); + dbgs() << "\n"; +} + +} // end namespace GVNExpression +} // end namespace llvm + +namespace { + +static bool isMemoryInst(const Instruction *I) { + return isa<LoadInst>(I) || isa<StoreInst>(I) || + (isa<InvokeInst>(I) && !cast<InvokeInst>(I)->doesNotAccessMemory()) || + (isa<CallInst>(I) && !cast<CallInst>(I)->doesNotAccessMemory()); +} + +/// Iterates through instructions in a set of blocks in reverse order from the +/// first non-terminator. For example (assume all blocks have size n): +/// LockstepReverseIterator I([B1, B2, B3]); +/// *I-- = [B1[n], B2[n], B3[n]]; +/// *I-- = [B1[n-1], B2[n-1], B3[n-1]]; +/// *I-- = [B1[n-2], B2[n-2], B3[n-2]]; +/// ... +/// +/// It continues until all blocks have been exhausted. Use \c getActiveBlocks() +/// to +/// determine which blocks are still going and the order they appear in the +/// list returned by operator*. +class LockstepReverseIterator { + ArrayRef<BasicBlock *> Blocks; + SmallSetVector<BasicBlock *, 4> ActiveBlocks; + SmallVector<Instruction *, 4> Insts; + bool Fail; + +public: + LockstepReverseIterator(ArrayRef<BasicBlock *> Blocks) : Blocks(Blocks) { + reset(); + } + + void reset() { + Fail = false; + ActiveBlocks.clear(); + for (BasicBlock *BB : Blocks) + ActiveBlocks.insert(BB); + Insts.clear(); + for (BasicBlock *BB : Blocks) { + if (BB->size() <= 1) { + // Block wasn't big enough - only contained a terminator. + ActiveBlocks.remove(BB); + continue; + } + Insts.push_back(BB->getTerminator()->getPrevNode()); + } + if (Insts.empty()) + Fail = true; + } + + bool isValid() const { return !Fail; } + ArrayRef<Instruction *> operator*() const { return Insts; } + + // Note: This needs to return a SmallSetVector as the elements of + // ActiveBlocks will be later copied to Blocks using std::copy. The + // resultant order of elements in Blocks needs to be deterministic. + // Using SmallPtrSet instead causes non-deterministic order while + // copying. And we cannot simply sort Blocks as they need to match the + // corresponding Values. + SmallSetVector<BasicBlock *, 4> &getActiveBlocks() { return ActiveBlocks; } + + void restrictToBlocks(SmallSetVector<BasicBlock *, 4> &Blocks) { + for (auto II = Insts.begin(); II != Insts.end();) { + if (std::find(Blocks.begin(), Blocks.end(), (*II)->getParent()) == + Blocks.end()) { + ActiveBlocks.remove((*II)->getParent()); + II = Insts.erase(II); + } else { + ++II; + } + } + } + + void operator--() { + if (Fail) + return; + SmallVector<Instruction *, 4> NewInsts; + for (auto *Inst : Insts) { + if (Inst == &Inst->getParent()->front()) + ActiveBlocks.remove(Inst->getParent()); + else + NewInsts.push_back(Inst->getPrevNode()); + } + if (NewInsts.empty()) { + Fail = true; + return; + } + Insts = NewInsts; + } +}; + +//===----------------------------------------------------------------------===// + +/// Candidate solution for sinking. There may be different ways to +/// sink instructions, differing in the number of instructions sunk, +/// the number of predecessors sunk from and the number of PHIs +/// required. +struct SinkingInstructionCandidate { + unsigned NumBlocks; + unsigned NumInstructions; + unsigned NumPHIs; + unsigned NumMemoryInsts; + int Cost = -1; + SmallVector<BasicBlock *, 4> Blocks; + + void calculateCost(unsigned NumOrigPHIs, unsigned NumOrigBlocks) { + unsigned NumExtraPHIs = NumPHIs - NumOrigPHIs; + unsigned SplitEdgeCost = (NumOrigBlocks > NumBlocks) ? 2 : 0; + Cost = (NumInstructions * (NumBlocks - 1)) - + (NumExtraPHIs * + NumExtraPHIs) // PHIs are expensive, so make sure they're worth it. + - SplitEdgeCost; + } + + bool operator>(const SinkingInstructionCandidate &Other) const { + return Cost > Other.Cost; + } +}; + +#ifndef NDEBUG +raw_ostream &operator<<(raw_ostream &OS, const SinkingInstructionCandidate &C) { + OS << "<Candidate Cost=" << C.Cost << " #Blocks=" << C.NumBlocks + << " #Insts=" << C.NumInstructions << " #PHIs=" << C.NumPHIs << ">"; + return OS; +} +#endif + +//===----------------------------------------------------------------------===// + +/// Describes a PHI node that may or may not exist. These track the PHIs +/// that must be created if we sunk a sequence of instructions. It provides +/// a hash function for efficient equality comparisons. +class ModelledPHI { + SmallVector<Value *, 4> Values; + SmallVector<BasicBlock *, 4> Blocks; + +public: + ModelledPHI() = default; + + ModelledPHI(const PHINode *PN) { + // BasicBlock comes first so we sort by basic block pointer order, then by value pointer order. + SmallVector<std::pair<BasicBlock *, Value *>, 4> Ops; + for (unsigned I = 0, E = PN->getNumIncomingValues(); I != E; ++I) + Ops.push_back({PN->getIncomingBlock(I), PN->getIncomingValue(I)}); + std::sort(Ops.begin(), Ops.end()); + for (auto &P : Ops) { + Blocks.push_back(P.first); + Values.push_back(P.second); + } + } + + /// Create a dummy ModelledPHI that will compare unequal to any other ModelledPHI + /// without the same ID. + /// \note This is specifically for DenseMapInfo - do not use this! + static ModelledPHI createDummy(size_t ID) { + ModelledPHI M; + M.Values.push_back(reinterpret_cast<Value*>(ID)); + return M; + } + + /// Create a PHI from an array of incoming values and incoming blocks. + template <typename VArray, typename BArray> + ModelledPHI(const VArray &V, const BArray &B) { + std::copy(V.begin(), V.end(), std::back_inserter(Values)); + std::copy(B.begin(), B.end(), std::back_inserter(Blocks)); + } + + /// Create a PHI from [I[OpNum] for I in Insts]. + template <typename BArray> + ModelledPHI(ArrayRef<Instruction *> Insts, unsigned OpNum, const BArray &B) { + std::copy(B.begin(), B.end(), std::back_inserter(Blocks)); + for (auto *I : Insts) + Values.push_back(I->getOperand(OpNum)); + } + + /// Restrict the PHI's contents down to only \c NewBlocks. + /// \c NewBlocks must be a subset of \c this->Blocks. + void restrictToBlocks(const SmallSetVector<BasicBlock *, 4> &NewBlocks) { + auto BI = Blocks.begin(); + auto VI = Values.begin(); + while (BI != Blocks.end()) { + assert(VI != Values.end()); + if (std::find(NewBlocks.begin(), NewBlocks.end(), *BI) == + NewBlocks.end()) { + BI = Blocks.erase(BI); + VI = Values.erase(VI); + } else { + ++BI; + ++VI; + } + } + assert(Blocks.size() == NewBlocks.size()); + } + + ArrayRef<Value *> getValues() const { return Values; } + + bool areAllIncomingValuesSame() const { + return llvm::all_of(Values, [&](Value *V) { return V == Values[0]; }); + } + + bool areAllIncomingValuesSameType() const { + return llvm::all_of( + Values, [&](Value *V) { return V->getType() == Values[0]->getType(); }); + } + + bool areAnyIncomingValuesConstant() const { + return llvm::any_of(Values, [&](Value *V) { return isa<Constant>(V); }); + } + + // Hash functor + unsigned hash() const { + return (unsigned)hash_combine_range(Values.begin(), Values.end()); + } + + bool operator==(const ModelledPHI &Other) const { + return Values == Other.Values && Blocks == Other.Blocks; + } +}; + +template <typename ModelledPHI> struct DenseMapInfo { + static inline ModelledPHI &getEmptyKey() { + static ModelledPHI Dummy = ModelledPHI::createDummy(0); + return Dummy; + } + + static inline ModelledPHI &getTombstoneKey() { + static ModelledPHI Dummy = ModelledPHI::createDummy(1); + return Dummy; + } + + static unsigned getHashValue(const ModelledPHI &V) { return V.hash(); } + + static bool isEqual(const ModelledPHI &LHS, const ModelledPHI &RHS) { + return LHS == RHS; + } +}; + +using ModelledPHISet = DenseSet<ModelledPHI, DenseMapInfo<ModelledPHI>>; + +//===----------------------------------------------------------------------===// +// ValueTable +//===----------------------------------------------------------------------===// +// This is a value number table where the value number is a function of the +// *uses* of a value, rather than its operands. Thus, if VN(A) == VN(B) we know +// that the program would be equivalent if we replaced A with PHI(A, B). +//===----------------------------------------------------------------------===// + +/// A GVN expression describing how an instruction is used. The operands +/// field of BasicExpression is used to store uses, not operands. +/// +/// This class also contains fields for discriminators used when determining +/// equivalence of instructions with sideeffects. +class InstructionUseExpr : public GVNExpression::BasicExpression { + unsigned MemoryUseOrder = -1; + bool Volatile = false; + +public: + InstructionUseExpr(Instruction *I, ArrayRecycler<Value *> &R, + BumpPtrAllocator &A) + : GVNExpression::BasicExpression(I->getNumUses()) { + allocateOperands(R, A); + setOpcode(I->getOpcode()); + setType(I->getType()); + + for (auto &U : I->uses()) + op_push_back(U.getUser()); + std::sort(op_begin(), op_end()); + } + + void setMemoryUseOrder(unsigned MUO) { MemoryUseOrder = MUO; } + void setVolatile(bool V) { Volatile = V; } + + hash_code getHashValue() const override { + return hash_combine(GVNExpression::BasicExpression::getHashValue(), + MemoryUseOrder, Volatile); + } + + template <typename Function> hash_code getHashValue(Function MapFn) { + hash_code H = + hash_combine(getOpcode(), getType(), MemoryUseOrder, Volatile); + for (auto *V : operands()) + H = hash_combine(H, MapFn(V)); + return H; + } +}; + +class ValueTable { + DenseMap<Value *, uint32_t> ValueNumbering; + DenseMap<GVNExpression::Expression *, uint32_t> ExpressionNumbering; + DenseMap<size_t, uint32_t> HashNumbering; + BumpPtrAllocator Allocator; + ArrayRecycler<Value *> Recycler; + uint32_t nextValueNumber = 1; + + /// Create an expression for I based on its opcode and its uses. If I + /// touches or reads memory, the expression is also based upon its memory + /// order - see \c getMemoryUseOrder(). + InstructionUseExpr *createExpr(Instruction *I) { + InstructionUseExpr *E = + new (Allocator) InstructionUseExpr(I, Recycler, Allocator); + if (isMemoryInst(I)) + E->setMemoryUseOrder(getMemoryUseOrder(I)); + + if (CmpInst *C = dyn_cast<CmpInst>(I)) { + CmpInst::Predicate Predicate = C->getPredicate(); + E->setOpcode((C->getOpcode() << 8) | Predicate); + } + return E; + } + + /// Helper to compute the value number for a memory instruction + /// (LoadInst/StoreInst), including checking the memory ordering and + /// volatility. + template <class Inst> InstructionUseExpr *createMemoryExpr(Inst *I) { + if (isStrongerThanUnordered(I->getOrdering()) || I->isAtomic()) + return nullptr; + InstructionUseExpr *E = createExpr(I); + E->setVolatile(I->isVolatile()); + return E; + } + +public: + ValueTable() = default; + + /// Returns the value number for the specified value, assigning + /// it a new number if it did not have one before. + uint32_t lookupOrAdd(Value *V) { + auto VI = ValueNumbering.find(V); + if (VI != ValueNumbering.end()) + return VI->second; + + if (!isa<Instruction>(V)) { + ValueNumbering[V] = nextValueNumber; + return nextValueNumber++; + } + + Instruction *I = cast<Instruction>(V); + InstructionUseExpr *exp = nullptr; + switch (I->getOpcode()) { + case Instruction::Load: + exp = createMemoryExpr(cast<LoadInst>(I)); + break; + case Instruction::Store: + exp = createMemoryExpr(cast<StoreInst>(I)); + break; + case Instruction::Call: + case Instruction::Invoke: + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::ICmp: + case Instruction::FCmp: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::BitCast: + case Instruction::Select: + case Instruction::ExtractElement: + case Instruction::InsertElement: + case Instruction::ShuffleVector: + case Instruction::InsertValue: + case Instruction::GetElementPtr: + exp = createExpr(I); + break; + default: + break; + } + + if (!exp) { + ValueNumbering[V] = nextValueNumber; + return nextValueNumber++; + } + + uint32_t e = ExpressionNumbering[exp]; + if (!e) { + hash_code H = exp->getHashValue([=](Value *V) { return lookupOrAdd(V); }); + auto I = HashNumbering.find(H); + if (I != HashNumbering.end()) { + e = I->second; + } else { + e = nextValueNumber++; + HashNumbering[H] = e; + ExpressionNumbering[exp] = e; + } + } + ValueNumbering[V] = e; + return e; + } + + /// Returns the value number of the specified value. Fails if the value has + /// not yet been numbered. + uint32_t lookup(Value *V) const { + auto VI = ValueNumbering.find(V); + assert(VI != ValueNumbering.end() && "Value not numbered?"); + return VI->second; + } + + /// Removes all value numberings and resets the value table. + void clear() { + ValueNumbering.clear(); + ExpressionNumbering.clear(); + HashNumbering.clear(); + Recycler.clear(Allocator); + nextValueNumber = 1; + } + + /// \c Inst uses or touches memory. Return an ID describing the memory state + /// at \c Inst such that if getMemoryUseOrder(I1) == getMemoryUseOrder(I2), + /// the exact same memory operations happen after I1 and I2. + /// + /// This is a very hard problem in general, so we use domain-specific + /// knowledge that we only ever check for equivalence between blocks sharing a + /// single immediate successor that is common, and when determining if I1 == + /// I2 we will have already determined that next(I1) == next(I2). This + /// inductive property allows us to simply return the value number of the next + /// instruction that defines memory. + uint32_t getMemoryUseOrder(Instruction *Inst) { + auto *BB = Inst->getParent(); + for (auto I = std::next(Inst->getIterator()), E = BB->end(); + I != E && !I->isTerminator(); ++I) { + if (!isMemoryInst(&*I)) + continue; + if (isa<LoadInst>(&*I)) + continue; + CallInst *CI = dyn_cast<CallInst>(&*I); + if (CI && CI->onlyReadsMemory()) + continue; + InvokeInst *II = dyn_cast<InvokeInst>(&*I); + if (II && II->onlyReadsMemory()) + continue; + return lookupOrAdd(&*I); + } + return 0; + } +}; + +//===----------------------------------------------------------------------===// + +class GVNSink { +public: + GVNSink() = default; + + bool run(Function &F) { + DEBUG(dbgs() << "GVNSink: running on function @" << F.getName() << "\n"); + + unsigned NumSunk = 0; + ReversePostOrderTraversal<Function*> RPOT(&F); + for (auto *N : RPOT) + NumSunk += sinkBB(N); + + return NumSunk > 0; + } + +private: + ValueTable VN; + + bool isInstructionBlacklisted(Instruction *I) { + // These instructions may change or break semantics if moved. + if (isa<PHINode>(I) || I->isEHPad() || isa<AllocaInst>(I) || + I->getType()->isTokenTy()) + return true; + return false; + } + + /// The main heuristic function. Analyze the set of instructions pointed to by + /// LRI and return a candidate solution if these instructions can be sunk, or + /// None otherwise. + Optional<SinkingInstructionCandidate> analyzeInstructionForSinking( + LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum, + ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents); + + /// Create a ModelledPHI for each PHI in BB, adding to PHIs. + void analyzeInitialPHIs(BasicBlock *BB, ModelledPHISet &PHIs, + SmallPtrSetImpl<Value *> &PHIContents) { + for (PHINode &PN : BB->phis()) { + auto MPHI = ModelledPHI(&PN); + PHIs.insert(MPHI); + for (auto *V : MPHI.getValues()) + PHIContents.insert(V); + } + } + + /// The main instruction sinking driver. Set up state and try and sink + /// instructions into BBEnd from its predecessors. + unsigned sinkBB(BasicBlock *BBEnd); + + /// Perform the actual mechanics of sinking an instruction from Blocks into + /// BBEnd, which is their only successor. + void sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, BasicBlock *BBEnd); + + /// Remove PHIs that all have the same incoming value. + void foldPointlessPHINodes(BasicBlock *BB) { + auto I = BB->begin(); + while (PHINode *PN = dyn_cast<PHINode>(I++)) { + if (!llvm::all_of(PN->incoming_values(), [&](const Value *V) { + return V == PN->getIncomingValue(0); + })) + continue; + if (PN->getIncomingValue(0) != PN) + PN->replaceAllUsesWith(PN->getIncomingValue(0)); + else + PN->replaceAllUsesWith(UndefValue::get(PN->getType())); + PN->eraseFromParent(); + } + } +}; + +Optional<SinkingInstructionCandidate> GVNSink::analyzeInstructionForSinking( + LockstepReverseIterator &LRI, unsigned &InstNum, unsigned &MemoryInstNum, + ModelledPHISet &NeededPHIs, SmallPtrSetImpl<Value *> &PHIContents) { + auto Insts = *LRI; + DEBUG(dbgs() << " -- Analyzing instruction set: [\n"; for (auto *I + : Insts) { + I->dump(); + } dbgs() << " ]\n";); + + DenseMap<uint32_t, unsigned> VNums; + for (auto *I : Insts) { + uint32_t N = VN.lookupOrAdd(I); + DEBUG(dbgs() << " VN=" << Twine::utohexstr(N) << " for" << *I << "\n"); + if (N == ~0U) + return None; + VNums[N]++; + } + unsigned VNumToSink = + std::max_element(VNums.begin(), VNums.end(), + [](const std::pair<uint32_t, unsigned> &I, + const std::pair<uint32_t, unsigned> &J) { + return I.second < J.second; + }) + ->first; + + if (VNums[VNumToSink] == 1) + // Can't sink anything! + return None; + + // Now restrict the number of incoming blocks down to only those with + // VNumToSink. + auto &ActivePreds = LRI.getActiveBlocks(); + unsigned InitialActivePredSize = ActivePreds.size(); + SmallVector<Instruction *, 4> NewInsts; + for (auto *I : Insts) { + if (VN.lookup(I) != VNumToSink) + ActivePreds.remove(I->getParent()); + else + NewInsts.push_back(I); + } + for (auto *I : NewInsts) + if (isInstructionBlacklisted(I)) + return None; + + // If we've restricted the incoming blocks, restrict all needed PHIs also + // to that set. + bool RecomputePHIContents = false; + if (ActivePreds.size() != InitialActivePredSize) { + ModelledPHISet NewNeededPHIs; + for (auto P : NeededPHIs) { + P.restrictToBlocks(ActivePreds); + NewNeededPHIs.insert(P); + } + NeededPHIs = NewNeededPHIs; + LRI.restrictToBlocks(ActivePreds); + RecomputePHIContents = true; + } + + // The sunk instruction's results. + ModelledPHI NewPHI(NewInsts, ActivePreds); + + // Does sinking this instruction render previous PHIs redundant? + if (NeededPHIs.find(NewPHI) != NeededPHIs.end()) { + NeededPHIs.erase(NewPHI); + RecomputePHIContents = true; + } + + if (RecomputePHIContents) { + // The needed PHIs have changed, so recompute the set of all needed + // values. + PHIContents.clear(); + for (auto &PHI : NeededPHIs) + PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end()); + } + + // Is this instruction required by a later PHI that doesn't match this PHI? + // if so, we can't sink this instruction. + for (auto *V : NewPHI.getValues()) + if (PHIContents.count(V)) + // V exists in this PHI, but the whole PHI is different to NewPHI + // (else it would have been removed earlier). We cannot continue + // because this isn't representable. + return None; + + // Which operands need PHIs? + // FIXME: If any of these fail, we should partition up the candidates to + // try and continue making progress. + Instruction *I0 = NewInsts[0]; + for (unsigned OpNum = 0, E = I0->getNumOperands(); OpNum != E; ++OpNum) { + ModelledPHI PHI(NewInsts, OpNum, ActivePreds); + if (PHI.areAllIncomingValuesSame()) + continue; + if (!canReplaceOperandWithVariable(I0, OpNum)) + // We can 't create a PHI from this instruction! + return None; + if (NeededPHIs.count(PHI)) + continue; + if (!PHI.areAllIncomingValuesSameType()) + return None; + // Don't create indirect calls! The called value is the final operand. + if ((isa<CallInst>(I0) || isa<InvokeInst>(I0)) && OpNum == E - 1 && + PHI.areAnyIncomingValuesConstant()) + return None; + + NeededPHIs.reserve(NeededPHIs.size()); + NeededPHIs.insert(PHI); + PHIContents.insert(PHI.getValues().begin(), PHI.getValues().end()); + } + + if (isMemoryInst(NewInsts[0])) + ++MemoryInstNum; + + SinkingInstructionCandidate Cand; + Cand.NumInstructions = ++InstNum; + Cand.NumMemoryInsts = MemoryInstNum; + Cand.NumBlocks = ActivePreds.size(); + Cand.NumPHIs = NeededPHIs.size(); + for (auto *C : ActivePreds) + Cand.Blocks.push_back(C); + + return Cand; +} + +unsigned GVNSink::sinkBB(BasicBlock *BBEnd) { + DEBUG(dbgs() << "GVNSink: running on basic block "; + BBEnd->printAsOperand(dbgs()); dbgs() << "\n"); + SmallVector<BasicBlock *, 4> Preds; + for (auto *B : predecessors(BBEnd)) { + auto *T = B->getTerminator(); + if (isa<BranchInst>(T) || isa<SwitchInst>(T)) + Preds.push_back(B); + else + return 0; + } + if (Preds.size() < 2) + return 0; + std::sort(Preds.begin(), Preds.end()); + + unsigned NumOrigPreds = Preds.size(); + // We can only sink instructions through unconditional branches. + for (auto I = Preds.begin(); I != Preds.end();) { + if ((*I)->getTerminator()->getNumSuccessors() != 1) + I = Preds.erase(I); + else + ++I; + } + + LockstepReverseIterator LRI(Preds); + SmallVector<SinkingInstructionCandidate, 4> Candidates; + unsigned InstNum = 0, MemoryInstNum = 0; + ModelledPHISet NeededPHIs; + SmallPtrSet<Value *, 4> PHIContents; + analyzeInitialPHIs(BBEnd, NeededPHIs, PHIContents); + unsigned NumOrigPHIs = NeededPHIs.size(); + + while (LRI.isValid()) { + auto Cand = analyzeInstructionForSinking(LRI, InstNum, MemoryInstNum, + NeededPHIs, PHIContents); + if (!Cand) + break; + Cand->calculateCost(NumOrigPHIs, Preds.size()); + Candidates.emplace_back(*Cand); + --LRI; + } + + std::stable_sort( + Candidates.begin(), Candidates.end(), + [](const SinkingInstructionCandidate &A, + const SinkingInstructionCandidate &B) { return A > B; }); + DEBUG(dbgs() << " -- Sinking candidates:\n"; for (auto &C + : Candidates) dbgs() + << " " << C << "\n";); + + // Pick the top candidate, as long it is positive! + if (Candidates.empty() || Candidates.front().Cost <= 0) + return 0; + auto C = Candidates.front(); + + DEBUG(dbgs() << " -- Sinking: " << C << "\n"); + BasicBlock *InsertBB = BBEnd; + if (C.Blocks.size() < NumOrigPreds) { + DEBUG(dbgs() << " -- Splitting edge to "; BBEnd->printAsOperand(dbgs()); + dbgs() << "\n"); + InsertBB = SplitBlockPredecessors(BBEnd, C.Blocks, ".gvnsink.split"); + if (!InsertBB) { + DEBUG(dbgs() << " -- FAILED to split edge!\n"); + // Edge couldn't be split. + return 0; + } + } + + for (unsigned I = 0; I < C.NumInstructions; ++I) + sinkLastInstruction(C.Blocks, InsertBB); + + return C.NumInstructions; +} + +void GVNSink::sinkLastInstruction(ArrayRef<BasicBlock *> Blocks, + BasicBlock *BBEnd) { + SmallVector<Instruction *, 4> Insts; + for (BasicBlock *BB : Blocks) + Insts.push_back(BB->getTerminator()->getPrevNode()); + Instruction *I0 = Insts.front(); + + SmallVector<Value *, 4> NewOperands; + for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) { + bool NeedPHI = llvm::any_of(Insts, [&I0, O](const Instruction *I) { + return I->getOperand(O) != I0->getOperand(O); + }); + if (!NeedPHI) { + NewOperands.push_back(I0->getOperand(O)); + continue; + } + + // Create a new PHI in the successor block and populate it. + auto *Op = I0->getOperand(O); + assert(!Op->getType()->isTokenTy() && "Can't PHI tokens!"); + auto *PN = PHINode::Create(Op->getType(), Insts.size(), + Op->getName() + ".sink", &BBEnd->front()); + for (auto *I : Insts) + PN->addIncoming(I->getOperand(O), I->getParent()); + NewOperands.push_back(PN); + } + + // Arbitrarily use I0 as the new "common" instruction; remap its operands + // and move it to the start of the successor block. + for (unsigned O = 0, E = I0->getNumOperands(); O != E; ++O) + I0->getOperandUse(O).set(NewOperands[O]); + I0->moveBefore(&*BBEnd->getFirstInsertionPt()); + + // Update metadata and IR flags. + for (auto *I : Insts) + if (I != I0) { + combineMetadataForCSE(I0, I); + I0->andIRFlags(I); + } + + for (auto *I : Insts) + if (I != I0) + I->replaceAllUsesWith(I0); + foldPointlessPHINodes(BBEnd); + + // Finally nuke all instructions apart from the common instruction. + for (auto *I : Insts) + if (I != I0) + I->eraseFromParent(); + + NumRemoved += Insts.size() - 1; +} + +//////////////////////////////////////////////////////////////////////////////// +// Pass machinery / boilerplate + +class GVNSinkLegacyPass : public FunctionPass { +public: + static char ID; + + GVNSinkLegacyPass() : FunctionPass(ID) { + initializeGVNSinkLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + GVNSink G; + return G.run(F); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; + +} // end anonymous namespace + +PreservedAnalyses GVNSinkPass::run(Function &F, FunctionAnalysisManager &AM) { + GVNSink G; + if (!G.run(F)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} + +char GVNSinkLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(GVNSinkLegacyPass, "gvn-sink", + "Early GVN sinking of Expressions", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_END(GVNSinkLegacyPass, "gvn-sink", + "Early GVN sinking of Expressions", false, false) + +FunctionPass *llvm::createGVNSinkPass() { return new GVNSinkLegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp b/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp new file mode 100644 index 000000000000..c4aeccb85ca7 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/GuardWidening.cpp @@ -0,0 +1,696 @@ +//===- GuardWidening.cpp - ---- Guard widening ----------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the guard widening pass. The semantics of the +// @llvm.experimental.guard intrinsic lets LLVM transform it so that it fails +// more often that it did before the transform. This optimization is called +// "widening" and can be used hoist and common runtime checks in situations like +// these: +// +// %cmp0 = 7 u< Length +// call @llvm.experimental.guard(i1 %cmp0) [ "deopt"(...) ] +// call @unknown_side_effects() +// %cmp1 = 9 u< Length +// call @llvm.experimental.guard(i1 %cmp1) [ "deopt"(...) ] +// ... +// +// => +// +// %cmp0 = 9 u< Length +// call @llvm.experimental.guard(i1 %cmp0) [ "deopt"(...) ] +// call @unknown_side_effects() +// ... +// +// If %cmp0 is false, @llvm.experimental.guard will "deoptimize" back to a +// generic implementation of the same function, which will have the correct +// semantics from that point onward. It is always _legal_ to deoptimize (so +// replacing %cmp0 with false is "correct"), though it may not always be +// profitable to do so. +// +// NB! This pass is a work in progress. It hasn't been tuned to be "production +// ready" yet. It is known to have quadriatic running time and will not scale +// to large numbers of guards +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/GuardWidening.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/PostDominators.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/KnownBits.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; + +#define DEBUG_TYPE "guard-widening" + +namespace { + +class GuardWideningImpl { + DominatorTree &DT; + PostDominatorTree &PDT; + LoopInfo &LI; + + /// The set of guards whose conditions have been widened into dominating + /// guards. + SmallVector<IntrinsicInst *, 16> EliminatedGuards; + + /// The set of guards which have been widened to include conditions to other + /// guards. + DenseSet<IntrinsicInst *> WidenedGuards; + + /// Try to eliminate guard \p Guard by widening it into an earlier dominating + /// guard. \p DFSI is the DFS iterator on the dominator tree that is + /// currently visiting the block containing \p Guard, and \p GuardsPerBlock + /// maps BasicBlocks to the set of guards seen in that block. + bool eliminateGuardViaWidening( + IntrinsicInst *Guard, const df_iterator<DomTreeNode *> &DFSI, + const DenseMap<BasicBlock *, SmallVector<IntrinsicInst *, 8>> & + GuardsPerBlock); + + /// Used to keep track of which widening potential is more effective. + enum WideningScore { + /// Don't widen. + WS_IllegalOrNegative, + + /// Widening is performance neutral as far as the cycles spent in check + /// conditions goes (but can still help, e.g., code layout, having less + /// deopt state). + WS_Neutral, + + /// Widening is profitable. + WS_Positive, + + /// Widening is very profitable. Not significantly different from \c + /// WS_Positive, except by the order. + WS_VeryPositive + }; + + static StringRef scoreTypeToString(WideningScore WS); + + /// Compute the score for widening the condition in \p DominatedGuard + /// (contained in \p DominatedGuardLoop) into \p DominatingGuard (contained in + /// \p DominatingGuardLoop). + WideningScore computeWideningScore(IntrinsicInst *DominatedGuard, + Loop *DominatedGuardLoop, + IntrinsicInst *DominatingGuard, + Loop *DominatingGuardLoop); + + /// Helper to check if \p V can be hoisted to \p InsertPos. + bool isAvailableAt(Value *V, Instruction *InsertPos) { + SmallPtrSet<Instruction *, 8> Visited; + return isAvailableAt(V, InsertPos, Visited); + } + + bool isAvailableAt(Value *V, Instruction *InsertPos, + SmallPtrSetImpl<Instruction *> &Visited); + + /// Helper to hoist \p V to \p InsertPos. Guaranteed to succeed if \c + /// isAvailableAt returned true. + void makeAvailableAt(Value *V, Instruction *InsertPos); + + /// Common helper used by \c widenGuard and \c isWideningCondProfitable. Try + /// to generate an expression computing the logical AND of \p Cond0 and \p + /// Cond1. Return true if the expression computing the AND is only as + /// expensive as computing one of the two. If \p InsertPt is true then + /// actually generate the resulting expression, make it available at \p + /// InsertPt and return it in \p Result (else no change to the IR is made). + bool widenCondCommon(Value *Cond0, Value *Cond1, Instruction *InsertPt, + Value *&Result); + + /// Represents a range check of the form \c Base + \c Offset u< \c Length, + /// with the constraint that \c Length is not negative. \c CheckInst is the + /// pre-existing instruction in the IR that computes the result of this range + /// check. + class RangeCheck { + Value *Base; + ConstantInt *Offset; + Value *Length; + ICmpInst *CheckInst; + + public: + explicit RangeCheck(Value *Base, ConstantInt *Offset, Value *Length, + ICmpInst *CheckInst) + : Base(Base), Offset(Offset), Length(Length), CheckInst(CheckInst) {} + + void setBase(Value *NewBase) { Base = NewBase; } + void setOffset(ConstantInt *NewOffset) { Offset = NewOffset; } + + Value *getBase() const { return Base; } + ConstantInt *getOffset() const { return Offset; } + const APInt &getOffsetValue() const { return getOffset()->getValue(); } + Value *getLength() const { return Length; }; + ICmpInst *getCheckInst() const { return CheckInst; } + + void print(raw_ostream &OS, bool PrintTypes = false) { + OS << "Base: "; + Base->printAsOperand(OS, PrintTypes); + OS << " Offset: "; + Offset->printAsOperand(OS, PrintTypes); + OS << " Length: "; + Length->printAsOperand(OS, PrintTypes); + } + + LLVM_DUMP_METHOD void dump() { + print(dbgs()); + dbgs() << "\n"; + } + }; + + /// Parse \p CheckCond into a conjunction (logical-and) of range checks; and + /// append them to \p Checks. Returns true on success, may clobber \c Checks + /// on failure. + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks) { + SmallPtrSet<Value *, 8> Visited; + return parseRangeChecks(CheckCond, Checks, Visited); + } + + bool parseRangeChecks(Value *CheckCond, SmallVectorImpl<RangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited); + + /// Combine the checks in \p Checks into a smaller set of checks and append + /// them into \p CombinedChecks. Return true on success (i.e. all of checks + /// in \p Checks were combined into \p CombinedChecks). Clobbers \p Checks + /// and \p CombinedChecks on success and on failure. + bool combineRangeChecks(SmallVectorImpl<RangeCheck> &Checks, + SmallVectorImpl<RangeCheck> &CombinedChecks); + + /// Can we compute the logical AND of \p Cond0 and \p Cond1 for the price of + /// computing only one of the two expressions? + bool isWideningCondProfitable(Value *Cond0, Value *Cond1) { + Value *ResultUnused; + return widenCondCommon(Cond0, Cond1, /*InsertPt=*/nullptr, ResultUnused); + } + + /// Widen \p ToWiden to fail if \p NewCondition is false (in addition to + /// whatever it is already checking). + void widenGuard(IntrinsicInst *ToWiden, Value *NewCondition) { + Value *Result; + widenCondCommon(ToWiden->getArgOperand(0), NewCondition, ToWiden, Result); + ToWiden->setArgOperand(0, Result); + } + +public: + explicit GuardWideningImpl(DominatorTree &DT, PostDominatorTree &PDT, + LoopInfo &LI) + : DT(DT), PDT(PDT), LI(LI) {} + + /// The entry point for this pass. + bool run(); +}; + +struct GuardWideningLegacyPass : public FunctionPass { + static char ID; + GuardWideningPass Impl; + + GuardWideningLegacyPass() : FunctionPass(ID) { + initializeGuardWideningLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + return GuardWideningImpl( + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<PostDominatorTreeWrapperPass>().getPostDomTree(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo()).run(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<PostDominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + } +}; + +} + +bool GuardWideningImpl::run() { + using namespace llvm::PatternMatch; + + DenseMap<BasicBlock *, SmallVector<IntrinsicInst *, 8>> GuardsInBlock; + bool Changed = false; + + for (auto DFI = df_begin(DT.getRootNode()), DFE = df_end(DT.getRootNode()); + DFI != DFE; ++DFI) { + auto *BB = (*DFI)->getBlock(); + auto &CurrentList = GuardsInBlock[BB]; + + for (auto &I : *BB) + if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>())) + CurrentList.push_back(cast<IntrinsicInst>(&I)); + + for (auto *II : CurrentList) + Changed |= eliminateGuardViaWidening(II, DFI, GuardsInBlock); + } + + for (auto *II : EliminatedGuards) + if (!WidenedGuards.count(II)) + II->eraseFromParent(); + + return Changed; +} + +bool GuardWideningImpl::eliminateGuardViaWidening( + IntrinsicInst *GuardInst, const df_iterator<DomTreeNode *> &DFSI, + const DenseMap<BasicBlock *, SmallVector<IntrinsicInst *, 8>> & + GuardsInBlock) { + IntrinsicInst *BestSoFar = nullptr; + auto BestScoreSoFar = WS_IllegalOrNegative; + auto *GuardInstLoop = LI.getLoopFor(GuardInst->getParent()); + + // In the set of dominating guards, find the one we can merge GuardInst with + // for the most profit. + for (unsigned i = 0, e = DFSI.getPathLength(); i != e; ++i) { + auto *CurBB = DFSI.getPath(i)->getBlock(); + auto *CurLoop = LI.getLoopFor(CurBB); + assert(GuardsInBlock.count(CurBB) && "Must have been populated by now!"); + const auto &GuardsInCurBB = GuardsInBlock.find(CurBB)->second; + + auto I = GuardsInCurBB.begin(); + auto E = GuardsInCurBB.end(); + +#ifndef NDEBUG + { + unsigned Index = 0; + for (auto &I : *CurBB) { + if (Index == GuardsInCurBB.size()) + break; + if (GuardsInCurBB[Index] == &I) + Index++; + } + assert(Index == GuardsInCurBB.size() && + "Guards expected to be in order!"); + } +#endif + + assert((i == (e - 1)) == (GuardInst->getParent() == CurBB) && "Bad DFS?"); + + if (i == (e - 1)) { + // Corner case: make sure we're only looking at guards strictly dominating + // GuardInst when visiting GuardInst->getParent(). + auto NewEnd = std::find(I, E, GuardInst); + assert(NewEnd != E && "GuardInst not in its own block?"); + E = NewEnd; + } + + for (auto *Candidate : make_range(I, E)) { + auto Score = + computeWideningScore(GuardInst, GuardInstLoop, Candidate, CurLoop); + DEBUG(dbgs() << "Score between " << *GuardInst->getArgOperand(0) + << " and " << *Candidate->getArgOperand(0) << " is " + << scoreTypeToString(Score) << "\n"); + if (Score > BestScoreSoFar) { + BestScoreSoFar = Score; + BestSoFar = Candidate; + } + } + } + + if (BestScoreSoFar == WS_IllegalOrNegative) { + DEBUG(dbgs() << "Did not eliminate guard " << *GuardInst << "\n"); + return false; + } + + assert(BestSoFar != GuardInst && "Should have never visited same guard!"); + assert(DT.dominates(BestSoFar, GuardInst) && "Should be!"); + + DEBUG(dbgs() << "Widening " << *GuardInst << " into " << *BestSoFar + << " with score " << scoreTypeToString(BestScoreSoFar) << "\n"); + widenGuard(BestSoFar, GuardInst->getArgOperand(0)); + GuardInst->setArgOperand(0, ConstantInt::getTrue(GuardInst->getContext())); + EliminatedGuards.push_back(GuardInst); + WidenedGuards.insert(BestSoFar); + return true; +} + +GuardWideningImpl::WideningScore GuardWideningImpl::computeWideningScore( + IntrinsicInst *DominatedGuard, Loop *DominatedGuardLoop, + IntrinsicInst *DominatingGuard, Loop *DominatingGuardLoop) { + bool HoistingOutOfLoop = false; + + if (DominatingGuardLoop != DominatedGuardLoop) { + if (DominatingGuardLoop && + !DominatingGuardLoop->contains(DominatedGuardLoop)) + return WS_IllegalOrNegative; + + HoistingOutOfLoop = true; + } + + if (!isAvailableAt(DominatedGuard->getArgOperand(0), DominatingGuard)) + return WS_IllegalOrNegative; + + bool HoistingOutOfIf = + !PDT.dominates(DominatedGuard->getParent(), DominatingGuard->getParent()); + + if (isWideningCondProfitable(DominatedGuard->getArgOperand(0), + DominatingGuard->getArgOperand(0))) + return HoistingOutOfLoop ? WS_VeryPositive : WS_Positive; + + if (HoistingOutOfLoop) + return WS_Positive; + + return HoistingOutOfIf ? WS_IllegalOrNegative : WS_Neutral; +} + +bool GuardWideningImpl::isAvailableAt(Value *V, Instruction *Loc, + SmallPtrSetImpl<Instruction *> &Visited) { + auto *Inst = dyn_cast<Instruction>(V); + if (!Inst || DT.dominates(Inst, Loc) || Visited.count(Inst)) + return true; + + if (!isSafeToSpeculativelyExecute(Inst, Loc, &DT) || + Inst->mayReadFromMemory()) + return false; + + Visited.insert(Inst); + + // We only want to go _up_ the dominance chain when recursing. + assert(!isa<PHINode>(Loc) && + "PHIs should return false for isSafeToSpeculativelyExecute"); + assert(DT.isReachableFromEntry(Inst->getParent()) && + "We did a DFS from the block entry!"); + return all_of(Inst->operands(), + [&](Value *Op) { return isAvailableAt(Op, Loc, Visited); }); +} + +void GuardWideningImpl::makeAvailableAt(Value *V, Instruction *Loc) { + auto *Inst = dyn_cast<Instruction>(V); + if (!Inst || DT.dominates(Inst, Loc)) + return; + + assert(isSafeToSpeculativelyExecute(Inst, Loc, &DT) && + !Inst->mayReadFromMemory() && "Should've checked with isAvailableAt!"); + + for (Value *Op : Inst->operands()) + makeAvailableAt(Op, Loc); + + Inst->moveBefore(Loc); +} + +bool GuardWideningImpl::widenCondCommon(Value *Cond0, Value *Cond1, + Instruction *InsertPt, Value *&Result) { + using namespace llvm::PatternMatch; + + { + // L >u C0 && L >u C1 -> L >u max(C0, C1) + ConstantInt *RHS0, *RHS1; + Value *LHS; + ICmpInst::Predicate Pred0, Pred1; + if (match(Cond0, m_ICmp(Pred0, m_Value(LHS), m_ConstantInt(RHS0))) && + match(Cond1, m_ICmp(Pred1, m_Specific(LHS), m_ConstantInt(RHS1)))) { + + ConstantRange CR0 = + ConstantRange::makeExactICmpRegion(Pred0, RHS0->getValue()); + ConstantRange CR1 = + ConstantRange::makeExactICmpRegion(Pred1, RHS1->getValue()); + + // SubsetIntersect is a subset of the actual mathematical intersection of + // CR0 and CR1, while SupersetIntersect is a superset of the actual + // mathematical intersection. If these two ConstantRanges are equal, then + // we know we were able to represent the actual mathematical intersection + // of CR0 and CR1, and can use the same to generate an icmp instruction. + // + // Given what we're doing here and the semantics of guards, it would + // actually be correct to just use SubsetIntersect, but that may be too + // aggressive in cases we care about. + auto SubsetIntersect = CR0.inverse().unionWith(CR1.inverse()).inverse(); + auto SupersetIntersect = CR0.intersectWith(CR1); + + APInt NewRHSAP; + CmpInst::Predicate Pred; + if (SubsetIntersect == SupersetIntersect && + SubsetIntersect.getEquivalentICmp(Pred, NewRHSAP)) { + if (InsertPt) { + ConstantInt *NewRHS = ConstantInt::get(Cond0->getContext(), NewRHSAP); + Result = new ICmpInst(InsertPt, Pred, LHS, NewRHS, "wide.chk"); + } + return true; + } + } + } + + { + SmallVector<GuardWideningImpl::RangeCheck, 4> Checks, CombinedChecks; + if (parseRangeChecks(Cond0, Checks) && parseRangeChecks(Cond1, Checks) && + combineRangeChecks(Checks, CombinedChecks)) { + if (InsertPt) { + Result = nullptr; + for (auto &RC : CombinedChecks) { + makeAvailableAt(RC.getCheckInst(), InsertPt); + if (Result) + Result = BinaryOperator::CreateAnd(RC.getCheckInst(), Result, "", + InsertPt); + else + Result = RC.getCheckInst(); + } + + Result->setName("wide.chk"); + } + return true; + } + } + + // Base case -- just logical-and the two conditions together. + + if (InsertPt) { + makeAvailableAt(Cond0, InsertPt); + makeAvailableAt(Cond1, InsertPt); + + Result = BinaryOperator::CreateAnd(Cond0, Cond1, "wide.chk", InsertPt); + } + + // We were not able to compute Cond0 AND Cond1 for the price of one. + return false; +} + +bool GuardWideningImpl::parseRangeChecks( + Value *CheckCond, SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited) { + if (!Visited.insert(CheckCond).second) + return true; + + using namespace llvm::PatternMatch; + + { + Value *AndLHS, *AndRHS; + if (match(CheckCond, m_And(m_Value(AndLHS), m_Value(AndRHS)))) + return parseRangeChecks(AndLHS, Checks) && + parseRangeChecks(AndRHS, Checks); + } + + auto *IC = dyn_cast<ICmpInst>(CheckCond); + if (!IC || !IC->getOperand(0)->getType()->isIntegerTy() || + (IC->getPredicate() != ICmpInst::ICMP_ULT && + IC->getPredicate() != ICmpInst::ICMP_UGT)) + return false; + + Value *CmpLHS = IC->getOperand(0), *CmpRHS = IC->getOperand(1); + if (IC->getPredicate() == ICmpInst::ICMP_UGT) + std::swap(CmpLHS, CmpRHS); + + auto &DL = IC->getModule()->getDataLayout(); + + GuardWideningImpl::RangeCheck Check( + CmpLHS, cast<ConstantInt>(ConstantInt::getNullValue(CmpRHS->getType())), + CmpRHS, IC); + + if (!isKnownNonNegative(Check.getLength(), DL)) + return false; + + // What we have in \c Check now is a correct interpretation of \p CheckCond. + // Try to see if we can move some constant offsets into the \c Offset field. + + bool Changed; + auto &Ctx = CheckCond->getContext(); + + do { + Value *OpLHS; + ConstantInt *OpRHS; + Changed = false; + +#ifndef NDEBUG + auto *BaseInst = dyn_cast<Instruction>(Check.getBase()); + assert((!BaseInst || DT.isReachableFromEntry(BaseInst->getParent())) && + "Unreachable instruction?"); +#endif + + if (match(Check.getBase(), m_Add(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + Check.setBase(OpLHS); + APInt NewOffset = Check.getOffsetValue() + OpRHS->getValue(); + Check.setOffset(ConstantInt::get(Ctx, NewOffset)); + Changed = true; + } else if (match(Check.getBase(), + m_Or(m_Value(OpLHS), m_ConstantInt(OpRHS)))) { + KnownBits Known = computeKnownBits(OpLHS, DL); + if ((OpRHS->getValue() & Known.Zero) == OpRHS->getValue()) { + Check.setBase(OpLHS); + APInt NewOffset = Check.getOffsetValue() + OpRHS->getValue(); + Check.setOffset(ConstantInt::get(Ctx, NewOffset)); + Changed = true; + } + } + } while (Changed); + + Checks.push_back(Check); + return true; +} + +bool GuardWideningImpl::combineRangeChecks( + SmallVectorImpl<GuardWideningImpl::RangeCheck> &Checks, + SmallVectorImpl<GuardWideningImpl::RangeCheck> &RangeChecksOut) { + unsigned OldCount = Checks.size(); + while (!Checks.empty()) { + // Pick all of the range checks with a specific base and length, and try to + // merge them. + Value *CurrentBase = Checks.front().getBase(); + Value *CurrentLength = Checks.front().getLength(); + + SmallVector<GuardWideningImpl::RangeCheck, 3> CurrentChecks; + + auto IsCurrentCheck = [&](GuardWideningImpl::RangeCheck &RC) { + return RC.getBase() == CurrentBase && RC.getLength() == CurrentLength; + }; + + copy_if(Checks, std::back_inserter(CurrentChecks), IsCurrentCheck); + Checks.erase(remove_if(Checks, IsCurrentCheck), Checks.end()); + + assert(CurrentChecks.size() != 0 && "We know we have at least one!"); + + if (CurrentChecks.size() < 3) { + RangeChecksOut.insert(RangeChecksOut.end(), CurrentChecks.begin(), + CurrentChecks.end()); + continue; + } + + // CurrentChecks.size() will typically be 3 here, but so far there has been + // no need to hard-code that fact. + + std::sort(CurrentChecks.begin(), CurrentChecks.end(), + [&](const GuardWideningImpl::RangeCheck &LHS, + const GuardWideningImpl::RangeCheck &RHS) { + return LHS.getOffsetValue().slt(RHS.getOffsetValue()); + }); + + // Note: std::sort should not invalidate the ChecksStart iterator. + + ConstantInt *MinOffset = CurrentChecks.front().getOffset(), + *MaxOffset = CurrentChecks.back().getOffset(); + + unsigned BitWidth = MaxOffset->getValue().getBitWidth(); + if ((MaxOffset->getValue() - MinOffset->getValue()) + .ugt(APInt::getSignedMinValue(BitWidth))) + return false; + + APInt MaxDiff = MaxOffset->getValue() - MinOffset->getValue(); + const APInt &HighOffset = MaxOffset->getValue(); + auto OffsetOK = [&](const GuardWideningImpl::RangeCheck &RC) { + return (HighOffset - RC.getOffsetValue()).ult(MaxDiff); + }; + + if (MaxDiff.isMinValue() || + !std::all_of(std::next(CurrentChecks.begin()), CurrentChecks.end(), + OffsetOK)) + return false; + + // We have a series of f+1 checks as: + // + // I+k_0 u< L ... Chk_0 + // I+k_1 u< L ... Chk_1 + // ... + // I+k_f u< L ... Chk_f + // + // with forall i in [0,f]: k_f-k_i u< k_f-k_0 ... Precond_0 + // k_f-k_0 u< INT_MIN+k_f ... Precond_1 + // k_f != k_0 ... Precond_2 + // + // Claim: + // Chk_0 AND Chk_f implies all the other checks + // + // Informal proof sketch: + // + // We will show that the integer range [I+k_0,I+k_f] does not unsigned-wrap + // (i.e. going from I+k_0 to I+k_f does not cross the -1,0 boundary) and + // thus I+k_f is the greatest unsigned value in that range. + // + // This combined with Ckh_(f+1) shows that everything in that range is u< L. + // Via Precond_0 we know that all of the indices in Chk_0 through Chk_(f+1) + // lie in [I+k_0,I+k_f], this proving our claim. + // + // To see that [I+k_0,I+k_f] is not a wrapping range, note that there are + // two possibilities: I+k_0 u< I+k_f or I+k_0 >u I+k_f (they can't be equal + // since k_0 != k_f). In the former case, [I+k_0,I+k_f] is not a wrapping + // range by definition, and the latter case is impossible: + // + // 0-----I+k_f---I+k_0----L---INT_MAX,INT_MIN------------------(-1) + // xxxxxx xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx + // + // For Chk_0 to succeed, we'd have to have k_f-k_0 (the range highlighted + // with 'x' above) to be at least >u INT_MIN. + + RangeChecksOut.emplace_back(CurrentChecks.front()); + RangeChecksOut.emplace_back(CurrentChecks.back()); + } + + assert(RangeChecksOut.size() <= OldCount && "We pessimized!"); + return RangeChecksOut.size() != OldCount; +} + +PreservedAnalyses GuardWideningPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &PDT = AM.getResult<PostDominatorTreeAnalysis>(F); + if (!GuardWideningImpl(DT, PDT, LI).run()) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +#ifndef NDEBUG +StringRef GuardWideningImpl::scoreTypeToString(WideningScore WS) { + switch (WS) { + case WS_IllegalOrNegative: + return "IllegalOrNegative"; + case WS_Neutral: + return "Neutral"; + case WS_Positive: + return "Positive"; + case WS_VeryPositive: + return "VeryPositive"; + } + + llvm_unreachable("Fully covered switch above!"); +} +#endif + +char GuardWideningLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(GuardWideningLegacyPass, "guard-widening", "Widen guards", + false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(PostDominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(GuardWideningLegacyPass, "guard-widening", "Widen guards", + false, false) + +FunctionPass *llvm::createGuardWideningPass() { + return new GuardWideningLegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp b/contrib/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp new file mode 100644 index 000000000000..807593379283 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/IVUsersPrinter.cpp @@ -0,0 +1,22 @@ +//===- IVUsersPrinter.cpp - Induction Variable Users Printer ----*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/IVUsersPrinter.h" +#include "llvm/Analysis/IVUsers.h" +#include "llvm/Support/Debug.h" +using namespace llvm; + +#define DEBUG_TYPE "iv-users" + +PreservedAnalyses IVUsersPrinterPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + AM.getResult<IVUsersAnalysis>(L, AR).print(OS); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp new file mode 100644 index 000000000000..221fe57581ca --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -0,0 +1,2597 @@ +//===- IndVarSimplify.cpp - Induction Variable Elimination ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation analyzes and transforms the induction variables (and +// computations derived from them) into simpler forms suitable for subsequent +// analysis and transformation. +// +// If the trip count of a loop is computable, this pass also makes the following +// changes: +// 1. The exit condition for the loop is canonicalized to compare the +// induction value against the exit value. This turns loops like: +// 'for (i = 7; i*i < 1000; ++i)' into 'for (i = 0; i != 25; ++i)' +// 2. Any use outside of the loop of an expression derived from the indvar +// is changed to compute the derived value outside of the loop, eliminating +// the dependence on the exit value of the induction variable. If the only +// purpose of the loop is to compute the exit value of some derived +// expression, this transformation will make the loop dead. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/IndVarSimplify.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/SimplifyIndVar.h" +#include <cassert> +#include <cstdint> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "indvars" + +STATISTIC(NumWidened , "Number of indvars widened"); +STATISTIC(NumReplaced , "Number of exit values replaced"); +STATISTIC(NumLFTR , "Number of loop exit tests replaced"); +STATISTIC(NumElimExt , "Number of IV sign/zero extends eliminated"); +STATISTIC(NumElimIV , "Number of congruent IVs eliminated"); + +// Trip count verification can be enabled by default under NDEBUG if we +// implement a strong expression equivalence checker in SCEV. Until then, we +// use the verify-indvars flag, which may assert in some cases. +static cl::opt<bool> VerifyIndvars( + "verify-indvars", cl::Hidden, + cl::desc("Verify the ScalarEvolution result after running indvars")); + +enum ReplaceExitVal { NeverRepl, OnlyCheapRepl, AlwaysRepl }; + +static cl::opt<ReplaceExitVal> ReplaceExitValue( + "replexitval", cl::Hidden, cl::init(OnlyCheapRepl), + cl::desc("Choose the strategy to replace exit value in IndVarSimplify"), + cl::values(clEnumValN(NeverRepl, "never", "never replace exit value"), + clEnumValN(OnlyCheapRepl, "cheap", + "only replace exit value when the cost is cheap"), + clEnumValN(AlwaysRepl, "always", + "always replace exit value whenever possible"))); + +static cl::opt<bool> UsePostIncrementRanges( + "indvars-post-increment-ranges", cl::Hidden, + cl::desc("Use post increment control-dependent ranges in IndVarSimplify"), + cl::init(true)); + +static cl::opt<bool> +DisableLFTR("disable-lftr", cl::Hidden, cl::init(false), + cl::desc("Disable Linear Function Test Replace optimization")); + +namespace { + +struct RewritePhi; + +class IndVarSimplify { + LoopInfo *LI; + ScalarEvolution *SE; + DominatorTree *DT; + const DataLayout &DL; + TargetLibraryInfo *TLI; + const TargetTransformInfo *TTI; + + SmallVector<WeakTrackingVH, 16> DeadInsts; + bool Changed = false; + + bool isValidRewrite(Value *FromVal, Value *ToVal); + + void handleFloatingPointIV(Loop *L, PHINode *PH); + void rewriteNonIntegerIVs(Loop *L); + + void simplifyAndExtend(Loop *L, SCEVExpander &Rewriter, LoopInfo *LI); + + bool canLoopBeDeleted(Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet); + void rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter); + void rewriteFirstIterationLoopExitValues(Loop *L); + + Value *linearFunctionTestReplace(Loop *L, const SCEV *BackedgeTakenCount, + PHINode *IndVar, SCEVExpander &Rewriter); + + void sinkUnusedInvariants(Loop *L); + + Value *expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, Loop *L, + Instruction *InsertPt, Type *Ty); + +public: + IndVarSimplify(LoopInfo *LI, ScalarEvolution *SE, DominatorTree *DT, + const DataLayout &DL, TargetLibraryInfo *TLI, + TargetTransformInfo *TTI) + : LI(LI), SE(SE), DT(DT), DL(DL), TLI(TLI), TTI(TTI) {} + + bool run(Loop *L); +}; + +} // end anonymous namespace + +/// Return true if the SCEV expansion generated by the rewriter can replace the +/// original value. SCEV guarantees that it produces the same value, but the way +/// it is produced may be illegal IR. Ideally, this function will only be +/// called for verification. +bool IndVarSimplify::isValidRewrite(Value *FromVal, Value *ToVal) { + // If an SCEV expression subsumed multiple pointers, its expansion could + // reassociate the GEP changing the base pointer. This is illegal because the + // final address produced by a GEP chain must be inbounds relative to its + // underlying object. Otherwise basic alias analysis, among other things, + // could fail in a dangerous way. Ultimately, SCEV will be improved to avoid + // producing an expression involving multiple pointers. Until then, we must + // bail out here. + // + // Retrieve the pointer operand of the GEP. Don't use GetUnderlyingObject + // because it understands lcssa phis while SCEV does not. + Value *FromPtr = FromVal; + Value *ToPtr = ToVal; + if (auto *GEP = dyn_cast<GEPOperator>(FromVal)) { + FromPtr = GEP->getPointerOperand(); + } + if (auto *GEP = dyn_cast<GEPOperator>(ToVal)) { + ToPtr = GEP->getPointerOperand(); + } + if (FromPtr != FromVal || ToPtr != ToVal) { + // Quickly check the common case + if (FromPtr == ToPtr) + return true; + + // SCEV may have rewritten an expression that produces the GEP's pointer + // operand. That's ok as long as the pointer operand has the same base + // pointer. Unlike GetUnderlyingObject(), getPointerBase() will find the + // base of a recurrence. This handles the case in which SCEV expansion + // converts a pointer type recurrence into a nonrecurrent pointer base + // indexed by an integer recurrence. + + // If the GEP base pointer is a vector of pointers, abort. + if (!FromPtr->getType()->isPointerTy() || !ToPtr->getType()->isPointerTy()) + return false; + + const SCEV *FromBase = SE->getPointerBase(SE->getSCEV(FromPtr)); + const SCEV *ToBase = SE->getPointerBase(SE->getSCEV(ToPtr)); + if (FromBase == ToBase) + return true; + + DEBUG(dbgs() << "INDVARS: GEP rewrite bail out " + << *FromBase << " != " << *ToBase << "\n"); + + return false; + } + return true; +} + +/// Determine the insertion point for this user. By default, insert immediately +/// before the user. SCEVExpander or LICM will hoist loop invariants out of the +/// loop. For PHI nodes, there may be multiple uses, so compute the nearest +/// common dominator for the incoming blocks. +static Instruction *getInsertPointForUses(Instruction *User, Value *Def, + DominatorTree *DT, LoopInfo *LI) { + PHINode *PHI = dyn_cast<PHINode>(User); + if (!PHI) + return User; + + Instruction *InsertPt = nullptr; + for (unsigned i = 0, e = PHI->getNumIncomingValues(); i != e; ++i) { + if (PHI->getIncomingValue(i) != Def) + continue; + + BasicBlock *InsertBB = PHI->getIncomingBlock(i); + if (!InsertPt) { + InsertPt = InsertBB->getTerminator(); + continue; + } + InsertBB = DT->findNearestCommonDominator(InsertPt->getParent(), InsertBB); + InsertPt = InsertBB->getTerminator(); + } + assert(InsertPt && "Missing phi operand"); + + auto *DefI = dyn_cast<Instruction>(Def); + if (!DefI) + return InsertPt; + + assert(DT->dominates(DefI, InsertPt) && "def does not dominate all uses"); + + auto *L = LI->getLoopFor(DefI->getParent()); + assert(!L || L->contains(LI->getLoopFor(InsertPt->getParent()))); + + for (auto *DTN = (*DT)[InsertPt->getParent()]; DTN; DTN = DTN->getIDom()) + if (LI->getLoopFor(DTN->getBlock()) == L) + return DTN->getBlock()->getTerminator(); + + llvm_unreachable("DefI dominates InsertPt!"); +} + +//===----------------------------------------------------------------------===// +// rewriteNonIntegerIVs and helpers. Prefer integer IVs. +//===----------------------------------------------------------------------===// + +/// Convert APF to an integer, if possible. +static bool ConvertToSInt(const APFloat &APF, int64_t &IntVal) { + bool isExact = false; + // See if we can convert this to an int64_t + uint64_t UIntVal; + if (APF.convertToInteger(makeMutableArrayRef(UIntVal), 64, true, + APFloat::rmTowardZero, &isExact) != APFloat::opOK || + !isExact) + return false; + IntVal = UIntVal; + return true; +} + +/// If the loop has floating induction variable then insert corresponding +/// integer induction variable if possible. +/// For example, +/// for(double i = 0; i < 10000; ++i) +/// bar(i) +/// is converted into +/// for(int i = 0; i < 10000; ++i) +/// bar((double)i); +void IndVarSimplify::handleFloatingPointIV(Loop *L, PHINode *PN) { + unsigned IncomingEdge = L->contains(PN->getIncomingBlock(0)); + unsigned BackEdge = IncomingEdge^1; + + // Check incoming value. + auto *InitValueVal = dyn_cast<ConstantFP>(PN->getIncomingValue(IncomingEdge)); + + int64_t InitValue; + if (!InitValueVal || !ConvertToSInt(InitValueVal->getValueAPF(), InitValue)) + return; + + // Check IV increment. Reject this PN if increment operation is not + // an add or increment value can not be represented by an integer. + auto *Incr = dyn_cast<BinaryOperator>(PN->getIncomingValue(BackEdge)); + if (Incr == nullptr || Incr->getOpcode() != Instruction::FAdd) return; + + // If this is not an add of the PHI with a constantfp, or if the constant fp + // is not an integer, bail out. + ConstantFP *IncValueVal = dyn_cast<ConstantFP>(Incr->getOperand(1)); + int64_t IncValue; + if (IncValueVal == nullptr || Incr->getOperand(0) != PN || + !ConvertToSInt(IncValueVal->getValueAPF(), IncValue)) + return; + + // Check Incr uses. One user is PN and the other user is an exit condition + // used by the conditional terminator. + Value::user_iterator IncrUse = Incr->user_begin(); + Instruction *U1 = cast<Instruction>(*IncrUse++); + if (IncrUse == Incr->user_end()) return; + Instruction *U2 = cast<Instruction>(*IncrUse++); + if (IncrUse != Incr->user_end()) return; + + // Find exit condition, which is an fcmp. If it doesn't exist, or if it isn't + // only used by a branch, we can't transform it. + FCmpInst *Compare = dyn_cast<FCmpInst>(U1); + if (!Compare) + Compare = dyn_cast<FCmpInst>(U2); + if (!Compare || !Compare->hasOneUse() || + !isa<BranchInst>(Compare->user_back())) + return; + + BranchInst *TheBr = cast<BranchInst>(Compare->user_back()); + + // We need to verify that the branch actually controls the iteration count + // of the loop. If not, the new IV can overflow and no one will notice. + // The branch block must be in the loop and one of the successors must be out + // of the loop. + assert(TheBr->isConditional() && "Can't use fcmp if not conditional"); + if (!L->contains(TheBr->getParent()) || + (L->contains(TheBr->getSuccessor(0)) && + L->contains(TheBr->getSuccessor(1)))) + return; + + // If it isn't a comparison with an integer-as-fp (the exit value), we can't + // transform it. + ConstantFP *ExitValueVal = dyn_cast<ConstantFP>(Compare->getOperand(1)); + int64_t ExitValue; + if (ExitValueVal == nullptr || + !ConvertToSInt(ExitValueVal->getValueAPF(), ExitValue)) + return; + + // Find new predicate for integer comparison. + CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE; + switch (Compare->getPredicate()) { + default: return; // Unknown comparison. + case CmpInst::FCMP_OEQ: + case CmpInst::FCMP_UEQ: NewPred = CmpInst::ICMP_EQ; break; + case CmpInst::FCMP_ONE: + case CmpInst::FCMP_UNE: NewPred = CmpInst::ICMP_NE; break; + case CmpInst::FCMP_OGT: + case CmpInst::FCMP_UGT: NewPred = CmpInst::ICMP_SGT; break; + case CmpInst::FCMP_OGE: + case CmpInst::FCMP_UGE: NewPred = CmpInst::ICMP_SGE; break; + case CmpInst::FCMP_OLT: + case CmpInst::FCMP_ULT: NewPred = CmpInst::ICMP_SLT; break; + case CmpInst::FCMP_OLE: + case CmpInst::FCMP_ULE: NewPred = CmpInst::ICMP_SLE; break; + } + + // We convert the floating point induction variable to a signed i32 value if + // we can. This is only safe if the comparison will not overflow in a way + // that won't be trapped by the integer equivalent operations. Check for this + // now. + // TODO: We could use i64 if it is native and the range requires it. + + // The start/stride/exit values must all fit in signed i32. + if (!isInt<32>(InitValue) || !isInt<32>(IncValue) || !isInt<32>(ExitValue)) + return; + + // If not actually striding (add x, 0.0), avoid touching the code. + if (IncValue == 0) + return; + + // Positive and negative strides have different safety conditions. + if (IncValue > 0) { + // If we have a positive stride, we require the init to be less than the + // exit value. + if (InitValue >= ExitValue) + return; + + uint32_t Range = uint32_t(ExitValue-InitValue); + // Check for infinite loop, either: + // while (i <= Exit) or until (i > Exit) + if (NewPred == CmpInst::ICMP_SLE || NewPred == CmpInst::ICMP_SGT) { + if (++Range == 0) return; // Range overflows. + } + + unsigned Leftover = Range % uint32_t(IncValue); + + // If this is an equality comparison, we require that the strided value + // exactly land on the exit value, otherwise the IV condition will wrap + // around and do things the fp IV wouldn't. + if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && + Leftover != 0) + return; + + // If the stride would wrap around the i32 before exiting, we can't + // transform the IV. + if (Leftover != 0 && int32_t(ExitValue+IncValue) < ExitValue) + return; + } else { + // If we have a negative stride, we require the init to be greater than the + // exit value. + if (InitValue <= ExitValue) + return; + + uint32_t Range = uint32_t(InitValue-ExitValue); + // Check for infinite loop, either: + // while (i >= Exit) or until (i < Exit) + if (NewPred == CmpInst::ICMP_SGE || NewPred == CmpInst::ICMP_SLT) { + if (++Range == 0) return; // Range overflows. + } + + unsigned Leftover = Range % uint32_t(-IncValue); + + // If this is an equality comparison, we require that the strided value + // exactly land on the exit value, otherwise the IV condition will wrap + // around and do things the fp IV wouldn't. + if ((NewPred == CmpInst::ICMP_EQ || NewPred == CmpInst::ICMP_NE) && + Leftover != 0) + return; + + // If the stride would wrap around the i32 before exiting, we can't + // transform the IV. + if (Leftover != 0 && int32_t(ExitValue+IncValue) > ExitValue) + return; + } + + IntegerType *Int32Ty = Type::getInt32Ty(PN->getContext()); + + // Insert new integer induction variable. + PHINode *NewPHI = PHINode::Create(Int32Ty, 2, PN->getName()+".int", PN); + NewPHI->addIncoming(ConstantInt::get(Int32Ty, InitValue), + PN->getIncomingBlock(IncomingEdge)); + + Value *NewAdd = + BinaryOperator::CreateAdd(NewPHI, ConstantInt::get(Int32Ty, IncValue), + Incr->getName()+".int", Incr); + NewPHI->addIncoming(NewAdd, PN->getIncomingBlock(BackEdge)); + + ICmpInst *NewCompare = new ICmpInst(TheBr, NewPred, NewAdd, + ConstantInt::get(Int32Ty, ExitValue), + Compare->getName()); + + // In the following deletions, PN may become dead and may be deleted. + // Use a WeakTrackingVH to observe whether this happens. + WeakTrackingVH WeakPH = PN; + + // Delete the old floating point exit comparison. The branch starts using the + // new comparison. + NewCompare->takeName(Compare); + Compare->replaceAllUsesWith(NewCompare); + RecursivelyDeleteTriviallyDeadInstructions(Compare, TLI); + + // Delete the old floating point increment. + Incr->replaceAllUsesWith(UndefValue::get(Incr->getType())); + RecursivelyDeleteTriviallyDeadInstructions(Incr, TLI); + + // If the FP induction variable still has uses, this is because something else + // in the loop uses its value. In order to canonicalize the induction + // variable, we chose to eliminate the IV and rewrite it in terms of an + // int->fp cast. + // + // We give preference to sitofp over uitofp because it is faster on most + // platforms. + if (WeakPH) { + Value *Conv = new SIToFPInst(NewPHI, PN->getType(), "indvar.conv", + &*PN->getParent()->getFirstInsertionPt()); + PN->replaceAllUsesWith(Conv); + RecursivelyDeleteTriviallyDeadInstructions(PN, TLI); + } + Changed = true; +} + +void IndVarSimplify::rewriteNonIntegerIVs(Loop *L) { + // First step. Check to see if there are any floating-point recurrences. + // If there are, change them into integer recurrences, permitting analysis by + // the SCEV routines. + BasicBlock *Header = L->getHeader(); + + SmallVector<WeakTrackingVH, 8> PHIs; + for (PHINode &PN : Header->phis()) + PHIs.push_back(&PN); + + for (unsigned i = 0, e = PHIs.size(); i != e; ++i) + if (PHINode *PN = dyn_cast_or_null<PHINode>(&*PHIs[i])) + handleFloatingPointIV(L, PN); + + // If the loop previously had floating-point IV, ScalarEvolution + // may not have been able to compute a trip count. Now that we've done some + // re-writing, the trip count may be computable. + if (Changed) + SE->forgetLoop(L); +} + +namespace { + +// Collect information about PHI nodes which can be transformed in +// rewriteLoopExitValues. +struct RewritePhi { + PHINode *PN; + + // Ith incoming value. + unsigned Ith; + + // Exit value after expansion. + Value *Val; + + // High Cost when expansion. + bool HighCost; + + RewritePhi(PHINode *P, unsigned I, Value *V, bool H) + : PN(P), Ith(I), Val(V), HighCost(H) {} +}; + +} // end anonymous namespace + +Value *IndVarSimplify::expandSCEVIfNeeded(SCEVExpander &Rewriter, const SCEV *S, + Loop *L, Instruction *InsertPt, + Type *ResultTy) { + // Before expanding S into an expensive LLVM expression, see if we can use an + // already existing value as the expansion for S. + if (Value *ExistingValue = Rewriter.getExactExistingExpansion(S, InsertPt, L)) + if (ExistingValue->getType() == ResultTy) + return ExistingValue; + + // We didn't find anything, fall back to using SCEVExpander. + return Rewriter.expandCodeFor(S, ResultTy, InsertPt); +} + +//===----------------------------------------------------------------------===// +// rewriteLoopExitValues - Optimize IV users outside the loop. +// As a side effect, reduces the amount of IV processing within the loop. +//===----------------------------------------------------------------------===// + +/// Check to see if this loop has a computable loop-invariant execution count. +/// If so, this means that we can compute the final value of any expressions +/// that are recurrent in the loop, and substitute the exit values from the loop +/// into any instructions outside of the loop that use the final values of the +/// current expressions. +/// +/// This is mostly redundant with the regular IndVarSimplify activities that +/// happen later, except that it's more powerful in some cases, because it's +/// able to brute-force evaluate arbitrary instructions as long as they have +/// constant operands at the beginning of the loop. +void IndVarSimplify::rewriteLoopExitValues(Loop *L, SCEVExpander &Rewriter) { + // Check a pre-condition. + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "Indvars did not preserve LCSSA!"); + + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + + SmallVector<RewritePhi, 8> RewritePhiSet; + // Find all values that are computed inside the loop, but used outside of it. + // Because of LCSSA, these values will only occur in LCSSA PHI Nodes. Scan + // the exit blocks of the loop to find them. + for (BasicBlock *ExitBB : ExitBlocks) { + // If there are no PHI nodes in this exit block, then no values defined + // inside the loop are used on this path, skip it. + PHINode *PN = dyn_cast<PHINode>(ExitBB->begin()); + if (!PN) continue; + + unsigned NumPreds = PN->getNumIncomingValues(); + + // Iterate over all of the PHI nodes. + BasicBlock::iterator BBI = ExitBB->begin(); + while ((PN = dyn_cast<PHINode>(BBI++))) { + if (PN->use_empty()) + continue; // dead use, don't replace it + + if (!SE->isSCEVable(PN->getType())) + continue; + + // It's necessary to tell ScalarEvolution about this explicitly so that + // it can walk the def-use list and forget all SCEVs, as it may not be + // watching the PHI itself. Once the new exit value is in place, there + // may not be a def-use connection between the loop and every instruction + // which got a SCEVAddRecExpr for that loop. + SE->forgetValue(PN); + + // Iterate over all of the values in all the PHI nodes. + for (unsigned i = 0; i != NumPreds; ++i) { + // If the value being merged in is not integer or is not defined + // in the loop, skip it. + Value *InVal = PN->getIncomingValue(i); + if (!isa<Instruction>(InVal)) + continue; + + // If this pred is for a subloop, not L itself, skip it. + if (LI->getLoopFor(PN->getIncomingBlock(i)) != L) + continue; // The Block is in a subloop, skip it. + + // Check that InVal is defined in the loop. + Instruction *Inst = cast<Instruction>(InVal); + if (!L->contains(Inst)) + continue; + + // Okay, this instruction has a user outside of the current loop + // and varies predictably *inside* the loop. Evaluate the value it + // contains when the loop exits, if possible. + const SCEV *ExitValue = SE->getSCEVAtScope(Inst, L->getParentLoop()); + if (!SE->isLoopInvariant(ExitValue, L) || + !isSafeToExpand(ExitValue, *SE)) + continue; + + // Computing the value outside of the loop brings no benefit if : + // - it is definitely used inside the loop in a way which can not be + // optimized away. + // - no use outside of the loop can take advantage of hoisting the + // computation out of the loop + if (ExitValue->getSCEVType()>=scMulExpr) { + unsigned NumHardInternalUses = 0; + unsigned NumSoftExternalUses = 0; + unsigned NumUses = 0; + for (auto IB = Inst->user_begin(), IE = Inst->user_end(); + IB != IE && NumUses <= 6; ++IB) { + Instruction *UseInstr = cast<Instruction>(*IB); + unsigned Opc = UseInstr->getOpcode(); + NumUses++; + if (L->contains(UseInstr)) { + if (Opc == Instruction::Call || Opc == Instruction::Ret) + NumHardInternalUses++; + } else { + if (Opc == Instruction::PHI) { + // Do not count the Phi as a use. LCSSA may have inserted + // plenty of trivial ones. + NumUses--; + for (auto PB = UseInstr->user_begin(), + PE = UseInstr->user_end(); + PB != PE && NumUses <= 6; ++PB, ++NumUses) { + unsigned PhiOpc = cast<Instruction>(*PB)->getOpcode(); + if (PhiOpc != Instruction::Call && PhiOpc != Instruction::Ret) + NumSoftExternalUses++; + } + continue; + } + if (Opc != Instruction::Call && Opc != Instruction::Ret) + NumSoftExternalUses++; + } + } + if (NumUses <= 6 && NumHardInternalUses && !NumSoftExternalUses) + continue; + } + + bool HighCost = Rewriter.isHighCostExpansion(ExitValue, L, Inst); + Value *ExitVal = + expandSCEVIfNeeded(Rewriter, ExitValue, L, Inst, PN->getType()); + + DEBUG(dbgs() << "INDVARS: RLEV: AfterLoopVal = " << *ExitVal << '\n' + << " LoopVal = " << *Inst << "\n"); + + if (!isValidRewrite(Inst, ExitVal)) { + DeadInsts.push_back(ExitVal); + continue; + } + + // Collect all the candidate PHINodes to be rewritten. + RewritePhiSet.emplace_back(PN, i, ExitVal, HighCost); + } + } + } + + bool LoopCanBeDel = canLoopBeDeleted(L, RewritePhiSet); + + // Transformation. + for (const RewritePhi &Phi : RewritePhiSet) { + PHINode *PN = Phi.PN; + Value *ExitVal = Phi.Val; + + // Only do the rewrite when the ExitValue can be expanded cheaply. + // If LoopCanBeDel is true, rewrite exit value aggressively. + if (ReplaceExitValue == OnlyCheapRepl && !LoopCanBeDel && Phi.HighCost) { + DeadInsts.push_back(ExitVal); + continue; + } + + Changed = true; + ++NumReplaced; + Instruction *Inst = cast<Instruction>(PN->getIncomingValue(Phi.Ith)); + PN->setIncomingValue(Phi.Ith, ExitVal); + + // If this instruction is dead now, delete it. Don't do it now to avoid + // invalidating iterators. + if (isInstructionTriviallyDead(Inst, TLI)) + DeadInsts.push_back(Inst); + + // Replace PN with ExitVal if that is legal and does not break LCSSA. + if (PN->getNumIncomingValues() == 1 && + LI->replacementPreservesLCSSAForm(PN, ExitVal)) { + PN->replaceAllUsesWith(ExitVal); + PN->eraseFromParent(); + } + } + + // The insertion point instruction may have been deleted; clear it out + // so that the rewriter doesn't trip over it later. + Rewriter.clearInsertPoint(); +} + +//===---------------------------------------------------------------------===// +// rewriteFirstIterationLoopExitValues: Rewrite loop exit values if we know +// they will exit at the first iteration. +//===---------------------------------------------------------------------===// + +/// Check to see if this loop has loop invariant conditions which lead to loop +/// exits. If so, we know that if the exit path is taken, it is at the first +/// loop iteration. This lets us predict exit values of PHI nodes that live in +/// loop header. +void IndVarSimplify::rewriteFirstIterationLoopExitValues(Loop *L) { + // Verify the input to the pass is already in LCSSA form. + assert(L->isLCSSAForm(*DT)); + + SmallVector<BasicBlock *, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + auto *LoopHeader = L->getHeader(); + assert(LoopHeader && "Invalid loop"); + + for (auto *ExitBB : ExitBlocks) { + // If there are no more PHI nodes in this exit block, then no more + // values defined inside the loop are used on this path. + for (PHINode &PN : ExitBB->phis()) { + for (unsigned IncomingValIdx = 0, E = PN.getNumIncomingValues(); + IncomingValIdx != E; ++IncomingValIdx) { + auto *IncomingBB = PN.getIncomingBlock(IncomingValIdx); + + // We currently only support loop exits from loop header. If the + // incoming block is not loop header, we need to recursively check + // all conditions starting from loop header are loop invariants. + // Additional support might be added in the future. + if (IncomingBB != LoopHeader) + continue; + + // Get condition that leads to the exit path. + auto *TermInst = IncomingBB->getTerminator(); + + Value *Cond = nullptr; + if (auto *BI = dyn_cast<BranchInst>(TermInst)) { + // Must be a conditional branch, otherwise the block + // should not be in the loop. + Cond = BI->getCondition(); + } else if (auto *SI = dyn_cast<SwitchInst>(TermInst)) + Cond = SI->getCondition(); + else + continue; + + if (!L->isLoopInvariant(Cond)) + continue; + + auto *ExitVal = dyn_cast<PHINode>(PN.getIncomingValue(IncomingValIdx)); + + // Only deal with PHIs. + if (!ExitVal) + continue; + + // If ExitVal is a PHI on the loop header, then we know its + // value along this exit because the exit can only be taken + // on the first iteration. + auto *LoopPreheader = L->getLoopPreheader(); + assert(LoopPreheader && "Invalid loop"); + int PreheaderIdx = ExitVal->getBasicBlockIndex(LoopPreheader); + if (PreheaderIdx != -1) { + assert(ExitVal->getParent() == LoopHeader && + "ExitVal must be in loop header"); + PN.setIncomingValue(IncomingValIdx, + ExitVal->getIncomingValue(PreheaderIdx)); + } + } + } + } +} + +/// Check whether it is possible to delete the loop after rewriting exit +/// value. If it is possible, ignore ReplaceExitValue and do rewriting +/// aggressively. +bool IndVarSimplify::canLoopBeDeleted( + Loop *L, SmallVector<RewritePhi, 8> &RewritePhiSet) { + BasicBlock *Preheader = L->getLoopPreheader(); + // If there is no preheader, the loop will not be deleted. + if (!Preheader) + return false; + + // In LoopDeletion pass Loop can be deleted when ExitingBlocks.size() > 1. + // We obviate multiple ExitingBlocks case for simplicity. + // TODO: If we see testcase with multiple ExitingBlocks can be deleted + // after exit value rewriting, we can enhance the logic here. + SmallVector<BasicBlock *, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + SmallVector<BasicBlock *, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + if (ExitBlocks.size() > 1 || ExitingBlocks.size() > 1) + return false; + + BasicBlock *ExitBlock = ExitBlocks[0]; + BasicBlock::iterator BI = ExitBlock->begin(); + while (PHINode *P = dyn_cast<PHINode>(BI)) { + Value *Incoming = P->getIncomingValueForBlock(ExitingBlocks[0]); + + // If the Incoming value of P is found in RewritePhiSet, we know it + // could be rewritten to use a loop invariant value in transformation + // phase later. Skip it in the loop invariant check below. + bool found = false; + for (const RewritePhi &Phi : RewritePhiSet) { + unsigned i = Phi.Ith; + if (Phi.PN == P && (Phi.PN)->getIncomingValue(i) == Incoming) { + found = true; + break; + } + } + + Instruction *I; + if (!found && (I = dyn_cast<Instruction>(Incoming))) + if (!L->hasLoopInvariantOperands(I)) + return false; + + ++BI; + } + + for (auto *BB : L->blocks()) + if (llvm::any_of(*BB, [](Instruction &I) { + return I.mayHaveSideEffects(); + })) + return false; + + return true; +} + +//===----------------------------------------------------------------------===// +// IV Widening - Extend the width of an IV to cover its widest uses. +//===----------------------------------------------------------------------===// + +namespace { + +// Collect information about induction variables that are used by sign/zero +// extend operations. This information is recorded by CollectExtend and provides +// the input to WidenIV. +struct WideIVInfo { + PHINode *NarrowIV = nullptr; + + // Widest integer type created [sz]ext + Type *WidestNativeType = nullptr; + + // Was a sext user seen before a zext? + bool IsSigned = false; +}; + +} // end anonymous namespace + +/// Update information about the induction variable that is extended by this +/// sign or zero extend operation. This is used to determine the final width of +/// the IV before actually widening it. +static void visitIVCast(CastInst *Cast, WideIVInfo &WI, ScalarEvolution *SE, + const TargetTransformInfo *TTI) { + bool IsSigned = Cast->getOpcode() == Instruction::SExt; + if (!IsSigned && Cast->getOpcode() != Instruction::ZExt) + return; + + Type *Ty = Cast->getType(); + uint64_t Width = SE->getTypeSizeInBits(Ty); + if (!Cast->getModule()->getDataLayout().isLegalInteger(Width)) + return; + + // Check that `Cast` actually extends the induction variable (we rely on this + // later). This takes care of cases where `Cast` is extending a truncation of + // the narrow induction variable, and thus can end up being narrower than the + // "narrow" induction variable. + uint64_t NarrowIVWidth = SE->getTypeSizeInBits(WI.NarrowIV->getType()); + if (NarrowIVWidth >= Width) + return; + + // Cast is either an sext or zext up to this point. + // We should not widen an indvar if arithmetics on the wider indvar are more + // expensive than those on the narrower indvar. We check only the cost of ADD + // because at least an ADD is required to increment the induction variable. We + // could compute more comprehensively the cost of all instructions on the + // induction variable when necessary. + if (TTI && + TTI->getArithmeticInstrCost(Instruction::Add, Ty) > + TTI->getArithmeticInstrCost(Instruction::Add, + Cast->getOperand(0)->getType())) { + return; + } + + if (!WI.WidestNativeType) { + WI.WidestNativeType = SE->getEffectiveSCEVType(Ty); + WI.IsSigned = IsSigned; + return; + } + + // We extend the IV to satisfy the sign of its first user, arbitrarily. + if (WI.IsSigned != IsSigned) + return; + + if (Width > SE->getTypeSizeInBits(WI.WidestNativeType)) + WI.WidestNativeType = SE->getEffectiveSCEVType(Ty); +} + +namespace { + +/// Record a link in the Narrow IV def-use chain along with the WideIV that +/// computes the same value as the Narrow IV def. This avoids caching Use* +/// pointers. +struct NarrowIVDefUse { + Instruction *NarrowDef = nullptr; + Instruction *NarrowUse = nullptr; + Instruction *WideDef = nullptr; + + // True if the narrow def is never negative. Tracking this information lets + // us use a sign extension instead of a zero extension or vice versa, when + // profitable and legal. + bool NeverNegative = false; + + NarrowIVDefUse(Instruction *ND, Instruction *NU, Instruction *WD, + bool NeverNegative) + : NarrowDef(ND), NarrowUse(NU), WideDef(WD), + NeverNegative(NeverNegative) {} +}; + +/// The goal of this transform is to remove sign and zero extends without +/// creating any new induction variables. To do this, it creates a new phi of +/// the wider type and redirects all users, either removing extends or inserting +/// truncs whenever we stop propagating the type. +class WidenIV { + // Parameters + PHINode *OrigPhi; + Type *WideType; + + // Context + LoopInfo *LI; + Loop *L; + ScalarEvolution *SE; + DominatorTree *DT; + + // Does the module have any calls to the llvm.experimental.guard intrinsic + // at all? If not we can avoid scanning instructions looking for guards. + bool HasGuards; + + // Result + PHINode *WidePhi = nullptr; + Instruction *WideInc = nullptr; + const SCEV *WideIncExpr = nullptr; + SmallVectorImpl<WeakTrackingVH> &DeadInsts; + + SmallPtrSet<Instruction *,16> Widened; + SmallVector<NarrowIVDefUse, 8> NarrowIVUsers; + + enum ExtendKind { ZeroExtended, SignExtended, Unknown }; + + // A map tracking the kind of extension used to widen each narrow IV + // and narrow IV user. + // Key: pointer to a narrow IV or IV user. + // Value: the kind of extension used to widen this Instruction. + DenseMap<AssertingVH<Instruction>, ExtendKind> ExtendKindMap; + + using DefUserPair = std::pair<AssertingVH<Value>, AssertingVH<Instruction>>; + + // A map with control-dependent ranges for post increment IV uses. The key is + // a pair of IV def and a use of this def denoting the context. The value is + // a ConstantRange representing possible values of the def at the given + // context. + DenseMap<DefUserPair, ConstantRange> PostIncRangeInfos; + + Optional<ConstantRange> getPostIncRangeInfo(Value *Def, + Instruction *UseI) { + DefUserPair Key(Def, UseI); + auto It = PostIncRangeInfos.find(Key); + return It == PostIncRangeInfos.end() + ? Optional<ConstantRange>(None) + : Optional<ConstantRange>(It->second); + } + + void calculatePostIncRanges(PHINode *OrigPhi); + void calculatePostIncRange(Instruction *NarrowDef, Instruction *NarrowUser); + + void updatePostIncRangeInfo(Value *Def, Instruction *UseI, ConstantRange R) { + DefUserPair Key(Def, UseI); + auto It = PostIncRangeInfos.find(Key); + if (It == PostIncRangeInfos.end()) + PostIncRangeInfos.insert({Key, R}); + else + It->second = R.intersectWith(It->second); + } + +public: + WidenIV(const WideIVInfo &WI, LoopInfo *LInfo, ScalarEvolution *SEv, + DominatorTree *DTree, SmallVectorImpl<WeakTrackingVH> &DI, + bool HasGuards) + : OrigPhi(WI.NarrowIV), WideType(WI.WidestNativeType), LI(LInfo), + L(LI->getLoopFor(OrigPhi->getParent())), SE(SEv), DT(DTree), + HasGuards(HasGuards), DeadInsts(DI) { + assert(L->getHeader() == OrigPhi->getParent() && "Phi must be an IV"); + ExtendKindMap[OrigPhi] = WI.IsSigned ? SignExtended : ZeroExtended; + } + + PHINode *createWideIV(SCEVExpander &Rewriter); + +protected: + Value *createExtendInst(Value *NarrowOper, Type *WideType, bool IsSigned, + Instruction *Use); + + Instruction *cloneIVUser(NarrowIVDefUse DU, const SCEVAddRecExpr *WideAR); + Instruction *cloneArithmeticIVUser(NarrowIVDefUse DU, + const SCEVAddRecExpr *WideAR); + Instruction *cloneBitwiseIVUser(NarrowIVDefUse DU); + + ExtendKind getExtendKind(Instruction *I); + + using WidenedRecTy = std::pair<const SCEVAddRecExpr *, ExtendKind>; + + WidenedRecTy getWideRecurrence(NarrowIVDefUse DU); + + WidenedRecTy getExtendedOperandRecurrence(NarrowIVDefUse DU); + + const SCEV *getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, + unsigned OpCode) const; + + Instruction *widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter); + + bool widenLoopCompare(NarrowIVDefUse DU); + + void pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef); +}; + +} // end anonymous namespace + +/// Perform a quick domtree based check for loop invariance assuming that V is +/// used within the loop. LoopInfo::isLoopInvariant() seems gratuitous for this +/// purpose. +static bool isLoopInvariant(Value *V, const Loop *L, const DominatorTree *DT) { + Instruction *Inst = dyn_cast<Instruction>(V); + if (!Inst) + return true; + + return DT->properlyDominates(Inst->getParent(), L->getHeader()); +} + +Value *WidenIV::createExtendInst(Value *NarrowOper, Type *WideType, + bool IsSigned, Instruction *Use) { + // Set the debug location and conservative insertion point. + IRBuilder<> Builder(Use); + // Hoist the insertion point into loop preheaders as far as possible. + for (const Loop *L = LI->getLoopFor(Use->getParent()); + L && L->getLoopPreheader() && isLoopInvariant(NarrowOper, L, DT); + L = L->getParentLoop()) + Builder.SetInsertPoint(L->getLoopPreheader()->getTerminator()); + + return IsSigned ? Builder.CreateSExt(NarrowOper, WideType) : + Builder.CreateZExt(NarrowOper, WideType); +} + +/// Instantiate a wide operation to replace a narrow operation. This only needs +/// to handle operations that can evaluation to SCEVAddRec. It can safely return +/// 0 for any operation we decide not to clone. +Instruction *WidenIV::cloneIVUser(NarrowIVDefUse DU, + const SCEVAddRecExpr *WideAR) { + unsigned Opcode = DU.NarrowUse->getOpcode(); + switch (Opcode) { + default: + return nullptr; + case Instruction::Add: + case Instruction::Mul: + case Instruction::UDiv: + case Instruction::Sub: + return cloneArithmeticIVUser(DU, WideAR); + + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + return cloneBitwiseIVUser(DU); + } +} + +Instruction *WidenIV::cloneBitwiseIVUser(NarrowIVDefUse DU) { + Instruction *NarrowUse = DU.NarrowUse; + Instruction *NarrowDef = DU.NarrowDef; + Instruction *WideDef = DU.WideDef; + + DEBUG(dbgs() << "Cloning bitwise IVUser: " << *NarrowUse << "\n"); + + // Replace NarrowDef operands with WideDef. Otherwise, we don't know anything + // about the narrow operand yet so must insert a [sz]ext. It is probably loop + // invariant and will be folded or hoisted. If it actually comes from a + // widened IV, it should be removed during a future call to widenIVUse. + bool IsSigned = getExtendKind(NarrowDef) == SignExtended; + Value *LHS = (NarrowUse->getOperand(0) == NarrowDef) + ? WideDef + : createExtendInst(NarrowUse->getOperand(0), WideType, + IsSigned, NarrowUse); + Value *RHS = (NarrowUse->getOperand(1) == NarrowDef) + ? WideDef + : createExtendInst(NarrowUse->getOperand(1), WideType, + IsSigned, NarrowUse); + + auto *NarrowBO = cast<BinaryOperator>(NarrowUse); + auto *WideBO = BinaryOperator::Create(NarrowBO->getOpcode(), LHS, RHS, + NarrowBO->getName()); + IRBuilder<> Builder(NarrowUse); + Builder.Insert(WideBO); + WideBO->copyIRFlags(NarrowBO); + return WideBO; +} + +Instruction *WidenIV::cloneArithmeticIVUser(NarrowIVDefUse DU, + const SCEVAddRecExpr *WideAR) { + Instruction *NarrowUse = DU.NarrowUse; + Instruction *NarrowDef = DU.NarrowDef; + Instruction *WideDef = DU.WideDef; + + DEBUG(dbgs() << "Cloning arithmetic IVUser: " << *NarrowUse << "\n"); + + unsigned IVOpIdx = (NarrowUse->getOperand(0) == NarrowDef) ? 0 : 1; + + // We're trying to find X such that + // + // Widen(NarrowDef `op` NonIVNarrowDef) == WideAR == WideDef `op.wide` X + // + // We guess two solutions to X, sext(NonIVNarrowDef) and zext(NonIVNarrowDef), + // and check using SCEV if any of them are correct. + + // Returns true if extending NonIVNarrowDef according to `SignExt` is a + // correct solution to X. + auto GuessNonIVOperand = [&](bool SignExt) { + const SCEV *WideLHS; + const SCEV *WideRHS; + + auto GetExtend = [this, SignExt](const SCEV *S, Type *Ty) { + if (SignExt) + return SE->getSignExtendExpr(S, Ty); + return SE->getZeroExtendExpr(S, Ty); + }; + + if (IVOpIdx == 0) { + WideLHS = SE->getSCEV(WideDef); + const SCEV *NarrowRHS = SE->getSCEV(NarrowUse->getOperand(1)); + WideRHS = GetExtend(NarrowRHS, WideType); + } else { + const SCEV *NarrowLHS = SE->getSCEV(NarrowUse->getOperand(0)); + WideLHS = GetExtend(NarrowLHS, WideType); + WideRHS = SE->getSCEV(WideDef); + } + + // WideUse is "WideDef `op.wide` X" as described in the comment. + const SCEV *WideUse = nullptr; + + switch (NarrowUse->getOpcode()) { + default: + llvm_unreachable("No other possibility!"); + + case Instruction::Add: + WideUse = SE->getAddExpr(WideLHS, WideRHS); + break; + + case Instruction::Mul: + WideUse = SE->getMulExpr(WideLHS, WideRHS); + break; + + case Instruction::UDiv: + WideUse = SE->getUDivExpr(WideLHS, WideRHS); + break; + + case Instruction::Sub: + WideUse = SE->getMinusSCEV(WideLHS, WideRHS); + break; + } + + return WideUse == WideAR; + }; + + bool SignExtend = getExtendKind(NarrowDef) == SignExtended; + if (!GuessNonIVOperand(SignExtend)) { + SignExtend = !SignExtend; + if (!GuessNonIVOperand(SignExtend)) + return nullptr; + } + + Value *LHS = (NarrowUse->getOperand(0) == NarrowDef) + ? WideDef + : createExtendInst(NarrowUse->getOperand(0), WideType, + SignExtend, NarrowUse); + Value *RHS = (NarrowUse->getOperand(1) == NarrowDef) + ? WideDef + : createExtendInst(NarrowUse->getOperand(1), WideType, + SignExtend, NarrowUse); + + auto *NarrowBO = cast<BinaryOperator>(NarrowUse); + auto *WideBO = BinaryOperator::Create(NarrowBO->getOpcode(), LHS, RHS, + NarrowBO->getName()); + + IRBuilder<> Builder(NarrowUse); + Builder.Insert(WideBO); + WideBO->copyIRFlags(NarrowBO); + return WideBO; +} + +WidenIV::ExtendKind WidenIV::getExtendKind(Instruction *I) { + auto It = ExtendKindMap.find(I); + assert(It != ExtendKindMap.end() && "Instruction not yet extended!"); + return It->second; +} + +const SCEV *WidenIV::getSCEVByOpCode(const SCEV *LHS, const SCEV *RHS, + unsigned OpCode) const { + if (OpCode == Instruction::Add) + return SE->getAddExpr(LHS, RHS); + if (OpCode == Instruction::Sub) + return SE->getMinusSCEV(LHS, RHS); + if (OpCode == Instruction::Mul) + return SE->getMulExpr(LHS, RHS); + + llvm_unreachable("Unsupported opcode."); +} + +/// No-wrap operations can transfer sign extension of their result to their +/// operands. Generate the SCEV value for the widened operation without +/// actually modifying the IR yet. If the expression after extending the +/// operands is an AddRec for this loop, return the AddRec and the kind of +/// extension used. +WidenIV::WidenedRecTy WidenIV::getExtendedOperandRecurrence(NarrowIVDefUse DU) { + // Handle the common case of add<nsw/nuw> + const unsigned OpCode = DU.NarrowUse->getOpcode(); + // Only Add/Sub/Mul instructions supported yet. + if (OpCode != Instruction::Add && OpCode != Instruction::Sub && + OpCode != Instruction::Mul) + return {nullptr, Unknown}; + + // One operand (NarrowDef) has already been extended to WideDef. Now determine + // if extending the other will lead to a recurrence. + const unsigned ExtendOperIdx = + DU.NarrowUse->getOperand(0) == DU.NarrowDef ? 1 : 0; + assert(DU.NarrowUse->getOperand(1-ExtendOperIdx) == DU.NarrowDef && "bad DU"); + + const SCEV *ExtendOperExpr = nullptr; + const OverflowingBinaryOperator *OBO = + cast<OverflowingBinaryOperator>(DU.NarrowUse); + ExtendKind ExtKind = getExtendKind(DU.NarrowDef); + if (ExtKind == SignExtended && OBO->hasNoSignedWrap()) + ExtendOperExpr = SE->getSignExtendExpr( + SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); + else if(ExtKind == ZeroExtended && OBO->hasNoUnsignedWrap()) + ExtendOperExpr = SE->getZeroExtendExpr( + SE->getSCEV(DU.NarrowUse->getOperand(ExtendOperIdx)), WideType); + else + return {nullptr, Unknown}; + + // When creating this SCEV expr, don't apply the current operations NSW or NUW + // flags. This instruction may be guarded by control flow that the no-wrap + // behavior depends on. Non-control-equivalent instructions can be mapped to + // the same SCEV expression, and it would be incorrect to transfer NSW/NUW + // semantics to those operations. + const SCEV *lhs = SE->getSCEV(DU.WideDef); + const SCEV *rhs = ExtendOperExpr; + + // Let's swap operands to the initial order for the case of non-commutative + // operations, like SUB. See PR21014. + if (ExtendOperIdx == 0) + std::swap(lhs, rhs); + const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(getSCEVByOpCode(lhs, rhs, OpCode)); + + if (!AddRec || AddRec->getLoop() != L) + return {nullptr, Unknown}; + + return {AddRec, ExtKind}; +} + +/// Is this instruction potentially interesting for further simplification after +/// widening it's type? In other words, can the extend be safely hoisted out of +/// the loop with SCEV reducing the value to a recurrence on the same loop. If +/// so, return the extended recurrence and the kind of extension used. Otherwise +/// return {nullptr, Unknown}. +WidenIV::WidenedRecTy WidenIV::getWideRecurrence(NarrowIVDefUse DU) { + if (!SE->isSCEVable(DU.NarrowUse->getType())) + return {nullptr, Unknown}; + + const SCEV *NarrowExpr = SE->getSCEV(DU.NarrowUse); + if (SE->getTypeSizeInBits(NarrowExpr->getType()) >= + SE->getTypeSizeInBits(WideType)) { + // NarrowUse implicitly widens its operand. e.g. a gep with a narrow + // index. So don't follow this use. + return {nullptr, Unknown}; + } + + const SCEV *WideExpr; + ExtendKind ExtKind; + if (DU.NeverNegative) { + WideExpr = SE->getSignExtendExpr(NarrowExpr, WideType); + if (isa<SCEVAddRecExpr>(WideExpr)) + ExtKind = SignExtended; + else { + WideExpr = SE->getZeroExtendExpr(NarrowExpr, WideType); + ExtKind = ZeroExtended; + } + } else if (getExtendKind(DU.NarrowDef) == SignExtended) { + WideExpr = SE->getSignExtendExpr(NarrowExpr, WideType); + ExtKind = SignExtended; + } else { + WideExpr = SE->getZeroExtendExpr(NarrowExpr, WideType); + ExtKind = ZeroExtended; + } + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(WideExpr); + if (!AddRec || AddRec->getLoop() != L) + return {nullptr, Unknown}; + return {AddRec, ExtKind}; +} + +/// This IV user cannot be widen. Replace this use of the original narrow IV +/// with a truncation of the new wide IV to isolate and eliminate the narrow IV. +static void truncateIVUse(NarrowIVDefUse DU, DominatorTree *DT, LoopInfo *LI) { + DEBUG(dbgs() << "INDVARS: Truncate IV " << *DU.WideDef + << " for user " << *DU.NarrowUse << "\n"); + IRBuilder<> Builder( + getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI)); + Value *Trunc = Builder.CreateTrunc(DU.WideDef, DU.NarrowDef->getType()); + DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, Trunc); +} + +/// If the narrow use is a compare instruction, then widen the compare +// (and possibly the other operand). The extend operation is hoisted into the +// loop preheader as far as possible. +bool WidenIV::widenLoopCompare(NarrowIVDefUse DU) { + ICmpInst *Cmp = dyn_cast<ICmpInst>(DU.NarrowUse); + if (!Cmp) + return false; + + // We can legally widen the comparison in the following two cases: + // + // - The signedness of the IV extension and comparison match + // + // - The narrow IV is always positive (and thus its sign extension is equal + // to its zero extension). For instance, let's say we're zero extending + // %narrow for the following use + // + // icmp slt i32 %narrow, %val ... (A) + // + // and %narrow is always positive. Then + // + // (A) == icmp slt i32 sext(%narrow), sext(%val) + // == icmp slt i32 zext(%narrow), sext(%val) + bool IsSigned = getExtendKind(DU.NarrowDef) == SignExtended; + if (!(DU.NeverNegative || IsSigned == Cmp->isSigned())) + return false; + + Value *Op = Cmp->getOperand(Cmp->getOperand(0) == DU.NarrowDef ? 1 : 0); + unsigned CastWidth = SE->getTypeSizeInBits(Op->getType()); + unsigned IVWidth = SE->getTypeSizeInBits(WideType); + assert(CastWidth <= IVWidth && "Unexpected width while widening compare."); + + // Widen the compare instruction. + IRBuilder<> Builder( + getInsertPointForUses(DU.NarrowUse, DU.NarrowDef, DT, LI)); + DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); + + // Widen the other operand of the compare, if necessary. + if (CastWidth < IVWidth) { + Value *ExtOp = createExtendInst(Op, WideType, Cmp->isSigned(), Cmp); + DU.NarrowUse->replaceUsesOfWith(Op, ExtOp); + } + return true; +} + +/// Determine whether an individual user of the narrow IV can be widened. If so, +/// return the wide clone of the user. +Instruction *WidenIV::widenIVUse(NarrowIVDefUse DU, SCEVExpander &Rewriter) { + assert(ExtendKindMap.count(DU.NarrowDef) && + "Should already know the kind of extension used to widen NarrowDef"); + + // Stop traversing the def-use chain at inner-loop phis or post-loop phis. + if (PHINode *UsePhi = dyn_cast<PHINode>(DU.NarrowUse)) { + if (LI->getLoopFor(UsePhi->getParent()) != L) { + // For LCSSA phis, sink the truncate outside the loop. + // After SimplifyCFG most loop exit targets have a single predecessor. + // Otherwise fall back to a truncate within the loop. + if (UsePhi->getNumOperands() != 1) + truncateIVUse(DU, DT, LI); + else { + // Widening the PHI requires us to insert a trunc. The logical place + // for this trunc is in the same BB as the PHI. This is not possible if + // the BB is terminated by a catchswitch. + if (isa<CatchSwitchInst>(UsePhi->getParent()->getTerminator())) + return nullptr; + + PHINode *WidePhi = + PHINode::Create(DU.WideDef->getType(), 1, UsePhi->getName() + ".wide", + UsePhi); + WidePhi->addIncoming(DU.WideDef, UsePhi->getIncomingBlock(0)); + IRBuilder<> Builder(&*WidePhi->getParent()->getFirstInsertionPt()); + Value *Trunc = Builder.CreateTrunc(WidePhi, DU.NarrowDef->getType()); + UsePhi->replaceAllUsesWith(Trunc); + DeadInsts.emplace_back(UsePhi); + DEBUG(dbgs() << "INDVARS: Widen lcssa phi " << *UsePhi + << " to " << *WidePhi << "\n"); + } + return nullptr; + } + } + + // This narrow use can be widened by a sext if it's non-negative or its narrow + // def was widended by a sext. Same for zext. + auto canWidenBySExt = [&]() { + return DU.NeverNegative || getExtendKind(DU.NarrowDef) == SignExtended; + }; + auto canWidenByZExt = [&]() { + return DU.NeverNegative || getExtendKind(DU.NarrowDef) == ZeroExtended; + }; + + // Our raison d'etre! Eliminate sign and zero extension. + if ((isa<SExtInst>(DU.NarrowUse) && canWidenBySExt()) || + (isa<ZExtInst>(DU.NarrowUse) && canWidenByZExt())) { + Value *NewDef = DU.WideDef; + if (DU.NarrowUse->getType() != WideType) { + unsigned CastWidth = SE->getTypeSizeInBits(DU.NarrowUse->getType()); + unsigned IVWidth = SE->getTypeSizeInBits(WideType); + if (CastWidth < IVWidth) { + // The cast isn't as wide as the IV, so insert a Trunc. + IRBuilder<> Builder(DU.NarrowUse); + NewDef = Builder.CreateTrunc(DU.WideDef, DU.NarrowUse->getType()); + } + else { + // A wider extend was hidden behind a narrower one. This may induce + // another round of IV widening in which the intermediate IV becomes + // dead. It should be very rare. + DEBUG(dbgs() << "INDVARS: New IV " << *WidePhi + << " not wide enough to subsume " << *DU.NarrowUse << "\n"); + DU.NarrowUse->replaceUsesOfWith(DU.NarrowDef, DU.WideDef); + NewDef = DU.NarrowUse; + } + } + if (NewDef != DU.NarrowUse) { + DEBUG(dbgs() << "INDVARS: eliminating " << *DU.NarrowUse + << " replaced by " << *DU.WideDef << "\n"); + ++NumElimExt; + DU.NarrowUse->replaceAllUsesWith(NewDef); + DeadInsts.emplace_back(DU.NarrowUse); + } + // Now that the extend is gone, we want to expose it's uses for potential + // further simplification. We don't need to directly inform SimplifyIVUsers + // of the new users, because their parent IV will be processed later as a + // new loop phi. If we preserved IVUsers analysis, we would also want to + // push the uses of WideDef here. + + // No further widening is needed. The deceased [sz]ext had done it for us. + return nullptr; + } + + // Does this user itself evaluate to a recurrence after widening? + WidenedRecTy WideAddRec = getExtendedOperandRecurrence(DU); + if (!WideAddRec.first) + WideAddRec = getWideRecurrence(DU); + + assert((WideAddRec.first == nullptr) == (WideAddRec.second == Unknown)); + if (!WideAddRec.first) { + // If use is a loop condition, try to promote the condition instead of + // truncating the IV first. + if (widenLoopCompare(DU)) + return nullptr; + + // This user does not evaluate to a recurrence after widening, so don't + // follow it. Instead insert a Trunc to kill off the original use, + // eventually isolating the original narrow IV so it can be removed. + truncateIVUse(DU, DT, LI); + return nullptr; + } + // Assume block terminators cannot evaluate to a recurrence. We can't to + // insert a Trunc after a terminator if there happens to be a critical edge. + assert(DU.NarrowUse != DU.NarrowUse->getParent()->getTerminator() && + "SCEV is not expected to evaluate a block terminator"); + + // Reuse the IV increment that SCEVExpander created as long as it dominates + // NarrowUse. + Instruction *WideUse = nullptr; + if (WideAddRec.first == WideIncExpr && + Rewriter.hoistIVInc(WideInc, DU.NarrowUse)) + WideUse = WideInc; + else { + WideUse = cloneIVUser(DU, WideAddRec.first); + if (!WideUse) + return nullptr; + } + // Evaluation of WideAddRec ensured that the narrow expression could be + // extended outside the loop without overflow. This suggests that the wide use + // evaluates to the same expression as the extended narrow use, but doesn't + // absolutely guarantee it. Hence the following failsafe check. In rare cases + // where it fails, we simply throw away the newly created wide use. + if (WideAddRec.first != SE->getSCEV(WideUse)) { + DEBUG(dbgs() << "Wide use expression mismatch: " << *WideUse + << ": " << *SE->getSCEV(WideUse) << " != " << *WideAddRec.first << "\n"); + DeadInsts.emplace_back(WideUse); + return nullptr; + } + + ExtendKindMap[DU.NarrowUse] = WideAddRec.second; + // Returning WideUse pushes it on the worklist. + return WideUse; +} + +/// Add eligible users of NarrowDef to NarrowIVUsers. +void WidenIV::pushNarrowIVUsers(Instruction *NarrowDef, Instruction *WideDef) { + const SCEV *NarrowSCEV = SE->getSCEV(NarrowDef); + bool NonNegativeDef = + SE->isKnownPredicate(ICmpInst::ICMP_SGE, NarrowSCEV, + SE->getConstant(NarrowSCEV->getType(), 0)); + for (User *U : NarrowDef->users()) { + Instruction *NarrowUser = cast<Instruction>(U); + + // Handle data flow merges and bizarre phi cycles. + if (!Widened.insert(NarrowUser).second) + continue; + + bool NonNegativeUse = false; + if (!NonNegativeDef) { + // We might have a control-dependent range information for this context. + if (auto RangeInfo = getPostIncRangeInfo(NarrowDef, NarrowUser)) + NonNegativeUse = RangeInfo->getSignedMin().isNonNegative(); + } + + NarrowIVUsers.emplace_back(NarrowDef, NarrowUser, WideDef, + NonNegativeDef || NonNegativeUse); + } +} + +/// Process a single induction variable. First use the SCEVExpander to create a +/// wide induction variable that evaluates to the same recurrence as the +/// original narrow IV. Then use a worklist to forward traverse the narrow IV's +/// def-use chain. After widenIVUse has processed all interesting IV users, the +/// narrow IV will be isolated for removal by DeleteDeadPHIs. +/// +/// It would be simpler to delete uses as they are processed, but we must avoid +/// invalidating SCEV expressions. +PHINode *WidenIV::createWideIV(SCEVExpander &Rewriter) { + // Is this phi an induction variable? + const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(OrigPhi)); + if (!AddRec) + return nullptr; + + // Widen the induction variable expression. + const SCEV *WideIVExpr = getExtendKind(OrigPhi) == SignExtended + ? SE->getSignExtendExpr(AddRec, WideType) + : SE->getZeroExtendExpr(AddRec, WideType); + + assert(SE->getEffectiveSCEVType(WideIVExpr->getType()) == WideType && + "Expect the new IV expression to preserve its type"); + + // Can the IV be extended outside the loop without overflow? + AddRec = dyn_cast<SCEVAddRecExpr>(WideIVExpr); + if (!AddRec || AddRec->getLoop() != L) + return nullptr; + + // An AddRec must have loop-invariant operands. Since this AddRec is + // materialized by a loop header phi, the expression cannot have any post-loop + // operands, so they must dominate the loop header. + assert( + SE->properlyDominates(AddRec->getStart(), L->getHeader()) && + SE->properlyDominates(AddRec->getStepRecurrence(*SE), L->getHeader()) && + "Loop header phi recurrence inputs do not dominate the loop"); + + // Iterate over IV uses (including transitive ones) looking for IV increments + // of the form 'add nsw %iv, <const>'. For each increment and each use of + // the increment calculate control-dependent range information basing on + // dominating conditions inside of the loop (e.g. a range check inside of the + // loop). Calculated ranges are stored in PostIncRangeInfos map. + // + // Control-dependent range information is later used to prove that a narrow + // definition is not negative (see pushNarrowIVUsers). It's difficult to do + // this on demand because when pushNarrowIVUsers needs this information some + // of the dominating conditions might be already widened. + if (UsePostIncrementRanges) + calculatePostIncRanges(OrigPhi); + + // The rewriter provides a value for the desired IV expression. This may + // either find an existing phi or materialize a new one. Either way, we + // expect a well-formed cyclic phi-with-increments. i.e. any operand not part + // of the phi-SCC dominates the loop entry. + Instruction *InsertPt = &L->getHeader()->front(); + WidePhi = cast<PHINode>(Rewriter.expandCodeFor(AddRec, WideType, InsertPt)); + + // Remembering the WideIV increment generated by SCEVExpander allows + // widenIVUse to reuse it when widening the narrow IV's increment. We don't + // employ a general reuse mechanism because the call above is the only call to + // SCEVExpander. Henceforth, we produce 1-to-1 narrow to wide uses. + if (BasicBlock *LatchBlock = L->getLoopLatch()) { + WideInc = + cast<Instruction>(WidePhi->getIncomingValueForBlock(LatchBlock)); + WideIncExpr = SE->getSCEV(WideInc); + // Propagate the debug location associated with the original loop increment + // to the new (widened) increment. + auto *OrigInc = + cast<Instruction>(OrigPhi->getIncomingValueForBlock(LatchBlock)); + WideInc->setDebugLoc(OrigInc->getDebugLoc()); + } + + DEBUG(dbgs() << "Wide IV: " << *WidePhi << "\n"); + ++NumWidened; + + // Traverse the def-use chain using a worklist starting at the original IV. + assert(Widened.empty() && NarrowIVUsers.empty() && "expect initial state" ); + + Widened.insert(OrigPhi); + pushNarrowIVUsers(OrigPhi, WidePhi); + + while (!NarrowIVUsers.empty()) { + NarrowIVDefUse DU = NarrowIVUsers.pop_back_val(); + + // Process a def-use edge. This may replace the use, so don't hold a + // use_iterator across it. + Instruction *WideUse = widenIVUse(DU, Rewriter); + + // Follow all def-use edges from the previous narrow use. + if (WideUse) + pushNarrowIVUsers(DU.NarrowUse, WideUse); + + // widenIVUse may have removed the def-use edge. + if (DU.NarrowDef->use_empty()) + DeadInsts.emplace_back(DU.NarrowDef); + } + + // Attach any debug information to the new PHI. Since OrigPhi and WidePHI + // evaluate the same recurrence, we can just copy the debug info over. + SmallVector<DbgValueInst *, 1> DbgValues; + llvm::findDbgValues(DbgValues, OrigPhi); + auto *MDPhi = MetadataAsValue::get(WidePhi->getContext(), + ValueAsMetadata::get(WidePhi)); + for (auto &DbgValue : DbgValues) + DbgValue->setOperand(0, MDPhi); + return WidePhi; +} + +/// Calculates control-dependent range for the given def at the given context +/// by looking at dominating conditions inside of the loop +void WidenIV::calculatePostIncRange(Instruction *NarrowDef, + Instruction *NarrowUser) { + using namespace llvm::PatternMatch; + + Value *NarrowDefLHS; + const APInt *NarrowDefRHS; + if (!match(NarrowDef, m_NSWAdd(m_Value(NarrowDefLHS), + m_APInt(NarrowDefRHS))) || + !NarrowDefRHS->isNonNegative()) + return; + + auto UpdateRangeFromCondition = [&] (Value *Condition, + bool TrueDest) { + CmpInst::Predicate Pred; + Value *CmpRHS; + if (!match(Condition, m_ICmp(Pred, m_Specific(NarrowDefLHS), + m_Value(CmpRHS)))) + return; + + CmpInst::Predicate P = + TrueDest ? Pred : CmpInst::getInversePredicate(Pred); + + auto CmpRHSRange = SE->getSignedRange(SE->getSCEV(CmpRHS)); + auto CmpConstrainedLHSRange = + ConstantRange::makeAllowedICmpRegion(P, CmpRHSRange); + auto NarrowDefRange = + CmpConstrainedLHSRange.addWithNoSignedWrap(*NarrowDefRHS); + + updatePostIncRangeInfo(NarrowDef, NarrowUser, NarrowDefRange); + }; + + auto UpdateRangeFromGuards = [&](Instruction *Ctx) { + if (!HasGuards) + return; + + for (Instruction &I : make_range(Ctx->getIterator().getReverse(), + Ctx->getParent()->rend())) { + Value *C = nullptr; + if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>(m_Value(C)))) + UpdateRangeFromCondition(C, /*TrueDest=*/true); + } + }; + + UpdateRangeFromGuards(NarrowUser); + + BasicBlock *NarrowUserBB = NarrowUser->getParent(); + // If NarrowUserBB is statically unreachable asking dominator queries may + // yield surprising results. (e.g. the block may not have a dom tree node) + if (!DT->isReachableFromEntry(NarrowUserBB)) + return; + + for (auto *DTB = (*DT)[NarrowUserBB]->getIDom(); + L->contains(DTB->getBlock()); + DTB = DTB->getIDom()) { + auto *BB = DTB->getBlock(); + auto *TI = BB->getTerminator(); + UpdateRangeFromGuards(TI); + + auto *BI = dyn_cast<BranchInst>(TI); + if (!BI || !BI->isConditional()) + continue; + + auto *TrueSuccessor = BI->getSuccessor(0); + auto *FalseSuccessor = BI->getSuccessor(1); + + auto DominatesNarrowUser = [this, NarrowUser] (BasicBlockEdge BBE) { + return BBE.isSingleEdge() && + DT->dominates(BBE, NarrowUser->getParent()); + }; + + if (DominatesNarrowUser(BasicBlockEdge(BB, TrueSuccessor))) + UpdateRangeFromCondition(BI->getCondition(), /*TrueDest=*/true); + + if (DominatesNarrowUser(BasicBlockEdge(BB, FalseSuccessor))) + UpdateRangeFromCondition(BI->getCondition(), /*TrueDest=*/false); + } +} + +/// Calculates PostIncRangeInfos map for the given IV +void WidenIV::calculatePostIncRanges(PHINode *OrigPhi) { + SmallPtrSet<Instruction *, 16> Visited; + SmallVector<Instruction *, 6> Worklist; + Worklist.push_back(OrigPhi); + Visited.insert(OrigPhi); + + while (!Worklist.empty()) { + Instruction *NarrowDef = Worklist.pop_back_val(); + + for (Use &U : NarrowDef->uses()) { + auto *NarrowUser = cast<Instruction>(U.getUser()); + + // Don't go looking outside the current loop. + auto *NarrowUserLoop = (*LI)[NarrowUser->getParent()]; + if (!NarrowUserLoop || !L->contains(NarrowUserLoop)) + continue; + + if (!Visited.insert(NarrowUser).second) + continue; + + Worklist.push_back(NarrowUser); + + calculatePostIncRange(NarrowDef, NarrowUser); + } + } +} + +//===----------------------------------------------------------------------===// +// Live IV Reduction - Minimize IVs live across the loop. +//===----------------------------------------------------------------------===// + +//===----------------------------------------------------------------------===// +// Simplification of IV users based on SCEV evaluation. +//===----------------------------------------------------------------------===// + +namespace { + +class IndVarSimplifyVisitor : public IVVisitor { + ScalarEvolution *SE; + const TargetTransformInfo *TTI; + PHINode *IVPhi; + +public: + WideIVInfo WI; + + IndVarSimplifyVisitor(PHINode *IV, ScalarEvolution *SCEV, + const TargetTransformInfo *TTI, + const DominatorTree *DTree) + : SE(SCEV), TTI(TTI), IVPhi(IV) { + DT = DTree; + WI.NarrowIV = IVPhi; + } + + // Implement the interface used by simplifyUsersOfIV. + void visitCast(CastInst *Cast) override { visitIVCast(Cast, WI, SE, TTI); } +}; + +} // end anonymous namespace + +/// Iteratively perform simplification on a worklist of IV users. Each +/// successive simplification may push more users which may themselves be +/// candidates for simplification. +/// +/// Sign/Zero extend elimination is interleaved with IV simplification. +void IndVarSimplify::simplifyAndExtend(Loop *L, + SCEVExpander &Rewriter, + LoopInfo *LI) { + SmallVector<WideIVInfo, 8> WideIVs; + + auto *GuardDecl = L->getBlocks()[0]->getModule()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + bool HasGuards = GuardDecl && !GuardDecl->use_empty(); + + SmallVector<PHINode*, 8> LoopPhis; + for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { + LoopPhis.push_back(cast<PHINode>(I)); + } + // Each round of simplification iterates through the SimplifyIVUsers worklist + // for all current phis, then determines whether any IVs can be + // widened. Widening adds new phis to LoopPhis, inducing another round of + // simplification on the wide IVs. + while (!LoopPhis.empty()) { + // Evaluate as many IV expressions as possible before widening any IVs. This + // forces SCEV to set no-wrap flags before evaluating sign/zero + // extension. The first time SCEV attempts to normalize sign/zero extension, + // the result becomes final. So for the most predictable results, we delay + // evaluation of sign/zero extend evaluation until needed, and avoid running + // other SCEV based analysis prior to simplifyAndExtend. + do { + PHINode *CurrIV = LoopPhis.pop_back_val(); + + // Information about sign/zero extensions of CurrIV. + IndVarSimplifyVisitor Visitor(CurrIV, SE, TTI, DT); + + Changed |= + simplifyUsersOfIV(CurrIV, SE, DT, LI, DeadInsts, Rewriter, &Visitor); + + if (Visitor.WI.WidestNativeType) { + WideIVs.push_back(Visitor.WI); + } + } while(!LoopPhis.empty()); + + for (; !WideIVs.empty(); WideIVs.pop_back()) { + WidenIV Widener(WideIVs.back(), LI, SE, DT, DeadInsts, HasGuards); + if (PHINode *WidePhi = Widener.createWideIV(Rewriter)) { + Changed = true; + LoopPhis.push_back(WidePhi); + } + } + } +} + +//===----------------------------------------------------------------------===// +// linearFunctionTestReplace and its kin. Rewrite the loop exit condition. +//===----------------------------------------------------------------------===// + +/// Return true if this loop's backedge taken count expression can be safely and +/// cheaply expanded into an instruction sequence that can be used by +/// linearFunctionTestReplace. +/// +/// TODO: This fails for pointer-type loop counters with greater than one byte +/// strides, consequently preventing LFTR from running. For the purpose of LFTR +/// we could skip this check in the case that the LFTR loop counter (chosen by +/// FindLoopCounter) is also pointer type. Instead, we could directly convert +/// the loop test to an inequality test by checking the target data's alignment +/// of element types (given that the initial pointer value originates from or is +/// used by ABI constrained operation, as opposed to inttoptr/ptrtoint). +/// However, we don't yet have a strong motivation for converting loop tests +/// into inequality tests. +static bool canExpandBackedgeTakenCount(Loop *L, ScalarEvolution *SE, + SCEVExpander &Rewriter) { + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount) || + BackedgeTakenCount->isZero()) + return false; + + if (!L->getExitingBlock()) + return false; + + // Can't rewrite non-branch yet. + if (!isa<BranchInst>(L->getExitingBlock()->getTerminator())) + return false; + + if (Rewriter.isHighCostExpansion(BackedgeTakenCount, L)) + return false; + + return true; +} + +/// Return the loop header phi IFF IncV adds a loop invariant value to the phi. +static PHINode *getLoopPhiForCounter(Value *IncV, Loop *L, DominatorTree *DT) { + Instruction *IncI = dyn_cast<Instruction>(IncV); + if (!IncI) + return nullptr; + + switch (IncI->getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + break; + case Instruction::GetElementPtr: + // An IV counter must preserve its type. + if (IncI->getNumOperands() == 2) + break; + LLVM_FALLTHROUGH; + default: + return nullptr; + } + + PHINode *Phi = dyn_cast<PHINode>(IncI->getOperand(0)); + if (Phi && Phi->getParent() == L->getHeader()) { + if (isLoopInvariant(IncI->getOperand(1), L, DT)) + return Phi; + return nullptr; + } + if (IncI->getOpcode() == Instruction::GetElementPtr) + return nullptr; + + // Allow add/sub to be commuted. + Phi = dyn_cast<PHINode>(IncI->getOperand(1)); + if (Phi && Phi->getParent() == L->getHeader()) { + if (isLoopInvariant(IncI->getOperand(0), L, DT)) + return Phi; + } + return nullptr; +} + +/// Return the compare guarding the loop latch, or NULL for unrecognized tests. +static ICmpInst *getLoopTest(Loop *L) { + assert(L->getExitingBlock() && "expected loop exit"); + + BasicBlock *LatchBlock = L->getLoopLatch(); + // Don't bother with LFTR if the loop is not properly simplified. + if (!LatchBlock) + return nullptr; + + BranchInst *BI = dyn_cast<BranchInst>(L->getExitingBlock()->getTerminator()); + assert(BI && "expected exit branch"); + + return dyn_cast<ICmpInst>(BI->getCondition()); +} + +/// linearFunctionTestReplace policy. Return true unless we can show that the +/// current exit test is already sufficiently canonical. +static bool needsLFTR(Loop *L, DominatorTree *DT) { + // Do LFTR to simplify the exit condition to an ICMP. + ICmpInst *Cond = getLoopTest(L); + if (!Cond) + return true; + + // Do LFTR to simplify the exit ICMP to EQ/NE + ICmpInst::Predicate Pred = Cond->getPredicate(); + if (Pred != ICmpInst::ICMP_NE && Pred != ICmpInst::ICMP_EQ) + return true; + + // Look for a loop invariant RHS + Value *LHS = Cond->getOperand(0); + Value *RHS = Cond->getOperand(1); + if (!isLoopInvariant(RHS, L, DT)) { + if (!isLoopInvariant(LHS, L, DT)) + return true; + std::swap(LHS, RHS); + } + // Look for a simple IV counter LHS + PHINode *Phi = dyn_cast<PHINode>(LHS); + if (!Phi) + Phi = getLoopPhiForCounter(LHS, L, DT); + + if (!Phi) + return true; + + // Do LFTR if PHI node is defined in the loop, but is *not* a counter. + int Idx = Phi->getBasicBlockIndex(L->getLoopLatch()); + if (Idx < 0) + return true; + + // Do LFTR if the exit condition's IV is *not* a simple counter. + Value *IncV = Phi->getIncomingValue(Idx); + return Phi != getLoopPhiForCounter(IncV, L, DT); +} + +/// Recursive helper for hasConcreteDef(). Unfortunately, this currently boils +/// down to checking that all operands are constant and listing instructions +/// that may hide undef. +static bool hasConcreteDefImpl(Value *V, SmallPtrSetImpl<Value*> &Visited, + unsigned Depth) { + if (isa<Constant>(V)) + return !isa<UndefValue>(V); + + if (Depth >= 6) + return false; + + // Conservatively handle non-constant non-instructions. For example, Arguments + // may be undef. + Instruction *I = dyn_cast<Instruction>(V); + if (!I) + return false; + + // Load and return values may be undef. + if(I->mayReadFromMemory() || isa<CallInst>(I) || isa<InvokeInst>(I)) + return false; + + // Optimistically handle other instructions. + for (Value *Op : I->operands()) { + if (!Visited.insert(Op).second) + continue; + if (!hasConcreteDefImpl(Op, Visited, Depth+1)) + return false; + } + return true; +} + +/// Return true if the given value is concrete. We must prove that undef can +/// never reach it. +/// +/// TODO: If we decide that this is a good approach to checking for undef, we +/// may factor it into a common location. +static bool hasConcreteDef(Value *V) { + SmallPtrSet<Value*, 8> Visited; + Visited.insert(V); + return hasConcreteDefImpl(V, Visited, 0); +} + +/// Return true if this IV has any uses other than the (soon to be rewritten) +/// loop exit test. +static bool AlmostDeadIV(PHINode *Phi, BasicBlock *LatchBlock, Value *Cond) { + int LatchIdx = Phi->getBasicBlockIndex(LatchBlock); + Value *IncV = Phi->getIncomingValue(LatchIdx); + + for (User *U : Phi->users()) + if (U != Cond && U != IncV) return false; + + for (User *U : IncV->users()) + if (U != Cond && U != Phi) return false; + return true; +} + +/// Find an affine IV in canonical form. +/// +/// BECount may be an i8* pointer type. The pointer difference is already +/// valid count without scaling the address stride, so it remains a pointer +/// expression as far as SCEV is concerned. +/// +/// Currently only valid for LFTR. See the comments on hasConcreteDef below. +/// +/// FIXME: Accept -1 stride and set IVLimit = IVInit - BECount +/// +/// FIXME: Accept non-unit stride as long as SCEV can reduce BECount * Stride. +/// This is difficult in general for SCEV because of potential overflow. But we +/// could at least handle constant BECounts. +static PHINode *FindLoopCounter(Loop *L, const SCEV *BECount, + ScalarEvolution *SE, DominatorTree *DT) { + uint64_t BCWidth = SE->getTypeSizeInBits(BECount->getType()); + + Value *Cond = + cast<BranchInst>(L->getExitingBlock()->getTerminator())->getCondition(); + + // Loop over all of the PHI nodes, looking for a simple counter. + PHINode *BestPhi = nullptr; + const SCEV *BestInit = nullptr; + BasicBlock *LatchBlock = L->getLoopLatch(); + assert(LatchBlock && "needsLFTR should guarantee a loop latch"); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + + for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { + PHINode *Phi = cast<PHINode>(I); + if (!SE->isSCEVable(Phi->getType())) + continue; + + // Avoid comparing an integer IV against a pointer Limit. + if (BECount->getType()->isPointerTy() && !Phi->getType()->isPointerTy()) + continue; + + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Phi)); + if (!AR || AR->getLoop() != L || !AR->isAffine()) + continue; + + // AR may be a pointer type, while BECount is an integer type. + // AR may be wider than BECount. With eq/ne tests overflow is immaterial. + // AR may not be a narrower type, or we may never exit. + uint64_t PhiWidth = SE->getTypeSizeInBits(AR->getType()); + if (PhiWidth < BCWidth || !DL.isLegalInteger(PhiWidth)) + continue; + + const SCEV *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); + if (!Step || !Step->isOne()) + continue; + + int LatchIdx = Phi->getBasicBlockIndex(LatchBlock); + Value *IncV = Phi->getIncomingValue(LatchIdx); + if (getLoopPhiForCounter(IncV, L, DT) != Phi) + continue; + + // Avoid reusing a potentially undef value to compute other values that may + // have originally had a concrete definition. + if (!hasConcreteDef(Phi)) { + // We explicitly allow unknown phis as long as they are already used by + // the loop test. In this case we assume that performing LFTR could not + // increase the number of undef users. + if (ICmpInst *Cond = getLoopTest(L)) { + if (Phi != getLoopPhiForCounter(Cond->getOperand(0), L, DT) && + Phi != getLoopPhiForCounter(Cond->getOperand(1), L, DT)) { + continue; + } + } + } + const SCEV *Init = AR->getStart(); + + if (BestPhi && !AlmostDeadIV(BestPhi, LatchBlock, Cond)) { + // Don't force a live loop counter if another IV can be used. + if (AlmostDeadIV(Phi, LatchBlock, Cond)) + continue; + + // Prefer to count-from-zero. This is a more "canonical" counter form. It + // also prefers integer to pointer IVs. + if (BestInit->isZero() != Init->isZero()) { + if (BestInit->isZero()) + continue; + } + // If two IVs both count from zero or both count from nonzero then the + // narrower is likely a dead phi that has been widened. Use the wider phi + // to allow the other to be eliminated. + else if (PhiWidth <= SE->getTypeSizeInBits(BestPhi->getType())) + continue; + } + BestPhi = Phi; + BestInit = Init; + } + return BestPhi; +} + +/// Help linearFunctionTestReplace by generating a value that holds the RHS of +/// the new loop test. +static Value *genLoopLimit(PHINode *IndVar, const SCEV *IVCount, Loop *L, + SCEVExpander &Rewriter, ScalarEvolution *SE) { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IndVar)); + assert(AR && AR->getLoop() == L && AR->isAffine() && "bad loop counter"); + const SCEV *IVInit = AR->getStart(); + + // IVInit may be a pointer while IVCount is an integer when FindLoopCounter + // finds a valid pointer IV. Sign extend BECount in order to materialize a + // GEP. Avoid running SCEVExpander on a new pointer value, instead reusing + // the existing GEPs whenever possible. + if (IndVar->getType()->isPointerTy() && !IVCount->getType()->isPointerTy()) { + // IVOffset will be the new GEP offset that is interpreted by GEP as a + // signed value. IVCount on the other hand represents the loop trip count, + // which is an unsigned value. FindLoopCounter only allows induction + // variables that have a positive unit stride of one. This means we don't + // have to handle the case of negative offsets (yet) and just need to zero + // extend IVCount. + Type *OfsTy = SE->getEffectiveSCEVType(IVInit->getType()); + const SCEV *IVOffset = SE->getTruncateOrZeroExtend(IVCount, OfsTy); + + // Expand the code for the iteration count. + assert(SE->isLoopInvariant(IVOffset, L) && + "Computed iteration count is not loop invariant!"); + BranchInst *BI = cast<BranchInst>(L->getExitingBlock()->getTerminator()); + Value *GEPOffset = Rewriter.expandCodeFor(IVOffset, OfsTy, BI); + + Value *GEPBase = IndVar->getIncomingValueForBlock(L->getLoopPreheader()); + assert(AR->getStart() == SE->getSCEV(GEPBase) && "bad loop counter"); + // We could handle pointer IVs other than i8*, but we need to compensate for + // gep index scaling. See canExpandBackedgeTakenCount comments. + assert(SE->getSizeOfExpr(IntegerType::getInt64Ty(IndVar->getContext()), + cast<PointerType>(GEPBase->getType()) + ->getElementType())->isOne() && + "unit stride pointer IV must be i8*"); + + IRBuilder<> Builder(L->getLoopPreheader()->getTerminator()); + return Builder.CreateGEP(nullptr, GEPBase, GEPOffset, "lftr.limit"); + } else { + // In any other case, convert both IVInit and IVCount to integers before + // comparing. This may result in SCEV expansion of pointers, but in practice + // SCEV will fold the pointer arithmetic away as such: + // BECount = (IVEnd - IVInit - 1) => IVLimit = IVInit (postinc). + // + // Valid Cases: (1) both integers is most common; (2) both may be pointers + // for simple memset-style loops. + // + // IVInit integer and IVCount pointer would only occur if a canonical IV + // were generated on top of case #2, which is not expected. + + const SCEV *IVLimit = nullptr; + // For unit stride, IVCount = Start + BECount with 2's complement overflow. + // For non-zero Start, compute IVCount here. + if (AR->getStart()->isZero()) + IVLimit = IVCount; + else { + assert(AR->getStepRecurrence(*SE)->isOne() && "only handles unit stride"); + const SCEV *IVInit = AR->getStart(); + + // For integer IVs, truncate the IV before computing IVInit + BECount. + if (SE->getTypeSizeInBits(IVInit->getType()) + > SE->getTypeSizeInBits(IVCount->getType())) + IVInit = SE->getTruncateExpr(IVInit, IVCount->getType()); + + IVLimit = SE->getAddExpr(IVInit, IVCount); + } + // Expand the code for the iteration count. + BranchInst *BI = cast<BranchInst>(L->getExitingBlock()->getTerminator()); + IRBuilder<> Builder(BI); + assert(SE->isLoopInvariant(IVLimit, L) && + "Computed iteration count is not loop invariant!"); + // Ensure that we generate the same type as IndVar, or a smaller integer + // type. In the presence of null pointer values, we have an integer type + // SCEV expression (IVInit) for a pointer type IV value (IndVar). + Type *LimitTy = IVCount->getType()->isPointerTy() ? + IndVar->getType() : IVCount->getType(); + return Rewriter.expandCodeFor(IVLimit, LimitTy, BI); + } +} + +/// This method rewrites the exit condition of the loop to be a canonical != +/// comparison against the incremented loop induction variable. This pass is +/// able to rewrite the exit tests of any loop where the SCEV analysis can +/// determine a loop-invariant trip count of the loop, which is actually a much +/// broader range than just linear tests. +Value *IndVarSimplify:: +linearFunctionTestReplace(Loop *L, + const SCEV *BackedgeTakenCount, + PHINode *IndVar, + SCEVExpander &Rewriter) { + assert(canExpandBackedgeTakenCount(L, SE, Rewriter) && "precondition"); + + // Initialize CmpIndVar and IVCount to their preincremented values. + Value *CmpIndVar = IndVar; + const SCEV *IVCount = BackedgeTakenCount; + + assert(L->getLoopLatch() && "Loop no longer in simplified form?"); + + // If the exiting block is the same as the backedge block, we prefer to + // compare against the post-incremented value, otherwise we must compare + // against the preincremented value. + if (L->getExitingBlock() == L->getLoopLatch()) { + // Add one to the "backedge-taken" count to get the trip count. + // This addition may overflow, which is valid as long as the comparison is + // truncated to BackedgeTakenCount->getType(). + IVCount = SE->getAddExpr(BackedgeTakenCount, + SE->getOne(BackedgeTakenCount->getType())); + // The BackedgeTaken expression contains the number of times that the + // backedge branches to the loop header. This is one less than the + // number of times the loop executes, so use the incremented indvar. + CmpIndVar = IndVar->getIncomingValueForBlock(L->getExitingBlock()); + } + + Value *ExitCnt = genLoopLimit(IndVar, IVCount, L, Rewriter, SE); + assert(ExitCnt->getType()->isPointerTy() == + IndVar->getType()->isPointerTy() && + "genLoopLimit missed a cast"); + + // Insert a new icmp_ne or icmp_eq instruction before the branch. + BranchInst *BI = cast<BranchInst>(L->getExitingBlock()->getTerminator()); + ICmpInst::Predicate P; + if (L->contains(BI->getSuccessor(0))) + P = ICmpInst::ICMP_NE; + else + P = ICmpInst::ICMP_EQ; + + DEBUG(dbgs() << "INDVARS: Rewriting loop exit condition to:\n" + << " LHS:" << *CmpIndVar << '\n' + << " op:\t" + << (P == ICmpInst::ICMP_NE ? "!=" : "==") << "\n" + << " RHS:\t" << *ExitCnt << "\n" + << " IVCount:\t" << *IVCount << "\n"); + + IRBuilder<> Builder(BI); + + // The new loop exit condition should reuse the debug location of the + // original loop exit condition. + if (auto *Cond = dyn_cast<Instruction>(BI->getCondition())) + Builder.SetCurrentDebugLocation(Cond->getDebugLoc()); + + // LFTR can ignore IV overflow and truncate to the width of + // BECount. This avoids materializing the add(zext(add)) expression. + unsigned CmpIndVarSize = SE->getTypeSizeInBits(CmpIndVar->getType()); + unsigned ExitCntSize = SE->getTypeSizeInBits(ExitCnt->getType()); + if (CmpIndVarSize > ExitCntSize) { + const SCEVAddRecExpr *AR = cast<SCEVAddRecExpr>(SE->getSCEV(IndVar)); + const SCEV *ARStart = AR->getStart(); + const SCEV *ARStep = AR->getStepRecurrence(*SE); + // For constant IVCount, avoid truncation. + if (isa<SCEVConstant>(ARStart) && isa<SCEVConstant>(IVCount)) { + const APInt &Start = cast<SCEVConstant>(ARStart)->getAPInt(); + APInt Count = cast<SCEVConstant>(IVCount)->getAPInt(); + // Note that the post-inc value of BackedgeTakenCount may have overflowed + // above such that IVCount is now zero. + if (IVCount != BackedgeTakenCount && Count == 0) { + Count = APInt::getMaxValue(Count.getBitWidth()).zext(CmpIndVarSize); + ++Count; + } + else + Count = Count.zext(CmpIndVarSize); + APInt NewLimit; + if (cast<SCEVConstant>(ARStep)->getValue()->isNegative()) + NewLimit = Start - Count; + else + NewLimit = Start + Count; + ExitCnt = ConstantInt::get(CmpIndVar->getType(), NewLimit); + + DEBUG(dbgs() << " Widen RHS:\t" << *ExitCnt << "\n"); + } else { + // We try to extend trip count first. If that doesn't work we truncate IV. + // Zext(trunc(IV)) == IV implies equivalence of the following two: + // Trunc(IV) == ExitCnt and IV == zext(ExitCnt). Similarly for sext. If + // one of the two holds, extend the trip count, otherwise we truncate IV. + bool Extended = false; + const SCEV *IV = SE->getSCEV(CmpIndVar); + const SCEV *ZExtTrunc = + SE->getZeroExtendExpr(SE->getTruncateExpr(SE->getSCEV(CmpIndVar), + ExitCnt->getType()), + CmpIndVar->getType()); + + if (ZExtTrunc == IV) { + Extended = true; + ExitCnt = Builder.CreateZExt(ExitCnt, IndVar->getType(), + "wide.trip.count"); + } else { + const SCEV *SExtTrunc = + SE->getSignExtendExpr(SE->getTruncateExpr(SE->getSCEV(CmpIndVar), + ExitCnt->getType()), + CmpIndVar->getType()); + if (SExtTrunc == IV) { + Extended = true; + ExitCnt = Builder.CreateSExt(ExitCnt, IndVar->getType(), + "wide.trip.count"); + } + } + + if (!Extended) + CmpIndVar = Builder.CreateTrunc(CmpIndVar, ExitCnt->getType(), + "lftr.wideiv"); + } + } + Value *Cond = Builder.CreateICmp(P, CmpIndVar, ExitCnt, "exitcond"); + Value *OrigCond = BI->getCondition(); + // It's tempting to use replaceAllUsesWith here to fully replace the old + // comparison, but that's not immediately safe, since users of the old + // comparison may not be dominated by the new comparison. Instead, just + // update the branch to use the new comparison; in the common case this + // will make old comparison dead. + BI->setCondition(Cond); + DeadInsts.push_back(OrigCond); + + ++NumLFTR; + Changed = true; + return Cond; +} + +//===----------------------------------------------------------------------===// +// sinkUnusedInvariants. A late subpass to cleanup loop preheaders. +//===----------------------------------------------------------------------===// + +/// If there's a single exit block, sink any loop-invariant values that +/// were defined in the preheader but not used inside the loop into the +/// exit block to reduce register pressure in the loop. +void IndVarSimplify::sinkUnusedInvariants(Loop *L) { + BasicBlock *ExitBlock = L->getExitBlock(); + if (!ExitBlock) return; + + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) return; + + BasicBlock::iterator InsertPt = ExitBlock->getFirstInsertionPt(); + BasicBlock::iterator I(Preheader->getTerminator()); + while (I != Preheader->begin()) { + --I; + // New instructions were inserted at the end of the preheader. + if (isa<PHINode>(I)) + break; + + // Don't move instructions which might have side effects, since the side + // effects need to complete before instructions inside the loop. Also don't + // move instructions which might read memory, since the loop may modify + // memory. Note that it's okay if the instruction might have undefined + // behavior: LoopSimplify guarantees that the preheader dominates the exit + // block. + if (I->mayHaveSideEffects() || I->mayReadFromMemory()) + continue; + + // Skip debug info intrinsics. + if (isa<DbgInfoIntrinsic>(I)) + continue; + + // Skip eh pad instructions. + if (I->isEHPad()) + continue; + + // Don't sink alloca: we never want to sink static alloca's out of the + // entry block, and correctly sinking dynamic alloca's requires + // checks for stacksave/stackrestore intrinsics. + // FIXME: Refactor this check somehow? + if (isa<AllocaInst>(I)) + continue; + + // Determine if there is a use in or before the loop (direct or + // otherwise). + bool UsedInLoop = false; + for (Use &U : I->uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + BasicBlock *UseBB = User->getParent(); + if (PHINode *P = dyn_cast<PHINode>(User)) { + unsigned i = + PHINode::getIncomingValueNumForOperand(U.getOperandNo()); + UseBB = P->getIncomingBlock(i); + } + if (UseBB == Preheader || L->contains(UseBB)) { + UsedInLoop = true; + break; + } + } + + // If there is, the def must remain in the preheader. + if (UsedInLoop) + continue; + + // Otherwise, sink it to the exit block. + Instruction *ToMove = &*I; + bool Done = false; + + if (I != Preheader->begin()) { + // Skip debug info intrinsics. + do { + --I; + } while (isa<DbgInfoIntrinsic>(I) && I != Preheader->begin()); + + if (isa<DbgInfoIntrinsic>(I) && I == Preheader->begin()) + Done = true; + } else { + Done = true; + } + + ToMove->moveBefore(*ExitBlock, InsertPt); + if (Done) break; + InsertPt = ToMove->getIterator(); + } +} + +//===----------------------------------------------------------------------===// +// IndVarSimplify driver. Manage several subpasses of IV simplification. +//===----------------------------------------------------------------------===// + +bool IndVarSimplify::run(Loop *L) { + // We need (and expect!) the incoming loop to be in LCSSA. + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "LCSSA required to run indvars!"); + + // If LoopSimplify form is not available, stay out of trouble. Some notes: + // - LSR currently only supports LoopSimplify-form loops. Indvars' + // canonicalization can be a pessimization without LSR to "clean up" + // afterwards. + // - We depend on having a preheader; in particular, + // Loop::getCanonicalInductionVariable only supports loops with preheaders, + // and we're in trouble if we can't find the induction variable even when + // we've manually inserted one. + // - LFTR relies on having a single backedge. + if (!L->isLoopSimplifyForm()) + return false; + + // If there are any floating-point recurrences, attempt to + // transform them to use integer recurrences. + rewriteNonIntegerIVs(L); + + const SCEV *BackedgeTakenCount = SE->getBackedgeTakenCount(L); + + // Create a rewriter object which we'll use to transform the code with. + SCEVExpander Rewriter(*SE, DL, "indvars"); +#ifndef NDEBUG + Rewriter.setDebugType(DEBUG_TYPE); +#endif + + // Eliminate redundant IV users. + // + // Simplification works best when run before other consumers of SCEV. We + // attempt to avoid evaluating SCEVs for sign/zero extend operations until + // other expressions involving loop IVs have been evaluated. This helps SCEV + // set no-wrap flags before normalizing sign/zero extension. + Rewriter.disableCanonicalMode(); + simplifyAndExtend(L, Rewriter, LI); + + // Check to see if this loop has a computable loop-invariant execution count. + // If so, this means that we can compute the final value of any expressions + // that are recurrent in the loop, and substitute the exit values from the + // loop into any instructions outside of the loop that use the final values of + // the current expressions. + // + if (ReplaceExitValue != NeverRepl && + !isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + rewriteLoopExitValues(L, Rewriter); + + // Eliminate redundant IV cycles. + NumElimIV += Rewriter.replaceCongruentIVs(L, DT, DeadInsts); + + // If we have a trip count expression, rewrite the loop's exit condition + // using it. We can currently only handle loops with a single exit. + if (!DisableLFTR && canExpandBackedgeTakenCount(L, SE, Rewriter) && + needsLFTR(L, DT)) { + PHINode *IndVar = FindLoopCounter(L, BackedgeTakenCount, SE, DT); + if (IndVar) { + // Check preconditions for proper SCEVExpander operation. SCEV does not + // express SCEVExpander's dependencies, such as LoopSimplify. Instead any + // pass that uses the SCEVExpander must do it. This does not work well for + // loop passes because SCEVExpander makes assumptions about all loops, + // while LoopPassManager only forces the current loop to be simplified. + // + // FIXME: SCEV expansion has no way to bail out, so the caller must + // explicitly check any assumptions made by SCEV. Brittle. + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(BackedgeTakenCount); + if (!AR || AR->getLoop()->getLoopPreheader()) + (void)linearFunctionTestReplace(L, BackedgeTakenCount, IndVar, + Rewriter); + } + } + // Clear the rewriter cache, because values that are in the rewriter's cache + // can be deleted in the loop below, causing the AssertingVH in the cache to + // trigger. + Rewriter.clear(); + + // Now that we're done iterating through lists, clean up any instructions + // which are now dead. + while (!DeadInsts.empty()) + if (Instruction *Inst = + dyn_cast_or_null<Instruction>(DeadInsts.pop_back_val())) + RecursivelyDeleteTriviallyDeadInstructions(Inst, TLI); + + // The Rewriter may not be used from this point on. + + // Loop-invariant instructions in the preheader that aren't used in the + // loop may be sunk below the loop to reduce register pressure. + sinkUnusedInvariants(L); + + // rewriteFirstIterationLoopExitValues does not rely on the computation of + // trip count and therefore can further simplify exit values in addition to + // rewriteLoopExitValues. + rewriteFirstIterationLoopExitValues(L); + + // Clean up dead instructions. + Changed |= DeleteDeadPHIs(L->getHeader(), TLI); + + // Check a post-condition. + assert(L->isRecursivelyLCSSAForm(*DT, *LI) && + "Indvars did not preserve LCSSA!"); + + // Verify that LFTR, and any other change have not interfered with SCEV's + // ability to compute trip count. +#ifndef NDEBUG + if (VerifyIndvars && !isa<SCEVCouldNotCompute>(BackedgeTakenCount)) { + SE->forgetLoop(L); + const SCEV *NewBECount = SE->getBackedgeTakenCount(L); + if (SE->getTypeSizeInBits(BackedgeTakenCount->getType()) < + SE->getTypeSizeInBits(NewBECount->getType())) + NewBECount = SE->getTruncateOrNoop(NewBECount, + BackedgeTakenCount->getType()); + else + BackedgeTakenCount = SE->getTruncateOrNoop(BackedgeTakenCount, + NewBECount->getType()); + assert(BackedgeTakenCount == NewBECount && "indvars must preserve SCEV"); + } +#endif + + return Changed; +} + +PreservedAnalyses IndVarSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + Function *F = L.getHeader()->getParent(); + const DataLayout &DL = F->getParent()->getDataLayout(); + + IndVarSimplify IVS(&AR.LI, &AR.SE, &AR.DT, DL, &AR.TLI, &AR.TTI); + if (!IVS.run(&L)) + return PreservedAnalyses::all(); + + auto PA = getLoopPassPreservedAnalyses(); + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +namespace { + +struct IndVarSimplifyLegacyPass : public LoopPass { + static char ID; // Pass identification, replacement for typeid + + IndVarSimplifyLegacyPass() : LoopPass(ID) { + initializeIndVarSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); + auto *TLI = TLIP ? &TLIP->getTLI() : nullptr; + auto *TTIP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); + auto *TTI = TTIP ? &TTIP->getTTI(*L->getHeader()->getParent()) : nullptr; + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + + IndVarSimplify IVS(LI, SE, DT, DL, TLI, TTI); + return IVS.run(L); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +char IndVarSimplifyLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(IndVarSimplifyLegacyPass, "indvars", + "Induction Variable Simplification", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(IndVarSimplifyLegacyPass, "indvars", + "Induction Variable Simplification", false, false) + +Pass *llvm::createIndVarSimplifyPass() { + return new IndVarSimplifyLegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp new file mode 100644 index 000000000000..cf98088111be --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/InductiveRangeCheckElimination.cpp @@ -0,0 +1,1856 @@ +//===- InductiveRangeCheckElimination.cpp - -------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The InductiveRangeCheckElimination pass splits a loop's iteration space into +// three disjoint ranges. It does that in a way such that the loop running in +// the middle loop provably does not need range checks. As an example, it will +// convert +// +// len = < known positive > +// for (i = 0; i < n; i++) { +// if (0 <= i && i < len) { +// do_something(); +// } else { +// throw_out_of_bounds(); +// } +// } +// +// to +// +// len = < known positive > +// limit = smin(n, len) +// // no first segment +// for (i = 0; i < limit; i++) { +// if (0 <= i && i < len) { // this check is fully redundant +// do_something(); +// } else { +// throw_out_of_bounds(); +// } +// } +// for (i = limit; i < n; i++) { +// if (0 <= i && i < len) { +// do_something(); +// } else { +// throw_out_of_bounds(); +// } +// } +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <limits> +#include <utility> +#include <vector> + +using namespace llvm; +using namespace llvm::PatternMatch; + +static cl::opt<unsigned> LoopSizeCutoff("irce-loop-size-cutoff", cl::Hidden, + cl::init(64)); + +static cl::opt<bool> PrintChangedLoops("irce-print-changed-loops", cl::Hidden, + cl::init(false)); + +static cl::opt<bool> PrintRangeChecks("irce-print-range-checks", cl::Hidden, + cl::init(false)); + +static cl::opt<int> MaxExitProbReciprocal("irce-max-exit-prob-reciprocal", + cl::Hidden, cl::init(10)); + +static cl::opt<bool> SkipProfitabilityChecks("irce-skip-profitability-checks", + cl::Hidden, cl::init(false)); + +static cl::opt<bool> AllowUnsignedLatchCondition("irce-allow-unsigned-latch", + cl::Hidden, cl::init(true)); + +static const char *ClonedLoopTag = "irce.loop.clone"; + +#define DEBUG_TYPE "irce" + +namespace { + +/// An inductive range check is conditional branch in a loop with +/// +/// 1. a very cold successor (i.e. the branch jumps to that successor very +/// rarely) +/// +/// and +/// +/// 2. a condition that is provably true for some contiguous range of values +/// taken by the containing loop's induction variable. +/// +class InductiveRangeCheck { + // Classifies a range check + enum RangeCheckKind : unsigned { + // Range check of the form "0 <= I". + RANGE_CHECK_LOWER = 1, + + // Range check of the form "I < L" where L is known positive. + RANGE_CHECK_UPPER = 2, + + // The logical and of the RANGE_CHECK_LOWER and RANGE_CHECK_UPPER + // conditions. + RANGE_CHECK_BOTH = RANGE_CHECK_LOWER | RANGE_CHECK_UPPER, + + // Unrecognized range check condition. + RANGE_CHECK_UNKNOWN = (unsigned)-1 + }; + + static StringRef rangeCheckKindToStr(RangeCheckKind); + + const SCEV *Begin = nullptr; + const SCEV *Step = nullptr; + const SCEV *End = nullptr; + Use *CheckUse = nullptr; + RangeCheckKind Kind = RANGE_CHECK_UNKNOWN; + bool IsSigned = true; + + static RangeCheckKind parseRangeCheckICmp(Loop *L, ICmpInst *ICI, + ScalarEvolution &SE, Value *&Index, + Value *&Length, bool &IsSigned); + + static void + extractRangeChecksFromCond(Loop *L, ScalarEvolution &SE, Use &ConditionUse, + SmallVectorImpl<InductiveRangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited); + +public: + const SCEV *getBegin() const { return Begin; } + const SCEV *getStep() const { return Step; } + const SCEV *getEnd() const { return End; } + bool isSigned() const { return IsSigned; } + + void print(raw_ostream &OS) const { + OS << "InductiveRangeCheck:\n"; + OS << " Kind: " << rangeCheckKindToStr(Kind) << "\n"; + OS << " Begin: "; + Begin->print(OS); + OS << " Step: "; + Step->print(OS); + OS << " End: "; + if (End) + End->print(OS); + else + OS << "(null)"; + OS << "\n CheckUse: "; + getCheckUse()->getUser()->print(OS); + OS << " Operand: " << getCheckUse()->getOperandNo() << "\n"; + } + + LLVM_DUMP_METHOD + void dump() { + print(dbgs()); + } + + Use *getCheckUse() const { return CheckUse; } + + /// Represents an signed integer range [Range.getBegin(), Range.getEnd()). If + /// R.getEnd() sle R.getBegin(), then R denotes the empty range. + + class Range { + const SCEV *Begin; + const SCEV *End; + + public: + Range(const SCEV *Begin, const SCEV *End) : Begin(Begin), End(End) { + assert(Begin->getType() == End->getType() && "ill-typed range!"); + } + + Type *getType() const { return Begin->getType(); } + const SCEV *getBegin() const { return Begin; } + const SCEV *getEnd() const { return End; } + bool isEmpty(ScalarEvolution &SE, bool IsSigned) const { + if (Begin == End) + return true; + if (IsSigned) + return SE.isKnownPredicate(ICmpInst::ICMP_SGE, Begin, End); + else + return SE.isKnownPredicate(ICmpInst::ICMP_UGE, Begin, End); + } + }; + + /// This is the value the condition of the branch needs to evaluate to for the + /// branch to take the hot successor (see (1) above). + bool getPassingDirection() { return true; } + + /// Computes a range for the induction variable (IndVar) in which the range + /// check is redundant and can be constant-folded away. The induction + /// variable is not required to be the canonical {0,+,1} induction variable. + Optional<Range> computeSafeIterationSpace(ScalarEvolution &SE, + const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const; + + /// Parse out a set of inductive range checks from \p BI and append them to \p + /// Checks. + /// + /// NB! There may be conditions feeding into \p BI that aren't inductive range + /// checks, and hence don't end up in \p Checks. + static void + extractRangeChecksFromBranch(BranchInst *BI, Loop *L, ScalarEvolution &SE, + BranchProbabilityInfo &BPI, + SmallVectorImpl<InductiveRangeCheck> &Checks); +}; + +class InductiveRangeCheckElimination : public LoopPass { +public: + static char ID; + + InductiveRangeCheckElimination() : LoopPass(ID) { + initializeInductiveRangeCheckEliminationPass( + *PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<BranchProbabilityInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; +}; + +} // end anonymous namespace + +char InductiveRangeCheckElimination::ID = 0; + +INITIALIZE_PASS_BEGIN(InductiveRangeCheckElimination, "irce", + "Inductive range check elimination", false, false) +INITIALIZE_PASS_DEPENDENCY(BranchProbabilityInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(InductiveRangeCheckElimination, "irce", + "Inductive range check elimination", false, false) + +StringRef InductiveRangeCheck::rangeCheckKindToStr( + InductiveRangeCheck::RangeCheckKind RCK) { + switch (RCK) { + case InductiveRangeCheck::RANGE_CHECK_UNKNOWN: + return "RANGE_CHECK_UNKNOWN"; + + case InductiveRangeCheck::RANGE_CHECK_UPPER: + return "RANGE_CHECK_UPPER"; + + case InductiveRangeCheck::RANGE_CHECK_LOWER: + return "RANGE_CHECK_LOWER"; + + case InductiveRangeCheck::RANGE_CHECK_BOTH: + return "RANGE_CHECK_BOTH"; + } + + llvm_unreachable("unknown range check type!"); +} + +/// Parse a single ICmp instruction, `ICI`, into a range check. If `ICI` cannot +/// be interpreted as a range check, return `RANGE_CHECK_UNKNOWN` and set +/// `Index` and `Length` to `nullptr`. Otherwise set `Index` to the value being +/// range checked, and set `Length` to the upper limit `Index` is being range +/// checked with if (and only if) the range check type is stronger or equal to +/// RANGE_CHECK_UPPER. +InductiveRangeCheck::RangeCheckKind +InductiveRangeCheck::parseRangeCheckICmp(Loop *L, ICmpInst *ICI, + ScalarEvolution &SE, Value *&Index, + Value *&Length, bool &IsSigned) { + auto IsNonNegativeAndNotLoopVarying = [&SE, L](Value *V) { + const SCEV *S = SE.getSCEV(V); + if (isa<SCEVCouldNotCompute>(S)) + return false; + + return SE.getLoopDisposition(S, L) == ScalarEvolution::LoopInvariant && + SE.isKnownNonNegative(S); + }; + + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *LHS = ICI->getOperand(0); + Value *RHS = ICI->getOperand(1); + + switch (Pred) { + default: + return RANGE_CHECK_UNKNOWN; + + case ICmpInst::ICMP_SLE: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SGE: + IsSigned = true; + if (match(RHS, m_ConstantInt<0>())) { + Index = LHS; + return RANGE_CHECK_LOWER; + } + return RANGE_CHECK_UNKNOWN; + + case ICmpInst::ICMP_SLT: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_SGT: + IsSigned = true; + if (match(RHS, m_ConstantInt<-1>())) { + Index = LHS; + return RANGE_CHECK_LOWER; + } + + if (IsNonNegativeAndNotLoopVarying(LHS)) { + Index = RHS; + Length = LHS; + return RANGE_CHECK_UPPER; + } + return RANGE_CHECK_UNKNOWN; + + case ICmpInst::ICMP_ULT: + std::swap(LHS, RHS); + LLVM_FALLTHROUGH; + case ICmpInst::ICMP_UGT: + IsSigned = false; + if (IsNonNegativeAndNotLoopVarying(LHS)) { + Index = RHS; + Length = LHS; + return RANGE_CHECK_BOTH; + } + return RANGE_CHECK_UNKNOWN; + } + + llvm_unreachable("default clause returns!"); +} + +void InductiveRangeCheck::extractRangeChecksFromCond( + Loop *L, ScalarEvolution &SE, Use &ConditionUse, + SmallVectorImpl<InductiveRangeCheck> &Checks, + SmallPtrSetImpl<Value *> &Visited) { + Value *Condition = ConditionUse.get(); + if (!Visited.insert(Condition).second) + return; + + // TODO: Do the same for OR, XOR, NOT etc? + if (match(Condition, m_And(m_Value(), m_Value()))) { + extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(0), + Checks, Visited); + extractRangeChecksFromCond(L, SE, cast<User>(Condition)->getOperandUse(1), + Checks, Visited); + return; + } + + ICmpInst *ICI = dyn_cast<ICmpInst>(Condition); + if (!ICI) + return; + + Value *Length = nullptr, *Index; + bool IsSigned; + auto RCKind = parseRangeCheckICmp(L, ICI, SE, Index, Length, IsSigned); + if (RCKind == InductiveRangeCheck::RANGE_CHECK_UNKNOWN) + return; + + const auto *IndexAddRec = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Index)); + bool IsAffineIndex = + IndexAddRec && (IndexAddRec->getLoop() == L) && IndexAddRec->isAffine(); + + if (!IsAffineIndex) + return; + + InductiveRangeCheck IRC; + IRC.End = Length ? SE.getSCEV(Length) : nullptr; + IRC.Begin = IndexAddRec->getStart(); + IRC.Step = IndexAddRec->getStepRecurrence(SE); + IRC.CheckUse = &ConditionUse; + IRC.Kind = RCKind; + IRC.IsSigned = IsSigned; + Checks.push_back(IRC); +} + +void InductiveRangeCheck::extractRangeChecksFromBranch( + BranchInst *BI, Loop *L, ScalarEvolution &SE, BranchProbabilityInfo &BPI, + SmallVectorImpl<InductiveRangeCheck> &Checks) { + if (BI->isUnconditional() || BI->getParent() == L->getLoopLatch()) + return; + + BranchProbability LikelyTaken(15, 16); + + if (!SkipProfitabilityChecks && + BPI.getEdgeProbability(BI->getParent(), (unsigned)0) < LikelyTaken) + return; + + SmallPtrSet<Value *, 8> Visited; + InductiveRangeCheck::extractRangeChecksFromCond(L, SE, BI->getOperandUse(0), + Checks, Visited); +} + +// Add metadata to the loop L to disable loop optimizations. Callers need to +// confirm that optimizing loop L is not beneficial. +static void DisableAllLoopOptsOnLoop(Loop &L) { + // We do not care about any existing loopID related metadata for L, since we + // are setting all loop metadata to false. + LLVMContext &Context = L.getHeader()->getContext(); + // Reserve first location for self reference to the LoopID metadata node. + MDNode *Dummy = MDNode::get(Context, {}); + MDNode *DisableUnroll = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.unroll.disable")}); + Metadata *FalseVal = + ConstantAsMetadata::get(ConstantInt::get(Type::getInt1Ty(Context), 0)); + MDNode *DisableVectorize = MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.vectorize.enable"), FalseVal}); + MDNode *DisableLICMVersioning = MDNode::get( + Context, {MDString::get(Context, "llvm.loop.licm_versioning.disable")}); + MDNode *DisableDistribution= MDNode::get( + Context, + {MDString::get(Context, "llvm.loop.distribute.enable"), FalseVal}); + MDNode *NewLoopID = + MDNode::get(Context, {Dummy, DisableUnroll, DisableVectorize, + DisableLICMVersioning, DisableDistribution}); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + L.setLoopID(NewLoopID); +} + +namespace { + +// Keeps track of the structure of a loop. This is similar to llvm::Loop, +// except that it is more lightweight and can track the state of a loop through +// changing and potentially invalid IR. This structure also formalizes the +// kinds of loops we can deal with -- ones that have a single latch that is also +// an exiting block *and* have a canonical induction variable. +struct LoopStructure { + const char *Tag = ""; + + BasicBlock *Header = nullptr; + BasicBlock *Latch = nullptr; + + // `Latch's terminator instruction is `LatchBr', and it's `LatchBrExitIdx'th + // successor is `LatchExit', the exit block of the loop. + BranchInst *LatchBr = nullptr; + BasicBlock *LatchExit = nullptr; + unsigned LatchBrExitIdx = std::numeric_limits<unsigned>::max(); + + // The loop represented by this instance of LoopStructure is semantically + // equivalent to: + // + // intN_ty inc = IndVarIncreasing ? 1 : -1; + // pred_ty predicate = IndVarIncreasing ? ICMP_SLT : ICMP_SGT; + // + // for (intN_ty iv = IndVarStart; predicate(iv, LoopExitAt); iv = IndVarBase) + // ... body ... + + Value *IndVarBase = nullptr; + Value *IndVarStart = nullptr; + Value *IndVarStep = nullptr; + Value *LoopExitAt = nullptr; + bool IndVarIncreasing = false; + bool IsSignedPredicate = true; + + LoopStructure() = default; + + template <typename M> LoopStructure map(M Map) const { + LoopStructure Result; + Result.Tag = Tag; + Result.Header = cast<BasicBlock>(Map(Header)); + Result.Latch = cast<BasicBlock>(Map(Latch)); + Result.LatchBr = cast<BranchInst>(Map(LatchBr)); + Result.LatchExit = cast<BasicBlock>(Map(LatchExit)); + Result.LatchBrExitIdx = LatchBrExitIdx; + Result.IndVarBase = Map(IndVarBase); + Result.IndVarStart = Map(IndVarStart); + Result.IndVarStep = Map(IndVarStep); + Result.LoopExitAt = Map(LoopExitAt); + Result.IndVarIncreasing = IndVarIncreasing; + Result.IsSignedPredicate = IsSignedPredicate; + return Result; + } + + static Optional<LoopStructure> parseLoopStructure(ScalarEvolution &, + BranchProbabilityInfo &BPI, + Loop &, + const char *&); +}; + +/// This class is used to constrain loops to run within a given iteration space. +/// The algorithm this class implements is given a Loop and a range [Begin, +/// End). The algorithm then tries to break out a "main loop" out of the loop +/// it is given in a way that the "main loop" runs with the induction variable +/// in a subset of [Begin, End). The algorithm emits appropriate pre and post +/// loops to run any remaining iterations. The pre loop runs any iterations in +/// which the induction variable is < Begin, and the post loop runs any +/// iterations in which the induction variable is >= End. +class LoopConstrainer { + // The representation of a clone of the original loop we started out with. + struct ClonedLoop { + // The cloned blocks + std::vector<BasicBlock *> Blocks; + + // `Map` maps values in the clonee into values in the cloned version + ValueToValueMapTy Map; + + // An instance of `LoopStructure` for the cloned loop + LoopStructure Structure; + }; + + // Result of rewriting the range of a loop. See changeIterationSpaceEnd for + // more details on what these fields mean. + struct RewrittenRangeInfo { + BasicBlock *PseudoExit = nullptr; + BasicBlock *ExitSelector = nullptr; + std::vector<PHINode *> PHIValuesAtPseudoExit; + PHINode *IndVarEnd = nullptr; + + RewrittenRangeInfo() = default; + }; + + // Calculated subranges we restrict the iteration space of the main loop to. + // See the implementation of `calculateSubRanges' for more details on how + // these fields are computed. `LowLimit` is None if there is no restriction + // on low end of the restricted iteration space of the main loop. `HighLimit` + // is None if there is no restriction on high end of the restricted iteration + // space of the main loop. + + struct SubRanges { + Optional<const SCEV *> LowLimit; + Optional<const SCEV *> HighLimit; + }; + + // A utility function that does a `replaceUsesOfWith' on the incoming block + // set of a `PHINode' -- replaces instances of `Block' in the `PHINode's + // incoming block list with `ReplaceBy'. + static void replacePHIBlock(PHINode *PN, BasicBlock *Block, + BasicBlock *ReplaceBy); + + // Compute a safe set of limits for the main loop to run in -- effectively the + // intersection of `Range' and the iteration space of the original loop. + // Return None if unable to compute the set of subranges. + Optional<SubRanges> calculateSubRanges(bool IsSignedPredicate) const; + + // Clone `OriginalLoop' and return the result in CLResult. The IR after + // running `cloneLoop' is well formed except for the PHI nodes in CLResult -- + // the PHI nodes say that there is an incoming edge from `OriginalPreheader` + // but there is no such edge. + void cloneLoop(ClonedLoop &CLResult, const char *Tag) const; + + // Create the appropriate loop structure needed to describe a cloned copy of + // `Original`. The clone is described by `VM`. + Loop *createClonedLoopStructure(Loop *Original, Loop *Parent, + ValueToValueMapTy &VM); + + // Rewrite the iteration space of the loop denoted by (LS, Preheader). The + // iteration space of the rewritten loop ends at ExitLoopAt. The start of the + // iteration space is not changed. `ExitLoopAt' is assumed to be slt + // `OriginalHeaderCount'. + // + // If there are iterations left to execute, control is made to jump to + // `ContinuationBlock', otherwise they take the normal loop exit. The + // returned `RewrittenRangeInfo' object is populated as follows: + // + // .PseudoExit is a basic block that unconditionally branches to + // `ContinuationBlock'. + // + // .ExitSelector is a basic block that decides, on exit from the loop, + // whether to branch to the "true" exit or to `PseudoExit'. + // + // .PHIValuesAtPseudoExit are PHINodes in `PseudoExit' that compute the value + // for each PHINode in the loop header on taking the pseudo exit. + // + // After changeIterationSpaceEnd, `Preheader' is no longer a legitimate + // preheader because it is made to branch to the loop header only + // conditionally. + RewrittenRangeInfo + changeIterationSpaceEnd(const LoopStructure &LS, BasicBlock *Preheader, + Value *ExitLoopAt, + BasicBlock *ContinuationBlock) const; + + // The loop denoted by `LS' has `OldPreheader' as its preheader. This + // function creates a new preheader for `LS' and returns it. + BasicBlock *createPreheader(const LoopStructure &LS, BasicBlock *OldPreheader, + const char *Tag) const; + + // `ContinuationBlockAndPreheader' was the continuation block for some call to + // `changeIterationSpaceEnd' and is the preheader to the loop denoted by `LS'. + // This function rewrites the PHI nodes in `LS.Header' to start with the + // correct value. + void rewriteIncomingValuesForPHIs( + LoopStructure &LS, BasicBlock *ContinuationBlockAndPreheader, + const LoopConstrainer::RewrittenRangeInfo &RRI) const; + + // Even though we do not preserve any passes at this time, we at least need to + // keep the parent loop structure consistent. The `LPPassManager' seems to + // verify this after running a loop pass. This function adds the list of + // blocks denoted by BBs to this loops parent loop if required. + void addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs); + + // Some global state. + Function &F; + LLVMContext &Ctx; + ScalarEvolution &SE; + DominatorTree &DT; + LPPassManager &LPM; + LoopInfo &LI; + + // Information about the original loop we started out with. + Loop &OriginalLoop; + + const SCEV *LatchTakenCount = nullptr; + BasicBlock *OriginalPreheader = nullptr; + + // The preheader of the main loop. This may or may not be different from + // `OriginalPreheader'. + BasicBlock *MainLoopPreheader = nullptr; + + // The range we need to run the main loop in. + InductiveRangeCheck::Range Range; + + // The structure of the main loop (see comment at the beginning of this class + // for a definition) + LoopStructure MainLoopStructure; + +public: + LoopConstrainer(Loop &L, LoopInfo &LI, LPPassManager &LPM, + const LoopStructure &LS, ScalarEvolution &SE, + DominatorTree &DT, InductiveRangeCheck::Range R) + : F(*L.getHeader()->getParent()), Ctx(L.getHeader()->getContext()), + SE(SE), DT(DT), LPM(LPM), LI(LI), OriginalLoop(L), Range(R), + MainLoopStructure(LS) {} + + // Entry point for the algorithm. Returns true on success. + bool run(); +}; + +} // end anonymous namespace + +void LoopConstrainer::replacePHIBlock(PHINode *PN, BasicBlock *Block, + BasicBlock *ReplaceBy) { + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingBlock(i) == Block) + PN->setIncomingBlock(i, ReplaceBy); +} + +static bool CanBeMax(ScalarEvolution &SE, const SCEV *S, bool Signed) { + APInt Max = Signed ? + APInt::getSignedMaxValue(cast<IntegerType>(S->getType())->getBitWidth()) : + APInt::getMaxValue(cast<IntegerType>(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(Max) && + SE.getUnsignedRange(S).contains(Max); +} + +static bool SumCanReachMax(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, + bool Signed) { + // S1 < INT_MAX - S2 ===> S1 + S2 < INT_MAX. + assert(SE.isKnownNonNegative(S2) && + "We expected the 2nd arg to be non-negative!"); + const SCEV *Max = SE.getConstant( + Signed ? APInt::getSignedMaxValue( + cast<IntegerType>(S1->getType())->getBitWidth()) + : APInt::getMaxValue( + cast<IntegerType>(S1->getType())->getBitWidth())); + const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); + return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT, + S1, CapForS1); +} + +static bool CanBeMin(ScalarEvolution &SE, const SCEV *S, bool Signed) { + APInt Min = Signed ? + APInt::getSignedMinValue(cast<IntegerType>(S->getType())->getBitWidth()) : + APInt::getMinValue(cast<IntegerType>(S->getType())->getBitWidth()); + return SE.getSignedRange(S).contains(Min) && + SE.getUnsignedRange(S).contains(Min); +} + +static bool SumCanReachMin(ScalarEvolution &SE, const SCEV *S1, const SCEV *S2, + bool Signed) { + // S1 > INT_MIN - S2 ===> S1 + S2 > INT_MIN. + assert(SE.isKnownNonPositive(S2) && + "We expected the 2nd arg to be non-positive!"); + const SCEV *Max = SE.getConstant( + Signed ? APInt::getSignedMinValue( + cast<IntegerType>(S1->getType())->getBitWidth()) + : APInt::getMinValue( + cast<IntegerType>(S1->getType())->getBitWidth())); + const SCEV *CapForS1 = SE.getMinusSCEV(Max, S2); + return !SE.isKnownPredicate(Signed ? ICmpInst::ICMP_SGT : ICmpInst::ICMP_UGT, + S1, CapForS1); +} + +Optional<LoopStructure> +LoopStructure::parseLoopStructure(ScalarEvolution &SE, + BranchProbabilityInfo &BPI, + Loop &L, const char *&FailureReason) { + if (!L.isLoopSimplifyForm()) { + FailureReason = "loop not in LoopSimplify form"; + return None; + } + + BasicBlock *Latch = L.getLoopLatch(); + assert(Latch && "Simplified loops only have one latch!"); + + if (Latch->getTerminator()->getMetadata(ClonedLoopTag)) { + FailureReason = "loop has already been cloned"; + return None; + } + + if (!L.isLoopExiting(Latch)) { + FailureReason = "no loop latch"; + return None; + } + + BasicBlock *Header = L.getHeader(); + BasicBlock *Preheader = L.getLoopPreheader(); + if (!Preheader) { + FailureReason = "no preheader"; + return None; + } + + BranchInst *LatchBr = dyn_cast<BranchInst>(Latch->getTerminator()); + if (!LatchBr || LatchBr->isUnconditional()) { + FailureReason = "latch terminator not conditional branch"; + return None; + } + + unsigned LatchBrExitIdx = LatchBr->getSuccessor(0) == Header ? 1 : 0; + + BranchProbability ExitProbability = + BPI.getEdgeProbability(LatchBr->getParent(), LatchBrExitIdx); + + if (!SkipProfitabilityChecks && + ExitProbability > BranchProbability(1, MaxExitProbReciprocal)) { + FailureReason = "short running loop, not profitable"; + return None; + } + + ICmpInst *ICI = dyn_cast<ICmpInst>(LatchBr->getCondition()); + if (!ICI || !isa<IntegerType>(ICI->getOperand(0)->getType())) { + FailureReason = "latch terminator branch not conditional on integral icmp"; + return None; + } + + const SCEV *LatchCount = SE.getExitCount(&L, Latch); + if (isa<SCEVCouldNotCompute>(LatchCount)) { + FailureReason = "could not compute latch count"; + return None; + } + + ICmpInst::Predicate Pred = ICI->getPredicate(); + Value *LeftValue = ICI->getOperand(0); + const SCEV *LeftSCEV = SE.getSCEV(LeftValue); + IntegerType *IndVarTy = cast<IntegerType>(LeftValue->getType()); + + Value *RightValue = ICI->getOperand(1); + const SCEV *RightSCEV = SE.getSCEV(RightValue); + + // We canonicalize `ICI` such that `LeftSCEV` is an add recurrence. + if (!isa<SCEVAddRecExpr>(LeftSCEV)) { + if (isa<SCEVAddRecExpr>(RightSCEV)) { + std::swap(LeftSCEV, RightSCEV); + std::swap(LeftValue, RightValue); + Pred = ICmpInst::getSwappedPredicate(Pred); + } else { + FailureReason = "no add recurrences in the icmp"; + return None; + } + } + + auto HasNoSignedWrap = [&](const SCEVAddRecExpr *AR) { + if (AR->getNoWrapFlags(SCEV::FlagNSW)) + return true; + + IntegerType *Ty = cast<IntegerType>(AR->getType()); + IntegerType *WideTy = + IntegerType::get(Ty->getContext(), Ty->getBitWidth() * 2); + + const SCEVAddRecExpr *ExtendAfterOp = + dyn_cast<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); + if (ExtendAfterOp) { + const SCEV *ExtendedStart = SE.getSignExtendExpr(AR->getStart(), WideTy); + const SCEV *ExtendedStep = + SE.getSignExtendExpr(AR->getStepRecurrence(SE), WideTy); + + bool NoSignedWrap = ExtendAfterOp->getStart() == ExtendedStart && + ExtendAfterOp->getStepRecurrence(SE) == ExtendedStep; + + if (NoSignedWrap) + return true; + } + + // We may have proved this when computing the sign extension above. + return AR->getNoWrapFlags(SCEV::FlagNSW) != SCEV::FlagAnyWrap; + }; + + // Here we check whether the suggested AddRec is an induction variable that + // can be handled (i.e. with known constant step), and if yes, calculate its + // step and identify whether it is increasing or decreasing. + auto IsInductionVar = [&](const SCEVAddRecExpr *AR, bool &IsIncreasing, + ConstantInt *&StepCI) { + if (!AR->isAffine()) + return false; + + // Currently we only work with induction variables that have been proved to + // not wrap. This restriction can potentially be lifted in the future. + + if (!HasNoSignedWrap(AR)) + return false; + + if (const SCEVConstant *StepExpr = + dyn_cast<SCEVConstant>(AR->getStepRecurrence(SE))) { + StepCI = StepExpr->getValue(); + assert(!StepCI->isZero() && "Zero step?"); + IsIncreasing = !StepCI->isNegative(); + return true; + } + + return false; + }; + + // `ICI` is interpreted as taking the backedge if the *next* value of the + // induction variable satisfies some constraint. + + const SCEVAddRecExpr *IndVarBase = cast<SCEVAddRecExpr>(LeftSCEV); + bool IsIncreasing = false; + bool IsSignedPredicate = true; + ConstantInt *StepCI; + if (!IsInductionVar(IndVarBase, IsIncreasing, StepCI)) { + FailureReason = "LHS in icmp not induction variable"; + return None; + } + + const SCEV *StartNext = IndVarBase->getStart(); + const SCEV *Addend = SE.getNegativeSCEV(IndVarBase->getStepRecurrence(SE)); + const SCEV *IndVarStart = SE.getAddExpr(StartNext, Addend); + const SCEV *Step = SE.getSCEV(StepCI); + + ConstantInt *One = ConstantInt::get(IndVarTy, 1); + if (IsIncreasing) { + bool DecreasedRightValueByOne = false; + if (StepCI->isOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (++i != len) { while (++i < len) { + // ... ---> ... + // } } + // If both parts are known non-negative, it is profitable to use + // unsigned comparison in increasing loop. This allows us to make the + // comparison check against "RightSCEV + 1" more optimistic. + if (SE.isKnownNonNegative(IndVarStart) && + SE.isKnownNonNegative(RightSCEV)) + Pred = ICmpInst::ICMP_ULT; + else + Pred = ICmpInst::ICMP_SLT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeMin(SE, RightSCEV, /* IsSignedPredicate */ true)) { + // while (true) { while (true) { + // if (++i == len) ---> if (++i > len - 1) + // break; break; + // ... ... + // } } + // TODO: Insert ICMP_UGT if both are non-negative? + Pred = ICmpInst::ICMP_SGT; + RightSCEV = SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())); + DecreasedRightValueByOne = true; + } + } + + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + bool FoundExpectedPred = + (LTPred && LatchBrExitIdx == 1) || (GTPred && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp slt semantically, found something else"; + return None; + } + + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return None; + } + + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSignedPredicate ? CmpInst::ICMP_SLT : CmpInst::ICMP_ULT; + + if (LatchBrExitIdx == 0) { + const SCEV *StepMinusOne = SE.getMinusSCEV(Step, + SE.getOne(Step->getType())); + if (SumCanReachMax(SE, RightSCEV, StepMinusOne, IsSignedPredicate)) { + // TODO: this restriction is easily removable -- we just have to + // remember that the icmp was an slt and not an sle. + FailureReason = "limit may overflow when coercing le to lt"; + return None; + } + + if (!SE.isLoopEntryGuardedByCond( + &L, BoundPred, IndVarStart, + SE.getAddExpr(RightSCEV, Step))) { + FailureReason = "Induction variable start not bounded by upper limit"; + return None; + } + + // We need to increase the right value unless we have already decreased + // it virtually when we replaced EQ with SGT. + if (!DecreasedRightValueByOne) { + IRBuilder<> B(Preheader->getTerminator()); + RightValue = B.CreateAdd(RightValue, One); + } + } else { + if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { + FailureReason = "Induction variable start not bounded by upper limit"; + return None; + } + assert(!DecreasedRightValueByOne && + "Right value can be decreased only for LatchBrExitIdx == 0!"); + } + } else { + bool IncreasedRightValueByOne = false; + if (StepCI->isMinusOne()) { + // Try to turn eq/ne predicates to those we can work with. + if (Pred == ICmpInst::ICMP_NE && LatchBrExitIdx == 1) + // while (--i != len) { while (--i > len) { + // ... ---> ... + // } } + // We intentionally don't turn the predicate into UGT even if we know + // that both operands are non-negative, because it will only pessimize + // our check against "RightSCEV - 1". + Pred = ICmpInst::ICMP_SGT; + else if (Pred == ICmpInst::ICMP_EQ && LatchBrExitIdx == 0 && + !CanBeMax(SE, RightSCEV, /* IsSignedPredicate */ true)) { + // while (true) { while (true) { + // if (--i == len) ---> if (--i < len + 1) + // break; break; + // ... ... + // } } + // TODO: Insert ICMP_ULT if both are non-negative? + Pred = ICmpInst::ICMP_SLT; + RightSCEV = SE.getAddExpr(RightSCEV, SE.getOne(RightSCEV->getType())); + IncreasedRightValueByOne = true; + } + } + + bool LTPred = (Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_ULT); + bool GTPred = (Pred == ICmpInst::ICMP_SGT || Pred == ICmpInst::ICMP_UGT); + + bool FoundExpectedPred = + (GTPred && LatchBrExitIdx == 1) || (LTPred && LatchBrExitIdx == 0); + + if (!FoundExpectedPred) { + FailureReason = "expected icmp sgt semantically, found something else"; + return None; + } + + IsSignedPredicate = + Pred == ICmpInst::ICMP_SLT || Pred == ICmpInst::ICMP_SGT; + + if (!IsSignedPredicate && !AllowUnsignedLatchCondition) { + FailureReason = "unsigned latch conditions are explicitly prohibited"; + return None; + } + + // The predicate that we need to check that the induction variable lies + // within bounds. + ICmpInst::Predicate BoundPred = + IsSignedPredicate ? CmpInst::ICMP_SGT : CmpInst::ICMP_UGT; + + if (LatchBrExitIdx == 0) { + const SCEV *StepPlusOne = SE.getAddExpr(Step, SE.getOne(Step->getType())); + if (SumCanReachMin(SE, RightSCEV, StepPlusOne, IsSignedPredicate)) { + // TODO: this restriction is easily removable -- we just have to + // remember that the icmp was an sgt and not an sge. + FailureReason = "limit may overflow when coercing ge to gt"; + return None; + } + + if (!SE.isLoopEntryGuardedByCond( + &L, BoundPred, IndVarStart, + SE.getMinusSCEV(RightSCEV, SE.getOne(RightSCEV->getType())))) { + FailureReason = "Induction variable start not bounded by lower limit"; + return None; + } + + // We need to decrease the right value unless we have already increased + // it virtually when we replaced EQ with SLT. + if (!IncreasedRightValueByOne) { + IRBuilder<> B(Preheader->getTerminator()); + RightValue = B.CreateSub(RightValue, One); + } + } else { + if (!SE.isLoopEntryGuardedByCond(&L, BoundPred, IndVarStart, RightSCEV)) { + FailureReason = "Induction variable start not bounded by lower limit"; + return None; + } + assert(!IncreasedRightValueByOne && + "Right value can be increased only for LatchBrExitIdx == 0!"); + } + } + BasicBlock *LatchExit = LatchBr->getSuccessor(LatchBrExitIdx); + + assert(SE.getLoopDisposition(LatchCount, &L) == + ScalarEvolution::LoopInvariant && + "loop variant exit count doesn't make sense!"); + + assert(!L.contains(LatchExit) && "expected an exit block!"); + const DataLayout &DL = Preheader->getModule()->getDataLayout(); + Value *IndVarStartV = + SCEVExpander(SE, DL, "irce") + .expandCodeFor(IndVarStart, IndVarTy, Preheader->getTerminator()); + IndVarStartV->setName("indvar.start"); + + LoopStructure Result; + + Result.Tag = "main"; + Result.Header = Header; + Result.Latch = Latch; + Result.LatchBr = LatchBr; + Result.LatchExit = LatchExit; + Result.LatchBrExitIdx = LatchBrExitIdx; + Result.IndVarStart = IndVarStartV; + Result.IndVarStep = StepCI; + Result.IndVarBase = LeftValue; + Result.IndVarIncreasing = IsIncreasing; + Result.LoopExitAt = RightValue; + Result.IsSignedPredicate = IsSignedPredicate; + + FailureReason = nullptr; + + return Result; +} + +Optional<LoopConstrainer::SubRanges> +LoopConstrainer::calculateSubRanges(bool IsSignedPredicate) const { + IntegerType *Ty = cast<IntegerType>(LatchTakenCount->getType()); + + if (Range.getType() != Ty) + return None; + + LoopConstrainer::SubRanges Result; + + // I think we can be more aggressive here and make this nuw / nsw if the + // addition that feeds into the icmp for the latch's terminating branch is nuw + // / nsw. In any case, a wrapping 2's complement addition is safe. + const SCEV *Start = SE.getSCEV(MainLoopStructure.IndVarStart); + const SCEV *End = SE.getSCEV(MainLoopStructure.LoopExitAt); + + bool Increasing = MainLoopStructure.IndVarIncreasing; + + // We compute `Smallest` and `Greatest` such that [Smallest, Greatest), or + // [Smallest, GreatestSeen] is the range of values the induction variable + // takes. + + const SCEV *Smallest = nullptr, *Greatest = nullptr, *GreatestSeen = nullptr; + + const SCEV *One = SE.getOne(Ty); + if (Increasing) { + Smallest = Start; + Greatest = End; + // No overflow, because the range [Smallest, GreatestSeen] is not empty. + GreatestSeen = SE.getMinusSCEV(End, One); + } else { + // These two computations may sign-overflow. Here is why that is okay: + // + // We know that the induction variable does not sign-overflow on any + // iteration except the last one, and it starts at `Start` and ends at + // `End`, decrementing by one every time. + // + // * if `Smallest` sign-overflows we know `End` is `INT_SMAX`. Since the + // induction variable is decreasing we know that that the smallest value + // the loop body is actually executed with is `INT_SMIN` == `Smallest`. + // + // * if `Greatest` sign-overflows, we know it can only be `INT_SMIN`. In + // that case, `Clamp` will always return `Smallest` and + // [`Result.LowLimit`, `Result.HighLimit`) = [`Smallest`, `Smallest`) + // will be an empty range. Returning an empty range is always safe. + + Smallest = SE.getAddExpr(End, One); + Greatest = SE.getAddExpr(Start, One); + GreatestSeen = Start; + } + + auto Clamp = [this, Smallest, Greatest, IsSignedPredicate](const SCEV *S) { + return IsSignedPredicate + ? SE.getSMaxExpr(Smallest, SE.getSMinExpr(Greatest, S)) + : SE.getUMaxExpr(Smallest, SE.getUMinExpr(Greatest, S)); + }; + + // In some cases we can prove that we don't need a pre or post loop. + ICmpInst::Predicate PredLE = + IsSignedPredicate ? ICmpInst::ICMP_SLE : ICmpInst::ICMP_ULE; + ICmpInst::Predicate PredLT = + IsSignedPredicate ? ICmpInst::ICMP_SLT : ICmpInst::ICMP_ULT; + + bool ProvablyNoPreloop = + SE.isKnownPredicate(PredLE, Range.getBegin(), Smallest); + if (!ProvablyNoPreloop) + Result.LowLimit = Clamp(Range.getBegin()); + + bool ProvablyNoPostLoop = + SE.isKnownPredicate(PredLT, GreatestSeen, Range.getEnd()); + if (!ProvablyNoPostLoop) + Result.HighLimit = Clamp(Range.getEnd()); + + return Result; +} + +void LoopConstrainer::cloneLoop(LoopConstrainer::ClonedLoop &Result, + const char *Tag) const { + for (BasicBlock *BB : OriginalLoop.getBlocks()) { + BasicBlock *Clone = CloneBasicBlock(BB, Result.Map, Twine(".") + Tag, &F); + Result.Blocks.push_back(Clone); + Result.Map[BB] = Clone; + } + + auto GetClonedValue = [&Result](Value *V) { + assert(V && "null values not in domain!"); + auto It = Result.Map.find(V); + if (It == Result.Map.end()) + return V; + return static_cast<Value *>(It->second); + }; + + auto *ClonedLatch = + cast<BasicBlock>(GetClonedValue(OriginalLoop.getLoopLatch())); + ClonedLatch->getTerminator()->setMetadata(ClonedLoopTag, + MDNode::get(Ctx, {})); + + Result.Structure = MainLoopStructure.map(GetClonedValue); + Result.Structure.Tag = Tag; + + for (unsigned i = 0, e = Result.Blocks.size(); i != e; ++i) { + BasicBlock *ClonedBB = Result.Blocks[i]; + BasicBlock *OriginalBB = OriginalLoop.getBlocks()[i]; + + assert(Result.Map[OriginalBB] == ClonedBB && "invariant!"); + + for (Instruction &I : *ClonedBB) + RemapInstruction(&I, Result.Map, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + // Exit blocks will now have one more predecessor and their PHI nodes need + // to be edited to reflect that. No phi nodes need to be introduced because + // the loop is in LCSSA. + + for (auto *SBB : successors(OriginalBB)) { + if (OriginalLoop.contains(SBB)) + continue; // not an exit block + + for (PHINode &PN : SBB->phis()) { + Value *OldIncoming = PN.getIncomingValueForBlock(OriginalBB); + PN.addIncoming(GetClonedValue(OldIncoming), ClonedBB); + } + } + } +} + +LoopConstrainer::RewrittenRangeInfo LoopConstrainer::changeIterationSpaceEnd( + const LoopStructure &LS, BasicBlock *Preheader, Value *ExitSubloopAt, + BasicBlock *ContinuationBlock) const { + // We start with a loop with a single latch: + // + // +--------------------+ + // | | + // | preheader | + // | | + // +--------+-----------+ + // | ----------------\ + // | / | + // +--------v----v------+ | + // | | | + // | header | | + // | | | + // +--------------------+ | + // | + // ..... | + // | + // +--------------------+ | + // | | | + // | latch >----------/ + // | | + // +-------v------------+ + // | + // | + // | +--------------------+ + // | | | + // +---> original exit | + // | | + // +--------------------+ + // + // We change the control flow to look like + // + // + // +--------------------+ + // | | + // | preheader >-------------------------+ + // | | | + // +--------v-----------+ | + // | /-------------+ | + // | / | | + // +--------v--v--------+ | | + // | | | | + // | header | | +--------+ | + // | | | | | | + // +--------------------+ | | +-----v-----v-----------+ + // | | | | + // | | | .pseudo.exit | + // | | | | + // | | +-----------v-----------+ + // | | | + // ..... | | | + // | | +--------v-------------+ + // +--------------------+ | | | | + // | | | | | ContinuationBlock | + // | latch >------+ | | | + // | | | +----------------------+ + // +---------v----------+ | + // | | + // | | + // | +---------------^-----+ + // | | | + // +-----> .exit.selector | + // | | + // +----------v----------+ + // | + // +--------------------+ | + // | | | + // | original exit <----+ + // | | + // +--------------------+ + + RewrittenRangeInfo RRI; + + BasicBlock *BBInsertLocation = LS.Latch->getNextNode(); + RRI.ExitSelector = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".exit.selector", + &F, BBInsertLocation); + RRI.PseudoExit = BasicBlock::Create(Ctx, Twine(LS.Tag) + ".pseudo.exit", &F, + BBInsertLocation); + + BranchInst *PreheaderJump = cast<BranchInst>(Preheader->getTerminator()); + bool Increasing = LS.IndVarIncreasing; + bool IsSignedPredicate = LS.IsSignedPredicate; + + IRBuilder<> B(PreheaderJump); + + // EnterLoopCond - is it okay to start executing this `LS'? + Value *EnterLoopCond = nullptr; + if (Increasing) + EnterLoopCond = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarStart, ExitSubloopAt) + : B.CreateICmpULT(LS.IndVarStart, ExitSubloopAt); + else + EnterLoopCond = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarStart, ExitSubloopAt) + : B.CreateICmpUGT(LS.IndVarStart, ExitSubloopAt); + + B.CreateCondBr(EnterLoopCond, LS.Header, RRI.PseudoExit); + PreheaderJump->eraseFromParent(); + + LS.LatchBr->setSuccessor(LS.LatchBrExitIdx, RRI.ExitSelector); + B.SetInsertPoint(LS.LatchBr); + Value *TakeBackedgeLoopCond = nullptr; + if (Increasing) + TakeBackedgeLoopCond = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarBase, ExitSubloopAt) + : B.CreateICmpULT(LS.IndVarBase, ExitSubloopAt); + else + TakeBackedgeLoopCond = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarBase, ExitSubloopAt) + : B.CreateICmpUGT(LS.IndVarBase, ExitSubloopAt); + Value *CondForBranch = LS.LatchBrExitIdx == 1 + ? TakeBackedgeLoopCond + : B.CreateNot(TakeBackedgeLoopCond); + + LS.LatchBr->setCondition(CondForBranch); + + B.SetInsertPoint(RRI.ExitSelector); + + // IterationsLeft - are there any more iterations left, given the original + // upper bound on the induction variable? If not, we branch to the "real" + // exit. + Value *IterationsLeft = nullptr; + if (Increasing) + IterationsLeft = IsSignedPredicate + ? B.CreateICmpSLT(LS.IndVarBase, LS.LoopExitAt) + : B.CreateICmpULT(LS.IndVarBase, LS.LoopExitAt); + else + IterationsLeft = IsSignedPredicate + ? B.CreateICmpSGT(LS.IndVarBase, LS.LoopExitAt) + : B.CreateICmpUGT(LS.IndVarBase, LS.LoopExitAt); + B.CreateCondBr(IterationsLeft, RRI.PseudoExit, LS.LatchExit); + + BranchInst *BranchToContinuation = + BranchInst::Create(ContinuationBlock, RRI.PseudoExit); + + // We emit PHI nodes into `RRI.PseudoExit' that compute the "latest" value of + // each of the PHI nodes in the loop header. This feeds into the initial + // value of the same PHI nodes if/when we continue execution. + for (PHINode &PN : LS.Header->phis()) { + PHINode *NewPHI = PHINode::Create(PN.getType(), 2, PN.getName() + ".copy", + BranchToContinuation); + + NewPHI->addIncoming(PN.getIncomingValueForBlock(Preheader), Preheader); + NewPHI->addIncoming(PN.getIncomingValueForBlock(LS.Latch), + RRI.ExitSelector); + RRI.PHIValuesAtPseudoExit.push_back(NewPHI); + } + + RRI.IndVarEnd = PHINode::Create(LS.IndVarBase->getType(), 2, "indvar.end", + BranchToContinuation); + RRI.IndVarEnd->addIncoming(LS.IndVarStart, Preheader); + RRI.IndVarEnd->addIncoming(LS.IndVarBase, RRI.ExitSelector); + + // The latch exit now has a branch from `RRI.ExitSelector' instead of + // `LS.Latch'. The PHI nodes need to be updated to reflect that. + for (PHINode &PN : LS.LatchExit->phis()) + replacePHIBlock(&PN, LS.Latch, RRI.ExitSelector); + + return RRI; +} + +void LoopConstrainer::rewriteIncomingValuesForPHIs( + LoopStructure &LS, BasicBlock *ContinuationBlock, + const LoopConstrainer::RewrittenRangeInfo &RRI) const { + unsigned PHIIndex = 0; + for (PHINode &PN : LS.Header->phis()) + for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) + if (PN.getIncomingBlock(i) == ContinuationBlock) + PN.setIncomingValue(i, RRI.PHIValuesAtPseudoExit[PHIIndex++]); + + LS.IndVarStart = RRI.IndVarEnd; +} + +BasicBlock *LoopConstrainer::createPreheader(const LoopStructure &LS, + BasicBlock *OldPreheader, + const char *Tag) const { + BasicBlock *Preheader = BasicBlock::Create(Ctx, Tag, &F, LS.Header); + BranchInst::Create(LS.Header, Preheader); + + for (PHINode &PN : LS.Header->phis()) + for (unsigned i = 0, e = PN.getNumIncomingValues(); i < e; ++i) + replacePHIBlock(&PN, OldPreheader, Preheader); + + return Preheader; +} + +void LoopConstrainer::addToParentLoopIfNeeded(ArrayRef<BasicBlock *> BBs) { + Loop *ParentLoop = OriginalLoop.getParentLoop(); + if (!ParentLoop) + return; + + for (BasicBlock *BB : BBs) + ParentLoop->addBasicBlockToLoop(BB, LI); +} + +Loop *LoopConstrainer::createClonedLoopStructure(Loop *Original, Loop *Parent, + ValueToValueMapTy &VM) { + Loop &New = *LI.AllocateLoop(); + if (Parent) + Parent->addChildLoop(&New); + else + LI.addTopLevelLoop(&New); + LPM.addLoop(New); + + // Add all of the blocks in Original to the new loop. + for (auto *BB : Original->blocks()) + if (LI.getLoopFor(BB) == Original) + New.addBasicBlockToLoop(cast<BasicBlock>(VM[BB]), LI); + + // Add all of the subloops to the new loop. + for (Loop *SubLoop : *Original) + createClonedLoopStructure(SubLoop, &New, VM); + + return &New; +} + +bool LoopConstrainer::run() { + BasicBlock *Preheader = nullptr; + LatchTakenCount = SE.getExitCount(&OriginalLoop, MainLoopStructure.Latch); + Preheader = OriginalLoop.getLoopPreheader(); + assert(!isa<SCEVCouldNotCompute>(LatchTakenCount) && Preheader != nullptr && + "preconditions!"); + + OriginalPreheader = Preheader; + MainLoopPreheader = Preheader; + + bool IsSignedPredicate = MainLoopStructure.IsSignedPredicate; + Optional<SubRanges> MaybeSR = calculateSubRanges(IsSignedPredicate); + if (!MaybeSR.hasValue()) { + DEBUG(dbgs() << "irce: could not compute subranges\n"); + return false; + } + + SubRanges SR = MaybeSR.getValue(); + bool Increasing = MainLoopStructure.IndVarIncreasing; + IntegerType *IVTy = + cast<IntegerType>(MainLoopStructure.IndVarBase->getType()); + + SCEVExpander Expander(SE, F.getParent()->getDataLayout(), "irce"); + Instruction *InsertPt = OriginalPreheader->getTerminator(); + + // It would have been better to make `PreLoop' and `PostLoop' + // `Optional<ClonedLoop>'s, but `ValueToValueMapTy' does not have a copy + // constructor. + ClonedLoop PreLoop, PostLoop; + bool NeedsPreLoop = + Increasing ? SR.LowLimit.hasValue() : SR.HighLimit.hasValue(); + bool NeedsPostLoop = + Increasing ? SR.HighLimit.hasValue() : SR.LowLimit.hasValue(); + + Value *ExitPreLoopAt = nullptr; + Value *ExitMainLoopAt = nullptr; + const SCEVConstant *MinusOneS = + cast<SCEVConstant>(SE.getConstant(IVTy, -1, true /* isSigned */)); + + if (NeedsPreLoop) { + const SCEV *ExitPreLoopAtSCEV = nullptr; + + if (Increasing) + ExitPreLoopAtSCEV = *SR.LowLimit; + else { + if (CanBeMin(SE, *SR.HighLimit, IsSignedPredicate)) { + DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "preloop exit limit. HighLimit = " << *(*SR.HighLimit) + << "\n"); + return false; + } + ExitPreLoopAtSCEV = SE.getAddExpr(*SR.HighLimit, MinusOneS); + } + + if (!isSafeToExpandAt(ExitPreLoopAtSCEV, InsertPt, SE)) { + DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " preloop exit limit " << *ExitPreLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() << "\n"); + return false; + } + + ExitPreLoopAt = Expander.expandCodeFor(ExitPreLoopAtSCEV, IVTy, InsertPt); + ExitPreLoopAt->setName("exit.preloop.at"); + } + + if (NeedsPostLoop) { + const SCEV *ExitMainLoopAtSCEV = nullptr; + + if (Increasing) + ExitMainLoopAtSCEV = *SR.HighLimit; + else { + if (CanBeMin(SE, *SR.LowLimit, IsSignedPredicate)) { + DEBUG(dbgs() << "irce: could not prove no-overflow when computing " + << "mainloop exit limit. LowLimit = " << *(*SR.LowLimit) + << "\n"); + return false; + } + ExitMainLoopAtSCEV = SE.getAddExpr(*SR.LowLimit, MinusOneS); + } + + if (!isSafeToExpandAt(ExitMainLoopAtSCEV, InsertPt, SE)) { + DEBUG(dbgs() << "irce: could not prove that it is safe to expand the" + << " main loop exit limit " << *ExitMainLoopAtSCEV + << " at block " << InsertPt->getParent()->getName() << "\n"); + return false; + } + + ExitMainLoopAt = Expander.expandCodeFor(ExitMainLoopAtSCEV, IVTy, InsertPt); + ExitMainLoopAt->setName("exit.mainloop.at"); + } + + // We clone these ahead of time so that we don't have to deal with changing + // and temporarily invalid IR as we transform the loops. + if (NeedsPreLoop) + cloneLoop(PreLoop, "preloop"); + if (NeedsPostLoop) + cloneLoop(PostLoop, "postloop"); + + RewrittenRangeInfo PreLoopRRI; + + if (NeedsPreLoop) { + Preheader->getTerminator()->replaceUsesOfWith(MainLoopStructure.Header, + PreLoop.Structure.Header); + + MainLoopPreheader = + createPreheader(MainLoopStructure, Preheader, "mainloop"); + PreLoopRRI = changeIterationSpaceEnd(PreLoop.Structure, Preheader, + ExitPreLoopAt, MainLoopPreheader); + rewriteIncomingValuesForPHIs(MainLoopStructure, MainLoopPreheader, + PreLoopRRI); + } + + BasicBlock *PostLoopPreheader = nullptr; + RewrittenRangeInfo PostLoopRRI; + + if (NeedsPostLoop) { + PostLoopPreheader = + createPreheader(PostLoop.Structure, Preheader, "postloop"); + PostLoopRRI = changeIterationSpaceEnd(MainLoopStructure, MainLoopPreheader, + ExitMainLoopAt, PostLoopPreheader); + rewriteIncomingValuesForPHIs(PostLoop.Structure, PostLoopPreheader, + PostLoopRRI); + } + + BasicBlock *NewMainLoopPreheader = + MainLoopPreheader != Preheader ? MainLoopPreheader : nullptr; + BasicBlock *NewBlocks[] = {PostLoopPreheader, PreLoopRRI.PseudoExit, + PreLoopRRI.ExitSelector, PostLoopRRI.PseudoExit, + PostLoopRRI.ExitSelector, NewMainLoopPreheader}; + + // Some of the above may be nullptr, filter them out before passing to + // addToParentLoopIfNeeded. + auto NewBlocksEnd = + std::remove(std::begin(NewBlocks), std::end(NewBlocks), nullptr); + + addToParentLoopIfNeeded(makeArrayRef(std::begin(NewBlocks), NewBlocksEnd)); + + DT.recalculate(F); + + // We need to first add all the pre and post loop blocks into the loop + // structures (as part of createClonedLoopStructure), and then update the + // LCSSA form and LoopSimplifyForm. This is necessary for correctly updating + // LI when LoopSimplifyForm is generated. + Loop *PreL = nullptr, *PostL = nullptr; + if (!PreLoop.Blocks.empty()) { + PreL = createClonedLoopStructure( + &OriginalLoop, OriginalLoop.getParentLoop(), PreLoop.Map); + } + + if (!PostLoop.Blocks.empty()) { + PostL = createClonedLoopStructure( + &OriginalLoop, OriginalLoop.getParentLoop(), PostLoop.Map); + } + + // This function canonicalizes the loop into Loop-Simplify and LCSSA forms. + auto CanonicalizeLoop = [&] (Loop *L, bool IsOriginalLoop) { + formLCSSARecursively(*L, DT, &LI, &SE); + simplifyLoop(L, &DT, &LI, &SE, nullptr, true); + // Pre/post loops are slow paths, we do not need to perform any loop + // optimizations on them. + if (!IsOriginalLoop) + DisableAllLoopOptsOnLoop(*L); + }; + if (PreL) + CanonicalizeLoop(PreL, false); + if (PostL) + CanonicalizeLoop(PostL, false); + CanonicalizeLoop(&OriginalLoop, true); + + return true; +} + +/// Computes and returns a range of values for the induction variable (IndVar) +/// in which the range check can be safely elided. If it cannot compute such a +/// range, returns None. +Optional<InductiveRangeCheck::Range> +InductiveRangeCheck::computeSafeIterationSpace( + ScalarEvolution &SE, const SCEVAddRecExpr *IndVar, + bool IsLatchSigned) const { + // IndVar is of the form "A + B * I" (where "I" is the canonical induction + // variable, that may or may not exist as a real llvm::Value in the loop) and + // this inductive range check is a range check on the "C + D * I" ("C" is + // getBegin() and "D" is getStep()). We rewrite the value being range + // checked to "M + N * IndVar" where "N" = "D * B^(-1)" and "M" = "C - NA". + // + // The actual inequalities we solve are of the form + // + // 0 <= M + 1 * IndVar < L given L >= 0 (i.e. N == 1) + // + // Here L stands for upper limit of the safe iteration space. + // The inequality is satisfied by (0 - M) <= IndVar < (L - M). To avoid + // overflows when calculating (0 - M) and (L - M) we, depending on type of + // IV's iteration space, limit the calculations by borders of the iteration + // space. For example, if IndVar is unsigned, (0 - M) overflows for any M > 0. + // If we figured out that "anything greater than (-M) is safe", we strengthen + // this to "everything greater than 0 is safe", assuming that values between + // -M and 0 just do not exist in unsigned iteration space, and we don't want + // to deal with overflown values. + + if (!IndVar->isAffine()) + return None; + + const SCEV *A = IndVar->getStart(); + const SCEVConstant *B = dyn_cast<SCEVConstant>(IndVar->getStepRecurrence(SE)); + if (!B) + return None; + assert(!B->isZero() && "Recurrence with zero step?"); + + const SCEV *C = getBegin(); + const SCEVConstant *D = dyn_cast<SCEVConstant>(getStep()); + if (D != B) + return None; + + assert(!D->getValue()->isZero() && "Recurrence with zero step?"); + unsigned BitWidth = cast<IntegerType>(IndVar->getType())->getBitWidth(); + const SCEV *SIntMax = SE.getConstant(APInt::getSignedMaxValue(BitWidth)); + + // Substract Y from X so that it does not go through border of the IV + // iteration space. Mathematically, it is equivalent to: + // + // ClampedSubstract(X, Y) = min(max(X - Y, INT_MIN), INT_MAX). [1] + // + // In [1], 'X - Y' is a mathematical substraction (result is not bounded to + // any width of bit grid). But after we take min/max, the result is + // guaranteed to be within [INT_MIN, INT_MAX]. + // + // In [1], INT_MAX and INT_MIN are respectively signed and unsigned max/min + // values, depending on type of latch condition that defines IV iteration + // space. + auto ClampedSubstract = [&](const SCEV *X, const SCEV *Y) { + assert(SE.isKnownNonNegative(X) && + "We can only substract from values in [0; SINT_MAX]!"); + if (IsLatchSigned) { + // X is a number from signed range, Y is interpreted as signed. + // Even if Y is SINT_MAX, (X - Y) does not reach SINT_MIN. So the only + // thing we should care about is that we didn't cross SINT_MAX. + // So, if Y is positive, we substract Y safely. + // Rule 1: Y > 0 ---> Y. + // If 0 <= -Y <= (SINT_MAX - X), we substract Y safely. + // Rule 2: Y >=s (X - SINT_MAX) ---> Y. + // If 0 <= (SINT_MAX - X) < -Y, we can only substract (X - SINT_MAX). + // Rule 3: Y <s (X - SINT_MAX) ---> (X - SINT_MAX). + // It gives us smax(Y, X - SINT_MAX) to substract in all cases. + const SCEV *XMinusSIntMax = SE.getMinusSCEV(X, SIntMax); + return SE.getMinusSCEV(X, SE.getSMaxExpr(Y, XMinusSIntMax), + SCEV::FlagNSW); + } else + // X is a number from unsigned range, Y is interpreted as signed. + // Even if Y is SINT_MIN, (X - Y) does not reach UINT_MAX. So the only + // thing we should care about is that we didn't cross zero. + // So, if Y is negative, we substract Y safely. + // Rule 1: Y <s 0 ---> Y. + // If 0 <= Y <= X, we substract Y safely. + // Rule 2: Y <=s X ---> Y. + // If 0 <= X < Y, we should stop at 0 and can only substract X. + // Rule 3: Y >s X ---> X. + // It gives us smin(X, Y) to substract in all cases. + return SE.getMinusSCEV(X, SE.getSMinExpr(X, Y), SCEV::FlagNUW); + }; + const SCEV *M = SE.getMinusSCEV(C, A); + const SCEV *Zero = SE.getZero(M->getType()); + const SCEV *Begin = ClampedSubstract(Zero, M); + const SCEV *L = nullptr; + + // We strengthen "0 <= I" to "0 <= I < INT_SMAX" and "I < L" to "0 <= I < L". + // We can potentially do much better here. + if (const SCEV *EndLimit = getEnd()) + L = EndLimit; + else { + assert(Kind == InductiveRangeCheck::RANGE_CHECK_LOWER && "invariant!"); + L = SIntMax; + } + const SCEV *End = ClampedSubstract(L, M); + return InductiveRangeCheck::Range(Begin, End); +} + +static Optional<InductiveRangeCheck::Range> +IntersectSignedRange(ScalarEvolution &SE, + const Optional<InductiveRangeCheck::Range> &R1, + const InductiveRangeCheck::Range &R2) { + if (R2.isEmpty(SE, /* IsSigned */ true)) + return None; + if (!R1.hasValue()) + return R2; + auto &R1Value = R1.getValue(); + // We never return empty ranges from this function, and R1 is supposed to be + // a result of intersection. Thus, R1 is never empty. + assert(!R1Value.isEmpty(SE, /* IsSigned */ true) && + "We should never have empty R1!"); + + // TODO: we could widen the smaller range and have this work; but for now we + // bail out to keep things simple. + if (R1Value.getType() != R2.getType()) + return None; + + const SCEV *NewBegin = SE.getSMaxExpr(R1Value.getBegin(), R2.getBegin()); + const SCEV *NewEnd = SE.getSMinExpr(R1Value.getEnd(), R2.getEnd()); + + // If the resulting range is empty, just return None. + auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); + if (Ret.isEmpty(SE, /* IsSigned */ true)) + return None; + return Ret; +} + +static Optional<InductiveRangeCheck::Range> +IntersectUnsignedRange(ScalarEvolution &SE, + const Optional<InductiveRangeCheck::Range> &R1, + const InductiveRangeCheck::Range &R2) { + if (R2.isEmpty(SE, /* IsSigned */ false)) + return None; + if (!R1.hasValue()) + return R2; + auto &R1Value = R1.getValue(); + // We never return empty ranges from this function, and R1 is supposed to be + // a result of intersection. Thus, R1 is never empty. + assert(!R1Value.isEmpty(SE, /* IsSigned */ false) && + "We should never have empty R1!"); + + // TODO: we could widen the smaller range and have this work; but for now we + // bail out to keep things simple. + if (R1Value.getType() != R2.getType()) + return None; + + const SCEV *NewBegin = SE.getUMaxExpr(R1Value.getBegin(), R2.getBegin()); + const SCEV *NewEnd = SE.getUMinExpr(R1Value.getEnd(), R2.getEnd()); + + // If the resulting range is empty, just return None. + auto Ret = InductiveRangeCheck::Range(NewBegin, NewEnd); + if (Ret.isEmpty(SE, /* IsSigned */ false)) + return None; + return Ret; +} + +bool InductiveRangeCheckElimination::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + + if (L->getBlocks().size() >= LoopSizeCutoff) { + DEBUG(dbgs() << "irce: giving up constraining loop, too large\n";); + return false; + } + + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) { + DEBUG(dbgs() << "irce: loop has no preheader, leaving\n"); + return false; + } + + LLVMContext &Context = Preheader->getContext(); + SmallVector<InductiveRangeCheck, 16> RangeChecks; + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + BranchProbabilityInfo &BPI = + getAnalysis<BranchProbabilityInfoWrapperPass>().getBPI(); + + for (auto BBI : L->getBlocks()) + if (BranchInst *TBI = dyn_cast<BranchInst>(BBI->getTerminator())) + InductiveRangeCheck::extractRangeChecksFromBranch(TBI, L, SE, BPI, + RangeChecks); + + if (RangeChecks.empty()) + return false; + + auto PrintRecognizedRangeChecks = [&](raw_ostream &OS) { + OS << "irce: looking at loop "; L->print(OS); + OS << "irce: loop has " << RangeChecks.size() + << " inductive range checks: \n"; + for (InductiveRangeCheck &IRC : RangeChecks) + IRC.print(OS); + }; + + DEBUG(PrintRecognizedRangeChecks(dbgs())); + + if (PrintRangeChecks) + PrintRecognizedRangeChecks(errs()); + + const char *FailureReason = nullptr; + Optional<LoopStructure> MaybeLoopStructure = + LoopStructure::parseLoopStructure(SE, BPI, *L, FailureReason); + if (!MaybeLoopStructure.hasValue()) { + DEBUG(dbgs() << "irce: could not parse loop structure: " << FailureReason + << "\n";); + return false; + } + LoopStructure LS = MaybeLoopStructure.getValue(); + const SCEVAddRecExpr *IndVar = + cast<SCEVAddRecExpr>(SE.getMinusSCEV(SE.getSCEV(LS.IndVarBase), SE.getSCEV(LS.IndVarStep))); + + Optional<InductiveRangeCheck::Range> SafeIterRange; + Instruction *ExprInsertPt = Preheader->getTerminator(); + + SmallVector<InductiveRangeCheck, 4> RangeChecksToEliminate; + // Basing on the type of latch predicate, we interpret the IV iteration range + // as signed or unsigned range. We use different min/max functions (signed or + // unsigned) when intersecting this range with safe iteration ranges implied + // by range checks. + auto IntersectRange = + LS.IsSignedPredicate ? IntersectSignedRange : IntersectUnsignedRange; + + IRBuilder<> B(ExprInsertPt); + for (InductiveRangeCheck &IRC : RangeChecks) { + auto Result = IRC.computeSafeIterationSpace(SE, IndVar, + LS.IsSignedPredicate); + if (Result.hasValue()) { + auto MaybeSafeIterRange = + IntersectRange(SE, SafeIterRange, Result.getValue()); + if (MaybeSafeIterRange.hasValue()) { + assert( + !MaybeSafeIterRange.getValue().isEmpty(SE, LS.IsSignedPredicate) && + "We should never return empty ranges!"); + RangeChecksToEliminate.push_back(IRC); + SafeIterRange = MaybeSafeIterRange.getValue(); + } + } + } + + if (!SafeIterRange.hasValue()) + return false; + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopConstrainer LC(*L, getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), LPM, + LS, SE, DT, SafeIterRange.getValue()); + bool Changed = LC.run(); + + if (Changed) { + auto PrintConstrainedLoopInfo = [L]() { + dbgs() << "irce: in function "; + dbgs() << L->getHeader()->getParent()->getName() << ": "; + dbgs() << "constrained "; + L->print(dbgs()); + }; + + DEBUG(PrintConstrainedLoopInfo()); + + if (PrintChangedLoops) + PrintConstrainedLoopInfo(); + + // Optimize away the now-redundant range checks. + + for (InductiveRangeCheck &IRC : RangeChecksToEliminate) { + ConstantInt *FoldedRangeCheck = IRC.getPassingDirection() + ? ConstantInt::getTrue(Context) + : ConstantInt::getFalse(Context); + IRC.getCheckUse()->set(FoldedRangeCheck); + } + } + + return Changed; +} + +Pass *llvm::createInductiveRangeCheckEliminationPass() { + return new InductiveRangeCheckElimination; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/contrib/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp new file mode 100644 index 000000000000..7d66c0f73821 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp @@ -0,0 +1,1012 @@ +//===- InferAddressSpace.cpp - --------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// CUDA C/C++ includes memory space designation as variable type qualifers (such +// as __global__ and __shared__). Knowing the space of a memory access allows +// CUDA compilers to emit faster PTX loads and stores. For example, a load from +// shared memory can be translated to `ld.shared` which is roughly 10% faster +// than a generic `ld` on an NVIDIA Tesla K40c. +// +// Unfortunately, type qualifiers only apply to variable declarations, so CUDA +// compilers must infer the memory space of an address expression from +// type-qualified variables. +// +// LLVM IR uses non-zero (so-called) specific address spaces to represent memory +// spaces (e.g. addrspace(3) means shared memory). The Clang frontend +// places only type-qualified variables in specific address spaces, and then +// conservatively `addrspacecast`s each type-qualified variable to addrspace(0) +// (so-called the generic address space) for other instructions to use. +// +// For example, the Clang translates the following CUDA code +// __shared__ float a[10]; +// float v = a[i]; +// to +// %0 = addrspacecast [10 x float] addrspace(3)* @a to [10 x float]* +// %1 = gep [10 x float], [10 x float]* %0, i64 0, i64 %i +// %v = load float, float* %1 ; emits ld.f32 +// @a is in addrspace(3) since it's type-qualified, but its use from %1 is +// redirected to %0 (the generic version of @a). +// +// The optimization implemented in this file propagates specific address spaces +// from type-qualified variable declarations to its users. For example, it +// optimizes the above IR to +// %1 = gep [10 x float] addrspace(3)* @a, i64 0, i64 %i +// %v = load float addrspace(3)* %1 ; emits ld.shared.f32 +// propagating the addrspace(3) from @a to %1. As the result, the NVPTX +// codegen is able to emit ld.shared.f32 for %v. +// +// Address space inference works in two steps. First, it uses a data-flow +// analysis to infer as many generic pointers as possible to point to only one +// specific address space. In the above example, it can prove that %1 only +// points to addrspace(3). This algorithm was published in +// CUDA: Compiling and optimizing for a GPU platform +// Chakrabarti, Grover, Aarts, Kong, Kudlur, Lin, Marathe, Murphy, Wang +// ICCS 2012 +// +// Then, address space inference replaces all refinable generic pointers with +// equivalent specific pointers. +// +// The major challenge of implementing this optimization is handling PHINodes, +// which may create loops in the data flow graph. This brings two complications. +// +// First, the data flow analysis in Step 1 needs to be circular. For example, +// %generic.input = addrspacecast float addrspace(3)* %input to float* +// loop: +// %y = phi [ %generic.input, %y2 ] +// %y2 = getelementptr %y, 1 +// %v = load %y2 +// br ..., label %loop, ... +// proving %y specific requires proving both %generic.input and %y2 specific, +// but proving %y2 specific circles back to %y. To address this complication, +// the data flow analysis operates on a lattice: +// uninitialized > specific address spaces > generic. +// All address expressions (our implementation only considers phi, bitcast, +// addrspacecast, and getelementptr) start with the uninitialized address space. +// The monotone transfer function moves the address space of a pointer down a +// lattice path from uninitialized to specific and then to generic. A join +// operation of two different specific address spaces pushes the expression down +// to the generic address space. The analysis completes once it reaches a fixed +// point. +// +// Second, IR rewriting in Step 2 also needs to be circular. For example, +// converting %y to addrspace(3) requires the compiler to know the converted +// %y2, but converting %y2 needs the converted %y. To address this complication, +// we break these cycles using "undef" placeholders. When converting an +// instruction `I` to a new address space, if its operand `Op` is not converted +// yet, we let `I` temporarily use `undef` and fix all the uses of undef later. +// For instance, our algorithm first converts %y to +// %y' = phi float addrspace(3)* [ %input, undef ] +// Then, it converts %y2 to +// %y2' = getelementptr %y', 1 +// Finally, it fixes the undef in %y' so that +// %y' = phi float addrspace(3)* [ %input, %y2' ] +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <cassert> +#include <iterator> +#include <limits> +#include <utility> +#include <vector> + +#define DEBUG_TYPE "infer-address-spaces" + +using namespace llvm; + +static const unsigned UninitializedAddressSpace = + std::numeric_limits<unsigned>::max(); + +namespace { + +using ValueToAddrSpaceMapTy = DenseMap<const Value *, unsigned>; + +/// \brief InferAddressSpaces +class InferAddressSpaces : public FunctionPass { + /// Target specific address space which uses of should be replaced if + /// possible. + unsigned FlatAddrSpace; + +public: + static char ID; + + InferAddressSpaces() : FunctionPass(ID) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + bool runOnFunction(Function &F) override; + +private: + // Returns the new address space of V if updated; otherwise, returns None. + Optional<unsigned> + updateAddressSpace(const Value &V, + const ValueToAddrSpaceMapTy &InferredAddrSpace) const; + + // Tries to infer the specific address space of each address expression in + // Postorder. + void inferAddressSpaces(ArrayRef<WeakTrackingVH> Postorder, + ValueToAddrSpaceMapTy *InferredAddrSpace) const; + + bool isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const; + + // Changes the flat address expressions in function F to point to specific + // address spaces if InferredAddrSpace says so. Postorder is the postorder of + // all flat expressions in the use-def graph of function F. + bool rewriteWithNewAddressSpaces( + const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, + const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const; + + void appendsFlatAddressExpressionToPostorderStack( + Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const; + + bool rewriteIntrinsicOperands(IntrinsicInst *II, + Value *OldV, Value *NewV) const; + void collectRewritableIntrinsicOperands( + IntrinsicInst *II, + std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const; + + std::vector<WeakTrackingVH> collectFlatAddressExpressions(Function &F) const; + + Value *cloneValueWithNewAddressSpace( + Value *V, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) const; + unsigned joinAddressSpaces(unsigned AS1, unsigned AS2) const; +}; + +} // end anonymous namespace + +char InferAddressSpaces::ID = 0; + +namespace llvm { + +void initializeInferAddressSpacesPass(PassRegistry &); + +} // end namespace llvm + +INITIALIZE_PASS(InferAddressSpaces, DEBUG_TYPE, "Infer address spaces", + false, false) + +// Returns true if V is an address expression. +// TODO: Currently, we consider only phi, bitcast, addrspacecast, and +// getelementptr operators. +static bool isAddressExpression(const Value &V) { + if (!isa<Operator>(V)) + return false; + + switch (cast<Operator>(V).getOpcode()) { + case Instruction::PHI: + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::GetElementPtr: + case Instruction::Select: + return true; + default: + return false; + } +} + +// Returns the pointer operands of V. +// +// Precondition: V is an address expression. +static SmallVector<Value *, 2> getPointerOperands(const Value &V) { + const Operator &Op = cast<Operator>(V); + switch (Op.getOpcode()) { + case Instruction::PHI: { + auto IncomingValues = cast<PHINode>(Op).incoming_values(); + return SmallVector<Value *, 2>(IncomingValues.begin(), + IncomingValues.end()); + } + case Instruction::BitCast: + case Instruction::AddrSpaceCast: + case Instruction::GetElementPtr: + return {Op.getOperand(0)}; + case Instruction::Select: + return {Op.getOperand(1), Op.getOperand(2)}; + default: + llvm_unreachable("Unexpected instruction type."); + } +} + +// TODO: Move logic to TTI? +bool InferAddressSpaces::rewriteIntrinsicOperands(IntrinsicInst *II, + Value *OldV, + Value *NewV) const { + Module *M = II->getParent()->getParent()->getParent(); + + switch (II->getIntrinsicID()) { + case Intrinsic::amdgcn_atomic_inc: + case Intrinsic::amdgcn_atomic_dec:{ + const ConstantInt *IsVolatile = dyn_cast<ConstantInt>(II->getArgOperand(4)); + if (!IsVolatile || !IsVolatile->isZero()) + return false; + + LLVM_FALLTHROUGH; + } + case Intrinsic::objectsize: { + Type *DestTy = II->getType(); + Type *SrcTy = NewV->getType(); + Function *NewDecl = + Intrinsic::getDeclaration(M, II->getIntrinsicID(), {DestTy, SrcTy}); + II->setArgOperand(0, NewV); + II->setCalledFunction(NewDecl); + return true; + } + default: + return false; + } +} + +// TODO: Move logic to TTI? +void InferAddressSpaces::collectRewritableIntrinsicOperands( + IntrinsicInst *II, std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const { + switch (II->getIntrinsicID()) { + case Intrinsic::objectsize: + case Intrinsic::amdgcn_atomic_inc: + case Intrinsic::amdgcn_atomic_dec: + appendsFlatAddressExpressionToPostorderStack(II->getArgOperand(0), + PostorderStack, Visited); + break; + default: + break; + } +} + +// Returns all flat address expressions in function F. The elements are +// If V is an unvisited flat address expression, appends V to PostorderStack +// and marks it as visited. +void InferAddressSpaces::appendsFlatAddressExpressionToPostorderStack( + Value *V, std::vector<std::pair<Value *, bool>> &PostorderStack, + DenseSet<Value *> &Visited) const { + assert(V->getType()->isPointerTy()); + + // Generic addressing expressions may be hidden in nested constant + // expressions. + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(V)) { + // TODO: Look in non-address parts, like icmp operands. + if (isAddressExpression(*CE) && Visited.insert(CE).second) + PostorderStack.push_back(std::make_pair(CE, false)); + + return; + } + + if (isAddressExpression(*V) && + V->getType()->getPointerAddressSpace() == FlatAddrSpace) { + if (Visited.insert(V).second) { + PostorderStack.push_back(std::make_pair(V, false)); + + Operator *Op = cast<Operator>(V); + for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) { + if (ConstantExpr *CE = dyn_cast<ConstantExpr>(Op->getOperand(I))) { + if (isAddressExpression(*CE) && Visited.insert(CE).second) + PostorderStack.emplace_back(CE, false); + } + } + } + } +} + +// Returns all flat address expressions in function F. The elements are ordered +// ordered in postorder. +std::vector<WeakTrackingVH> +InferAddressSpaces::collectFlatAddressExpressions(Function &F) const { + // This function implements a non-recursive postorder traversal of a partial + // use-def graph of function F. + std::vector<std::pair<Value *, bool>> PostorderStack; + // The set of visited expressions. + DenseSet<Value *> Visited; + + auto PushPtrOperand = [&](Value *Ptr) { + appendsFlatAddressExpressionToPostorderStack(Ptr, PostorderStack, + Visited); + }; + + // Look at operations that may be interesting accelerate by moving to a known + // address space. We aim at generating after loads and stores, but pure + // addressing calculations may also be faster. + for (Instruction &I : instructions(F)) { + if (auto *GEP = dyn_cast<GetElementPtrInst>(&I)) { + if (!GEP->getType()->isVectorTy()) + PushPtrOperand(GEP->getPointerOperand()); + } else if (auto *LI = dyn_cast<LoadInst>(&I)) + PushPtrOperand(LI->getPointerOperand()); + else if (auto *SI = dyn_cast<StoreInst>(&I)) + PushPtrOperand(SI->getPointerOperand()); + else if (auto *RMW = dyn_cast<AtomicRMWInst>(&I)) + PushPtrOperand(RMW->getPointerOperand()); + else if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(&I)) + PushPtrOperand(CmpX->getPointerOperand()); + else if (auto *MI = dyn_cast<MemIntrinsic>(&I)) { + // For memset/memcpy/memmove, any pointer operand can be replaced. + PushPtrOperand(MI->getRawDest()); + + // Handle 2nd operand for memcpy/memmove. + if (auto *MTI = dyn_cast<MemTransferInst>(MI)) + PushPtrOperand(MTI->getRawSource()); + } else if (auto *II = dyn_cast<IntrinsicInst>(&I)) + collectRewritableIntrinsicOperands(II, PostorderStack, Visited); + else if (ICmpInst *Cmp = dyn_cast<ICmpInst>(&I)) { + // FIXME: Handle vectors of pointers + if (Cmp->getOperand(0)->getType()->isPointerTy()) { + PushPtrOperand(Cmp->getOperand(0)); + PushPtrOperand(Cmp->getOperand(1)); + } + } else if (auto *ASC = dyn_cast<AddrSpaceCastInst>(&I)) { + if (!ASC->getType()->isVectorTy()) + PushPtrOperand(ASC->getPointerOperand()); + } + } + + std::vector<WeakTrackingVH> Postorder; // The resultant postorder. + while (!PostorderStack.empty()) { + Value *TopVal = PostorderStack.back().first; + // If the operands of the expression on the top are already explored, + // adds that expression to the resultant postorder. + if (PostorderStack.back().second) { + if (TopVal->getType()->getPointerAddressSpace() == FlatAddrSpace) + Postorder.push_back(TopVal); + PostorderStack.pop_back(); + continue; + } + // Otherwise, adds its operands to the stack and explores them. + PostorderStack.back().second = true; + for (Value *PtrOperand : getPointerOperands(*TopVal)) { + appendsFlatAddressExpressionToPostorderStack(PtrOperand, PostorderStack, + Visited); + } + } + return Postorder; +} + +// A helper function for cloneInstructionWithNewAddressSpace. Returns the clone +// of OperandUse.get() in the new address space. If the clone is not ready yet, +// returns an undef in the new address space as a placeholder. +static Value *operandWithNewAddressSpaceOrCreateUndef( + const Use &OperandUse, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) { + Value *Operand = OperandUse.get(); + + Type *NewPtrTy = + Operand->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + + if (Constant *C = dyn_cast<Constant>(Operand)) + return ConstantExpr::getAddrSpaceCast(C, NewPtrTy); + + if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) + return NewOperand; + + UndefUsesToFix->push_back(&OperandUse); + return UndefValue::get(NewPtrTy); +} + +// Returns a clone of `I` with its operands converted to those specified in +// ValueWithNewAddrSpace. Due to potential cycles in the data flow graph, an +// operand whose address space needs to be modified might not exist in +// ValueWithNewAddrSpace. In that case, uses undef as a placeholder operand and +// adds that operand use to UndefUsesToFix so that caller can fix them later. +// +// Note that we do not necessarily clone `I`, e.g., if it is an addrspacecast +// from a pointer whose type already matches. Therefore, this function returns a +// Value* instead of an Instruction*. +static Value *cloneInstructionWithNewAddressSpace( + Instruction *I, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) { + Type *NewPtrType = + I->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + + if (I->getOpcode() == Instruction::AddrSpaceCast) { + Value *Src = I->getOperand(0); + // Because `I` is flat, the source address space must be specific. + // Therefore, the inferred address space must be the source space, according + // to our algorithm. + assert(Src->getType()->getPointerAddressSpace() == NewAddrSpace); + if (Src->getType() != NewPtrType) + return new BitCastInst(Src, NewPtrType); + return Src; + } + + // Computes the converted pointer operands. + SmallVector<Value *, 4> NewPointerOperands; + for (const Use &OperandUse : I->operands()) { + if (!OperandUse.get()->getType()->isPointerTy()) + NewPointerOperands.push_back(nullptr); + else + NewPointerOperands.push_back(operandWithNewAddressSpaceOrCreateUndef( + OperandUse, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix)); + } + + switch (I->getOpcode()) { + case Instruction::BitCast: + return new BitCastInst(NewPointerOperands[0], NewPtrType); + case Instruction::PHI: { + assert(I->getType()->isPointerTy()); + PHINode *PHI = cast<PHINode>(I); + PHINode *NewPHI = PHINode::Create(NewPtrType, PHI->getNumIncomingValues()); + for (unsigned Index = 0; Index < PHI->getNumIncomingValues(); ++Index) { + unsigned OperandNo = PHINode::getOperandNumForIncomingValue(Index); + NewPHI->addIncoming(NewPointerOperands[OperandNo], + PHI->getIncomingBlock(Index)); + } + return NewPHI; + } + case Instruction::GetElementPtr: { + GetElementPtrInst *GEP = cast<GetElementPtrInst>(I); + GetElementPtrInst *NewGEP = GetElementPtrInst::Create( + GEP->getSourceElementType(), NewPointerOperands[0], + SmallVector<Value *, 4>(GEP->idx_begin(), GEP->idx_end())); + NewGEP->setIsInBounds(GEP->isInBounds()); + return NewGEP; + } + case Instruction::Select: + assert(I->getType()->isPointerTy()); + return SelectInst::Create(I->getOperand(0), NewPointerOperands[1], + NewPointerOperands[2], "", nullptr, I); + default: + llvm_unreachable("Unexpected opcode"); + } +} + +// Similar to cloneInstructionWithNewAddressSpace, returns a clone of the +// constant expression `CE` with its operands replaced as specified in +// ValueWithNewAddrSpace. +static Value *cloneConstantExprWithNewAddressSpace( + ConstantExpr *CE, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace) { + Type *TargetType = + CE->getType()->getPointerElementType()->getPointerTo(NewAddrSpace); + + if (CE->getOpcode() == Instruction::AddrSpaceCast) { + // Because CE is flat, the source address space must be specific. + // Therefore, the inferred address space must be the source space according + // to our algorithm. + assert(CE->getOperand(0)->getType()->getPointerAddressSpace() == + NewAddrSpace); + return ConstantExpr::getBitCast(CE->getOperand(0), TargetType); + } + + if (CE->getOpcode() == Instruction::BitCast) { + if (Value *NewOperand = ValueWithNewAddrSpace.lookup(CE->getOperand(0))) + return ConstantExpr::getBitCast(cast<Constant>(NewOperand), TargetType); + return ConstantExpr::getAddrSpaceCast(CE, TargetType); + } + + if (CE->getOpcode() == Instruction::Select) { + Constant *Src0 = CE->getOperand(1); + Constant *Src1 = CE->getOperand(2); + if (Src0->getType()->getPointerAddressSpace() == + Src1->getType()->getPointerAddressSpace()) { + + return ConstantExpr::getSelect( + CE->getOperand(0), ConstantExpr::getAddrSpaceCast(Src0, TargetType), + ConstantExpr::getAddrSpaceCast(Src1, TargetType)); + } + } + + // Computes the operands of the new constant expression. + bool IsNew = false; + SmallVector<Constant *, 4> NewOperands; + for (unsigned Index = 0; Index < CE->getNumOperands(); ++Index) { + Constant *Operand = CE->getOperand(Index); + // If the address space of `Operand` needs to be modified, the new operand + // with the new address space should already be in ValueWithNewAddrSpace + // because (1) the constant expressions we consider (i.e. addrspacecast, + // bitcast, and getelementptr) do not incur cycles in the data flow graph + // and (2) this function is called on constant expressions in postorder. + if (Value *NewOperand = ValueWithNewAddrSpace.lookup(Operand)) { + IsNew = true; + NewOperands.push_back(cast<Constant>(NewOperand)); + } else { + // Otherwise, reuses the old operand. + NewOperands.push_back(Operand); + } + } + + // If !IsNew, we will replace the Value with itself. However, replaced values + // are assumed to wrapped in a addrspace cast later so drop it now. + if (!IsNew) + return nullptr; + + if (CE->getOpcode() == Instruction::GetElementPtr) { + // Needs to specify the source type while constructing a getelementptr + // constant expression. + return CE->getWithOperands( + NewOperands, TargetType, /*OnlyIfReduced=*/false, + NewOperands[0]->getType()->getPointerElementType()); + } + + return CE->getWithOperands(NewOperands, TargetType); +} + +// Returns a clone of the value `V`, with its operands replaced as specified in +// ValueWithNewAddrSpace. This function is called on every flat address +// expression whose address space needs to be modified, in postorder. +// +// See cloneInstructionWithNewAddressSpace for the meaning of UndefUsesToFix. +Value *InferAddressSpaces::cloneValueWithNewAddressSpace( + Value *V, unsigned NewAddrSpace, + const ValueToValueMapTy &ValueWithNewAddrSpace, + SmallVectorImpl<const Use *> *UndefUsesToFix) const { + // All values in Postorder are flat address expressions. + assert(isAddressExpression(*V) && + V->getType()->getPointerAddressSpace() == FlatAddrSpace); + + if (Instruction *I = dyn_cast<Instruction>(V)) { + Value *NewV = cloneInstructionWithNewAddressSpace( + I, NewAddrSpace, ValueWithNewAddrSpace, UndefUsesToFix); + if (Instruction *NewI = dyn_cast<Instruction>(NewV)) { + if (NewI->getParent() == nullptr) { + NewI->insertBefore(I); + NewI->takeName(I); + } + } + return NewV; + } + + return cloneConstantExprWithNewAddressSpace( + cast<ConstantExpr>(V), NewAddrSpace, ValueWithNewAddrSpace); +} + +// Defines the join operation on the address space lattice (see the file header +// comments). +unsigned InferAddressSpaces::joinAddressSpaces(unsigned AS1, + unsigned AS2) const { + if (AS1 == FlatAddrSpace || AS2 == FlatAddrSpace) + return FlatAddrSpace; + + if (AS1 == UninitializedAddressSpace) + return AS2; + if (AS2 == UninitializedAddressSpace) + return AS1; + + // The join of two different specific address spaces is flat. + return (AS1 == AS2) ? AS1 : FlatAddrSpace; +} + +bool InferAddressSpaces::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + const TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + FlatAddrSpace = TTI.getFlatAddressSpace(); + if (FlatAddrSpace == UninitializedAddressSpace) + return false; + + // Collects all flat address expressions in postorder. + std::vector<WeakTrackingVH> Postorder = collectFlatAddressExpressions(F); + + // Runs a data-flow analysis to refine the address spaces of every expression + // in Postorder. + ValueToAddrSpaceMapTy InferredAddrSpace; + inferAddressSpaces(Postorder, &InferredAddrSpace); + + // Changes the address spaces of the flat address expressions who are inferred + // to point to a specific address space. + return rewriteWithNewAddressSpaces(TTI, Postorder, InferredAddrSpace, &F); +} + +// Constants need to be tracked through RAUW to handle cases with nested +// constant expressions, so wrap values in WeakTrackingVH. +void InferAddressSpaces::inferAddressSpaces( + ArrayRef<WeakTrackingVH> Postorder, + ValueToAddrSpaceMapTy *InferredAddrSpace) const { + SetVector<Value *> Worklist(Postorder.begin(), Postorder.end()); + // Initially, all expressions are in the uninitialized address space. + for (Value *V : Postorder) + (*InferredAddrSpace)[V] = UninitializedAddressSpace; + + while (!Worklist.empty()) { + Value *V = Worklist.pop_back_val(); + + // Tries to update the address space of the stack top according to the + // address spaces of its operands. + DEBUG(dbgs() << "Updating the address space of\n " << *V << '\n'); + Optional<unsigned> NewAS = updateAddressSpace(*V, *InferredAddrSpace); + if (!NewAS.hasValue()) + continue; + // If any updates are made, grabs its users to the worklist because + // their address spaces can also be possibly updated. + DEBUG(dbgs() << " to " << NewAS.getValue() << '\n'); + (*InferredAddrSpace)[V] = NewAS.getValue(); + + for (Value *User : V->users()) { + // Skip if User is already in the worklist. + if (Worklist.count(User)) + continue; + + auto Pos = InferredAddrSpace->find(User); + // Our algorithm only updates the address spaces of flat address + // expressions, which are those in InferredAddrSpace. + if (Pos == InferredAddrSpace->end()) + continue; + + // Function updateAddressSpace moves the address space down a lattice + // path. Therefore, nothing to do if User is already inferred as flat (the + // bottom element in the lattice). + if (Pos->second == FlatAddrSpace) + continue; + + Worklist.insert(User); + } + } +} + +Optional<unsigned> InferAddressSpaces::updateAddressSpace( + const Value &V, const ValueToAddrSpaceMapTy &InferredAddrSpace) const { + assert(InferredAddrSpace.count(&V)); + + // The new inferred address space equals the join of the address spaces + // of all its pointer operands. + unsigned NewAS = UninitializedAddressSpace; + + const Operator &Op = cast<Operator>(V); + if (Op.getOpcode() == Instruction::Select) { + Value *Src0 = Op.getOperand(1); + Value *Src1 = Op.getOperand(2); + + auto I = InferredAddrSpace.find(Src0); + unsigned Src0AS = (I != InferredAddrSpace.end()) ? + I->second : Src0->getType()->getPointerAddressSpace(); + + auto J = InferredAddrSpace.find(Src1); + unsigned Src1AS = (J != InferredAddrSpace.end()) ? + J->second : Src1->getType()->getPointerAddressSpace(); + + auto *C0 = dyn_cast<Constant>(Src0); + auto *C1 = dyn_cast<Constant>(Src1); + + // If one of the inputs is a constant, we may be able to do a constant + // addrspacecast of it. Defer inferring the address space until the input + // address space is known. + if ((C1 && Src0AS == UninitializedAddressSpace) || + (C0 && Src1AS == UninitializedAddressSpace)) + return None; + + if (C0 && isSafeToCastConstAddrSpace(C0, Src1AS)) + NewAS = Src1AS; + else if (C1 && isSafeToCastConstAddrSpace(C1, Src0AS)) + NewAS = Src0AS; + else + NewAS = joinAddressSpaces(Src0AS, Src1AS); + } else { + for (Value *PtrOperand : getPointerOperands(V)) { + auto I = InferredAddrSpace.find(PtrOperand); + unsigned OperandAS = I != InferredAddrSpace.end() ? + I->second : PtrOperand->getType()->getPointerAddressSpace(); + + // join(flat, *) = flat. So we can break if NewAS is already flat. + NewAS = joinAddressSpaces(NewAS, OperandAS); + if (NewAS == FlatAddrSpace) + break; + } + } + + unsigned OldAS = InferredAddrSpace.lookup(&V); + assert(OldAS != FlatAddrSpace); + if (OldAS == NewAS) + return None; + return NewAS; +} + +/// \p returns true if \p U is the pointer operand of a memory instruction with +/// a single pointer operand that can have its address space changed by simply +/// mutating the use to a new value. If the memory instruction is volatile, +/// return true only if the target allows the memory instruction to be volatile +/// in the new address space. +static bool isSimplePointerUseValidToReplace(const TargetTransformInfo &TTI, + Use &U, unsigned AddrSpace) { + User *Inst = U.getUser(); + unsigned OpNo = U.getOperandNo(); + bool VolatileIsAllowed = false; + if (auto *I = dyn_cast<Instruction>(Inst)) + VolatileIsAllowed = TTI.hasVolatileVariant(I, AddrSpace); + + if (auto *LI = dyn_cast<LoadInst>(Inst)) + return OpNo == LoadInst::getPointerOperandIndex() && + (VolatileIsAllowed || !LI->isVolatile()); + + if (auto *SI = dyn_cast<StoreInst>(Inst)) + return OpNo == StoreInst::getPointerOperandIndex() && + (VolatileIsAllowed || !SI->isVolatile()); + + if (auto *RMW = dyn_cast<AtomicRMWInst>(Inst)) + return OpNo == AtomicRMWInst::getPointerOperandIndex() && + (VolatileIsAllowed || !RMW->isVolatile()); + + if (auto *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) + return OpNo == AtomicCmpXchgInst::getPointerOperandIndex() && + (VolatileIsAllowed || !CmpX->isVolatile()); + + return false; +} + +/// Update memory intrinsic uses that require more complex processing than +/// simple memory instructions. Thse require re-mangling and may have multiple +/// pointer operands. +static bool handleMemIntrinsicPtrUse(MemIntrinsic *MI, Value *OldV, + Value *NewV) { + IRBuilder<> B(MI); + MDNode *TBAA = MI->getMetadata(LLVMContext::MD_tbaa); + MDNode *ScopeMD = MI->getMetadata(LLVMContext::MD_alias_scope); + MDNode *NoAliasMD = MI->getMetadata(LLVMContext::MD_noalias); + + if (auto *MSI = dyn_cast<MemSetInst>(MI)) { + B.CreateMemSet(NewV, MSI->getValue(), + MSI->getLength(), MSI->getAlignment(), + false, // isVolatile + TBAA, ScopeMD, NoAliasMD); + } else if (auto *MTI = dyn_cast<MemTransferInst>(MI)) { + Value *Src = MTI->getRawSource(); + Value *Dest = MTI->getRawDest(); + + // Be careful in case this is a self-to-self copy. + if (Src == OldV) + Src = NewV; + + if (Dest == OldV) + Dest = NewV; + + if (isa<MemCpyInst>(MTI)) { + MDNode *TBAAStruct = MTI->getMetadata(LLVMContext::MD_tbaa_struct); + B.CreateMemCpy(Dest, Src, MTI->getLength(), + MTI->getAlignment(), + false, // isVolatile + TBAA, TBAAStruct, ScopeMD, NoAliasMD); + } else { + assert(isa<MemMoveInst>(MTI)); + B.CreateMemMove(Dest, Src, MTI->getLength(), + MTI->getAlignment(), + false, // isVolatile + TBAA, ScopeMD, NoAliasMD); + } + } else + llvm_unreachable("unhandled MemIntrinsic"); + + MI->eraseFromParent(); + return true; +} + +// \p returns true if it is OK to change the address space of constant \p C with +// a ConstantExpr addrspacecast. +bool InferAddressSpaces::isSafeToCastConstAddrSpace(Constant *C, unsigned NewAS) const { + assert(NewAS != UninitializedAddressSpace); + + unsigned SrcAS = C->getType()->getPointerAddressSpace(); + if (SrcAS == NewAS || isa<UndefValue>(C)) + return true; + + // Prevent illegal casts between different non-flat address spaces. + if (SrcAS != FlatAddrSpace && NewAS != FlatAddrSpace) + return false; + + if (isa<ConstantPointerNull>(C)) + return true; + + if (auto *Op = dyn_cast<Operator>(C)) { + // If we already have a constant addrspacecast, it should be safe to cast it + // off. + if (Op->getOpcode() == Instruction::AddrSpaceCast) + return isSafeToCastConstAddrSpace(cast<Constant>(Op->getOperand(0)), NewAS); + + if (Op->getOpcode() == Instruction::IntToPtr && + Op->getType()->getPointerAddressSpace() == FlatAddrSpace) + return true; + } + + return false; +} + +static Value::use_iterator skipToNextUser(Value::use_iterator I, + Value::use_iterator End) { + User *CurUser = I->getUser(); + ++I; + + while (I != End && I->getUser() == CurUser) + ++I; + + return I; +} + +bool InferAddressSpaces::rewriteWithNewAddressSpaces( + const TargetTransformInfo &TTI, ArrayRef<WeakTrackingVH> Postorder, + const ValueToAddrSpaceMapTy &InferredAddrSpace, Function *F) const { + // For each address expression to be modified, creates a clone of it with its + // pointer operands converted to the new address space. Since the pointer + // operands are converted, the clone is naturally in the new address space by + // construction. + ValueToValueMapTy ValueWithNewAddrSpace; + SmallVector<const Use *, 32> UndefUsesToFix; + for (Value* V : Postorder) { + unsigned NewAddrSpace = InferredAddrSpace.lookup(V); + if (V->getType()->getPointerAddressSpace() != NewAddrSpace) { + ValueWithNewAddrSpace[V] = cloneValueWithNewAddressSpace( + V, NewAddrSpace, ValueWithNewAddrSpace, &UndefUsesToFix); + } + } + + if (ValueWithNewAddrSpace.empty()) + return false; + + // Fixes all the undef uses generated by cloneInstructionWithNewAddressSpace. + for (const Use *UndefUse : UndefUsesToFix) { + User *V = UndefUse->getUser(); + User *NewV = cast<User>(ValueWithNewAddrSpace.lookup(V)); + unsigned OperandNo = UndefUse->getOperandNo(); + assert(isa<UndefValue>(NewV->getOperand(OperandNo))); + NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(UndefUse->get())); + } + + SmallVector<Instruction *, 16> DeadInstructions; + + // Replaces the uses of the old address expressions with the new ones. + for (const WeakTrackingVH &WVH : Postorder) { + assert(WVH && "value was unexpectedly deleted"); + Value *V = WVH; + Value *NewV = ValueWithNewAddrSpace.lookup(V); + if (NewV == nullptr) + continue; + + DEBUG(dbgs() << "Replacing the uses of " << *V + << "\n with\n " << *NewV << '\n'); + + if (Constant *C = dyn_cast<Constant>(V)) { + Constant *Replace = ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), + C->getType()); + if (C != Replace) { + DEBUG(dbgs() << "Inserting replacement const cast: " + << Replace << ": " << *Replace << '\n'); + C->replaceAllUsesWith(Replace); + V = Replace; + } + } + + Value::use_iterator I, E, Next; + for (I = V->use_begin(), E = V->use_end(); I != E; ) { + Use &U = *I; + + // Some users may see the same pointer operand in multiple operands. Skip + // to the next instruction. + I = skipToNextUser(I, E); + + if (isSimplePointerUseValidToReplace( + TTI, U, V->getType()->getPointerAddressSpace())) { + // If V is used as the pointer operand of a compatible memory operation, + // sets the pointer operand to NewV. This replacement does not change + // the element type, so the resultant load/store is still valid. + U.set(NewV); + continue; + } + + User *CurUser = U.getUser(); + // Handle more complex cases like intrinsic that need to be remangled. + if (auto *MI = dyn_cast<MemIntrinsic>(CurUser)) { + if (!MI->isVolatile() && handleMemIntrinsicPtrUse(MI, V, NewV)) + continue; + } + + if (auto *II = dyn_cast<IntrinsicInst>(CurUser)) { + if (rewriteIntrinsicOperands(II, V, NewV)) + continue; + } + + if (isa<Instruction>(CurUser)) { + if (ICmpInst *Cmp = dyn_cast<ICmpInst>(CurUser)) { + // If we can infer that both pointers are in the same addrspace, + // transform e.g. + // %cmp = icmp eq float* %p, %q + // into + // %cmp = icmp eq float addrspace(3)* %new_p, %new_q + + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + int SrcIdx = U.getOperandNo(); + int OtherIdx = (SrcIdx == 0) ? 1 : 0; + Value *OtherSrc = Cmp->getOperand(OtherIdx); + + if (Value *OtherNewV = ValueWithNewAddrSpace.lookup(OtherSrc)) { + if (OtherNewV->getType()->getPointerAddressSpace() == NewAS) { + Cmp->setOperand(OtherIdx, OtherNewV); + Cmp->setOperand(SrcIdx, NewV); + continue; + } + } + + // Even if the type mismatches, we can cast the constant. + if (auto *KOtherSrc = dyn_cast<Constant>(OtherSrc)) { + if (isSafeToCastConstAddrSpace(KOtherSrc, NewAS)) { + Cmp->setOperand(SrcIdx, NewV); + Cmp->setOperand(OtherIdx, + ConstantExpr::getAddrSpaceCast(KOtherSrc, NewV->getType())); + continue; + } + } + } + + if (AddrSpaceCastInst *ASC = dyn_cast<AddrSpaceCastInst>(CurUser)) { + unsigned NewAS = NewV->getType()->getPointerAddressSpace(); + if (ASC->getDestAddressSpace() == NewAS) { + if (ASC->getType()->getPointerElementType() != + NewV->getType()->getPointerElementType()) { + NewV = CastInst::Create(Instruction::BitCast, NewV, + ASC->getType(), "", ASC); + } + ASC->replaceAllUsesWith(NewV); + DeadInstructions.push_back(ASC); + continue; + } + } + + // Otherwise, replaces the use with flat(NewV). + if (Instruction *I = dyn_cast<Instruction>(V)) { + BasicBlock::iterator InsertPos = std::next(I->getIterator()); + while (isa<PHINode>(InsertPos)) + ++InsertPos; + U.set(new AddrSpaceCastInst(NewV, V->getType(), "", &*InsertPos)); + } else { + U.set(ConstantExpr::getAddrSpaceCast(cast<Constant>(NewV), + V->getType())); + } + } + } + + if (V->use_empty()) { + if (Instruction *I = dyn_cast<Instruction>(V)) + DeadInstructions.push_back(I); + } + } + + for (Instruction *I : DeadInstructions) + RecursivelyDeleteTriviallyDeadInstructions(I); + + return true; +} + +FunctionPass *llvm::createInferAddressSpacesPass() { + return new InferAddressSpaces(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp b/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp new file mode 100644 index 000000000000..141c9938bf8b --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/JumpThreading.cpp @@ -0,0 +1,2545 @@ +//===- JumpThreading.cpp - Thread control through conditional blocks ------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Jump Threading pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/JumpThreading.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/BranchProbabilityInfo.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LazyValueInfo.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantRange.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/BlockFrequency.h" +#include "llvm/Support/BranchProbability.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <memory> +#include <utility> + +using namespace llvm; +using namespace jumpthreading; + +#define DEBUG_TYPE "jump-threading" + +STATISTIC(NumThreads, "Number of jumps threaded"); +STATISTIC(NumFolds, "Number of terminators folded"); +STATISTIC(NumDupes, "Number of branch blocks duplicated to eliminate phi"); + +static cl::opt<unsigned> +BBDuplicateThreshold("jump-threading-threshold", + cl::desc("Max block size to duplicate for jump threading"), + cl::init(6), cl::Hidden); + +static cl::opt<unsigned> +ImplicationSearchThreshold( + "jump-threading-implication-search-threshold", + cl::desc("The number of predecessors to search for a stronger " + "condition to use to thread over a weaker condition"), + cl::init(3), cl::Hidden); + +static cl::opt<bool> PrintLVIAfterJumpThreading( + "print-lvi-after-jump-threading", + cl::desc("Print the LazyValueInfo cache after JumpThreading"), cl::init(false), + cl::Hidden); + +namespace { + + /// This pass performs 'jump threading', which looks at blocks that have + /// multiple predecessors and multiple successors. If one or more of the + /// predecessors of the block can be proven to always jump to one of the + /// successors, we forward the edge from the predecessor to the successor by + /// duplicating the contents of this block. + /// + /// An example of when this can occur is code like this: + /// + /// if () { ... + /// X = 4; + /// } + /// if (X < 3) { + /// + /// In this case, the unconditional branch at the end of the first if can be + /// revectored to the false side of the second if. + class JumpThreading : public FunctionPass { + JumpThreadingPass Impl; + + public: + static char ID; // Pass identification + + JumpThreading(int T = -1) : FunctionPass(ID), Impl(T) { + initializeJumpThreadingPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + if (PrintLVIAfterJumpThreading) + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<LazyValueInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } + + void releaseMemory() override { Impl.releaseMemory(); } + }; + +} // end anonymous namespace + +char JumpThreading::ID = 0; + +INITIALIZE_PASS_BEGIN(JumpThreading, "jump-threading", + "Jump Threading", false, false) +INITIALIZE_PASS_DEPENDENCY(LazyValueInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(JumpThreading, "jump-threading", + "Jump Threading", false, false) + +// Public interface to the Jump Threading pass +FunctionPass *llvm::createJumpThreadingPass(int Threshold) { + return new JumpThreading(Threshold); +} + +JumpThreadingPass::JumpThreadingPass(int T) { + BBDupThreshold = (T == -1) ? BBDuplicateThreshold : unsigned(T); +} + +// Update branch probability information according to conditional +// branch probablity. This is usually made possible for cloned branches +// in inline instances by the context specific profile in the caller. +// For instance, +// +// [Block PredBB] +// [Branch PredBr] +// if (t) { +// Block A; +// } else { +// Block B; +// } +// +// [Block BB] +// cond = PN([true, %A], [..., %B]); // PHI node +// [Branch CondBr] +// if (cond) { +// ... // P(cond == true) = 1% +// } +// +// Here we know that when block A is taken, cond must be true, which means +// P(cond == true | A) = 1 +// +// Given that P(cond == true) = P(cond == true | A) * P(A) + +// P(cond == true | B) * P(B) +// we get: +// P(cond == true ) = P(A) + P(cond == true | B) * P(B) +// +// which gives us: +// P(A) is less than P(cond == true), i.e. +// P(t == true) <= P(cond == true) +// +// In other words, if we know P(cond == true) is unlikely, we know +// that P(t == true) is also unlikely. +// +static void updatePredecessorProfileMetadata(PHINode *PN, BasicBlock *BB) { + BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); + if (!CondBr) + return; + + BranchProbability BP; + uint64_t TrueWeight, FalseWeight; + if (!CondBr->extractProfMetadata(TrueWeight, FalseWeight)) + return; + + // Returns the outgoing edge of the dominating predecessor block + // that leads to the PhiNode's incoming block: + auto GetPredOutEdge = + [](BasicBlock *IncomingBB, + BasicBlock *PhiBB) -> std::pair<BasicBlock *, BasicBlock *> { + auto *PredBB = IncomingBB; + auto *SuccBB = PhiBB; + while (true) { + BranchInst *PredBr = dyn_cast<BranchInst>(PredBB->getTerminator()); + if (PredBr && PredBr->isConditional()) + return {PredBB, SuccBB}; + auto *SinglePredBB = PredBB->getSinglePredecessor(); + if (!SinglePredBB) + return {nullptr, nullptr}; + SuccBB = PredBB; + PredBB = SinglePredBB; + } + }; + + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *PhiOpnd = PN->getIncomingValue(i); + ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd); + + if (!CI || !CI->getType()->isIntegerTy(1)) + continue; + + BP = (CI->isOne() ? BranchProbability::getBranchProbability( + TrueWeight, TrueWeight + FalseWeight) + : BranchProbability::getBranchProbability( + FalseWeight, TrueWeight + FalseWeight)); + + auto PredOutEdge = GetPredOutEdge(PN->getIncomingBlock(i), BB); + if (!PredOutEdge.first) + return; + + BasicBlock *PredBB = PredOutEdge.first; + BranchInst *PredBr = cast<BranchInst>(PredBB->getTerminator()); + + uint64_t PredTrueWeight, PredFalseWeight; + // FIXME: We currently only set the profile data when it is missing. + // With PGO, this can be used to refine even existing profile data with + // context information. This needs to be done after more performance + // testing. + if (PredBr->extractProfMetadata(PredTrueWeight, PredFalseWeight)) + continue; + + // We can not infer anything useful when BP >= 50%, because BP is the + // upper bound probability value. + if (BP >= BranchProbability(50, 100)) + continue; + + SmallVector<uint32_t, 2> Weights; + if (PredBr->getSuccessor(0) == PredOutEdge.second) { + Weights.push_back(BP.getNumerator()); + Weights.push_back(BP.getCompl().getNumerator()); + } else { + Weights.push_back(BP.getCompl().getNumerator()); + Weights.push_back(BP.getNumerator()); + } + PredBr->setMetadata(LLVMContext::MD_prof, + MDBuilder(PredBr->getParent()->getContext()) + .createBranchWeights(Weights)); + } +} + +/// runOnFunction - Toplevel algorithm. +bool JumpThreading::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + auto TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto LVI = &getAnalysis<LazyValueInfoWrapperPass>().getLVI(); + auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + std::unique_ptr<BlockFrequencyInfo> BFI; + std::unique_ptr<BranchProbabilityInfo> BPI; + bool HasProfileData = F.hasProfileData(); + if (HasProfileData) { + LoopInfo LI{DominatorTree(F)}; + BPI.reset(new BranchProbabilityInfo(F, LI, TLI)); + BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); + } + + bool Changed = Impl.runImpl(F, TLI, LVI, AA, HasProfileData, std::move(BFI), + std::move(BPI)); + if (PrintLVIAfterJumpThreading) { + dbgs() << "LVI for function '" << F.getName() << "':\n"; + LVI->printLVI(F, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + dbgs()); + } + return Changed; +} + +PreservedAnalyses JumpThreadingPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &LVI = AM.getResult<LazyValueAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + + std::unique_ptr<BlockFrequencyInfo> BFI; + std::unique_ptr<BranchProbabilityInfo> BPI; + if (F.hasProfileData()) { + LoopInfo LI{DominatorTree(F)}; + BPI.reset(new BranchProbabilityInfo(F, LI, &TLI)); + BFI.reset(new BlockFrequencyInfo(F, *BPI, LI)); + } + + bool Changed = runImpl(F, &TLI, &LVI, &AA, HasProfileData, std::move(BFI), + std::move(BPI)); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} + +bool JumpThreadingPass::runImpl(Function &F, TargetLibraryInfo *TLI_, + LazyValueInfo *LVI_, AliasAnalysis *AA_, + bool HasProfileData_, + std::unique_ptr<BlockFrequencyInfo> BFI_, + std::unique_ptr<BranchProbabilityInfo> BPI_) { + DEBUG(dbgs() << "Jump threading on function '" << F.getName() << "'\n"); + TLI = TLI_; + LVI = LVI_; + AA = AA_; + BFI.reset(); + BPI.reset(); + // When profile data is available, we need to update edge weights after + // successful jump threading, which requires both BPI and BFI being available. + HasProfileData = HasProfileData_; + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + HasGuards = GuardDecl && !GuardDecl->use_empty(); + if (HasProfileData) { + BPI = std::move(BPI_); + BFI = std::move(BFI_); + } + + // Remove unreachable blocks from function as they may result in infinite + // loop. We do threading if we found something profitable. Jump threading a + // branch can create other opportunities. If these opportunities form a cycle + // i.e. if any jump threading is undoing previous threading in the path, then + // we will loop forever. We take care of this issue by not jump threading for + // back edges. This works for normal cases but not for unreachable blocks as + // they may have cycle with no back edge. + bool EverChanged = false; + EverChanged |= removeUnreachableBlocks(F, LVI); + + FindLoopHeaders(F); + + bool Changed; + do { + Changed = false; + for (Function::iterator I = F.begin(), E = F.end(); I != E;) { + BasicBlock *BB = &*I; + // Thread all of the branches we can over this block. + while (ProcessBlock(BB)) + Changed = true; + + ++I; + + // If the block is trivially dead, zap it. This eliminates the successor + // edges which simplifies the CFG. + if (pred_empty(BB) && + BB != &BB->getParent()->getEntryBlock()) { + DEBUG(dbgs() << " JT: Deleting dead block '" << BB->getName() + << "' with terminator: " << *BB->getTerminator() << '\n'); + LoopHeaders.erase(BB); + LVI->eraseBlock(BB); + DeleteDeadBlock(BB); + Changed = true; + continue; + } + + BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + + // Can't thread an unconditional jump, but if the block is "almost + // empty", we can replace uses of it with uses of the successor and make + // this dead. + // We should not eliminate the loop header or latch either, because + // eliminating a loop header or latch might later prevent LoopSimplify + // from transforming nested loops into simplified form. We will rely on + // later passes in backend to clean up empty blocks. + if (BI && BI->isUnconditional() && + BB != &BB->getParent()->getEntryBlock() && + // If the terminator is the only non-phi instruction, try to nuke it. + BB->getFirstNonPHIOrDbg()->isTerminator() && !LoopHeaders.count(BB) && + !LoopHeaders.count(BI->getSuccessor(0))) { + // FIXME: It is always conservatively correct to drop the info + // for a block even if it doesn't get erased. This isn't totally + // awesome, but it allows us to use AssertingVH to prevent nasty + // dangling pointer issues within LazyValueInfo. + LVI->eraseBlock(BB); + if (TryToSimplifyUncondBranchFromEmptyBlock(BB)) + Changed = true; + } + } + EverChanged |= Changed; + } while (Changed); + + LoopHeaders.clear(); + return EverChanged; +} + +// Replace uses of Cond with ToVal when safe to do so. If all uses are +// replaced, we can remove Cond. We cannot blindly replace all uses of Cond +// because we may incorrectly replace uses when guards/assumes are uses of +// of `Cond` and we used the guards/assume to reason about the `Cond` value +// at the end of block. RAUW unconditionally replaces all uses +// including the guards/assumes themselves and the uses before the +// guard/assume. +static void ReplaceFoldableUses(Instruction *Cond, Value *ToVal) { + assert(Cond->getType() == ToVal->getType()); + auto *BB = Cond->getParent(); + // We can unconditionally replace all uses in non-local blocks (i.e. uses + // strictly dominated by BB), since LVI information is true from the + // terminator of BB. + replaceNonLocalUsesWith(Cond, ToVal); + for (Instruction &I : reverse(*BB)) { + // Reached the Cond whose uses we are trying to replace, so there are no + // more uses. + if (&I == Cond) + break; + // We only replace uses in instructions that are guaranteed to reach the end + // of BB, where we know Cond is ToVal. + if (!isGuaranteedToTransferExecutionToSuccessor(&I)) + break; + I.replaceUsesOfWith(Cond, ToVal); + } + if (Cond->use_empty() && !Cond->mayHaveSideEffects()) + Cond->eraseFromParent(); +} + +/// Return the cost of duplicating a piece of this block from first non-phi +/// and before StopAt instruction to thread across it. Stop scanning the block +/// when exceeding the threshold. If duplication is impossible, returns ~0U. +static unsigned getJumpThreadDuplicationCost(BasicBlock *BB, + Instruction *StopAt, + unsigned Threshold) { + assert(StopAt->getParent() == BB && "Not an instruction from proper BB?"); + /// Ignore PHI nodes, these will be flattened when duplication happens. + BasicBlock::const_iterator I(BB->getFirstNonPHI()); + + // FIXME: THREADING will delete values that are just used to compute the + // branch, so they shouldn't count against the duplication cost. + + unsigned Bonus = 0; + if (BB->getTerminator() == StopAt) { + // Threading through a switch statement is particularly profitable. If this + // block ends in a switch, decrease its cost to make it more likely to + // happen. + if (isa<SwitchInst>(StopAt)) + Bonus = 6; + + // The same holds for indirect branches, but slightly more so. + if (isa<IndirectBrInst>(StopAt)) + Bonus = 8; + } + + // Bump the threshold up so the early exit from the loop doesn't skip the + // terminator-based Size adjustment at the end. + Threshold += Bonus; + + // Sum up the cost of each instruction until we get to the terminator. Don't + // include the terminator because the copy won't include it. + unsigned Size = 0; + for (; &*I != StopAt; ++I) { + + // Stop scanning the block if we've reached the threshold. + if (Size > Threshold) + return Size; + + // Debugger intrinsics don't incur code size. + if (isa<DbgInfoIntrinsic>(I)) continue; + + // If this is a pointer->pointer bitcast, it is free. + if (isa<BitCastInst>(I) && I->getType()->isPointerTy()) + continue; + + // Bail out if this instruction gives back a token type, it is not possible + // to duplicate it if it is used outside this BB. + if (I->getType()->isTokenTy() && I->isUsedOutsideOfBlock(BB)) + return ~0U; + + // All other instructions count for at least one unit. + ++Size; + + // Calls are more expensive. If they are non-intrinsic calls, we model them + // as having cost of 4. If they are a non-vector intrinsic, we model them + // as having cost of 2 total, and if they are a vector intrinsic, we model + // them as having cost 1. + if (const CallInst *CI = dyn_cast<CallInst>(I)) { + if (CI->cannotDuplicate() || CI->isConvergent()) + // Blocks with NoDuplicate are modelled as having infinite cost, so they + // are never duplicated. + return ~0U; + else if (!isa<IntrinsicInst>(CI)) + Size += 3; + else if (!CI->getType()->isVectorTy()) + Size += 1; + } + } + + return Size > Bonus ? Size - Bonus : 0; +} + +/// FindLoopHeaders - We do not want jump threading to turn proper loop +/// structures into irreducible loops. Doing this breaks up the loop nesting +/// hierarchy and pessimizes later transformations. To prevent this from +/// happening, we first have to find the loop headers. Here we approximate this +/// by finding targets of backedges in the CFG. +/// +/// Note that there definitely are cases when we want to allow threading of +/// edges across a loop header. For example, threading a jump from outside the +/// loop (the preheader) to an exit block of the loop is definitely profitable. +/// It is also almost always profitable to thread backedges from within the loop +/// to exit blocks, and is often profitable to thread backedges to other blocks +/// within the loop (forming a nested loop). This simple analysis is not rich +/// enough to track all of these properties and keep it up-to-date as the CFG +/// mutates, so we don't allow any of these transformations. +void JumpThreadingPass::FindLoopHeaders(Function &F) { + SmallVector<std::pair<const BasicBlock*,const BasicBlock*>, 32> Edges; + FindFunctionBackedges(F, Edges); + + for (const auto &Edge : Edges) + LoopHeaders.insert(Edge.second); +} + +/// getKnownConstant - Helper method to determine if we can thread over a +/// terminator with the given value as its condition, and if so what value to +/// use for that. What kind of value this is depends on whether we want an +/// integer or a block address, but an undef is always accepted. +/// Returns null if Val is null or not an appropriate constant. +static Constant *getKnownConstant(Value *Val, ConstantPreference Preference) { + if (!Val) + return nullptr; + + // Undef is "known" enough. + if (UndefValue *U = dyn_cast<UndefValue>(Val)) + return U; + + if (Preference == WantBlockAddress) + return dyn_cast<BlockAddress>(Val->stripPointerCasts()); + + return dyn_cast<ConstantInt>(Val); +} + +/// ComputeValueKnownInPredecessors - Given a basic block BB and a value V, see +/// if we can infer that the value is a known ConstantInt/BlockAddress or undef +/// in any of our predecessors. If so, return the known list of value and pred +/// BB in the result vector. +/// +/// This returns true if there were any known values. +bool JumpThreadingPass::ComputeValueKnownInPredecessors( + Value *V, BasicBlock *BB, PredValueInfo &Result, + ConstantPreference Preference, Instruction *CxtI) { + // This method walks up use-def chains recursively. Because of this, we could + // get into an infinite loop going around loops in the use-def chain. To + // prevent this, keep track of what (value, block) pairs we've already visited + // and terminate the search if we loop back to them + if (!RecursionSet.insert(std::make_pair(V, BB)).second) + return false; + + // An RAII help to remove this pair from the recursion set once the recursion + // stack pops back out again. + RecursionSetRemover remover(RecursionSet, std::make_pair(V, BB)); + + // If V is a constant, then it is known in all predecessors. + if (Constant *KC = getKnownConstant(V, Preference)) { + for (BasicBlock *Pred : predecessors(BB)) + Result.push_back(std::make_pair(KC, Pred)); + + return !Result.empty(); + } + + // If V is a non-instruction value, or an instruction in a different block, + // then it can't be derived from a PHI. + Instruction *I = dyn_cast<Instruction>(V); + if (!I || I->getParent() != BB) { + + // Okay, if this is a live-in value, see if it has a known value at the end + // of any of our predecessors. + // + // FIXME: This should be an edge property, not a block end property. + /// TODO: Per PR2563, we could infer value range information about a + /// predecessor based on its terminator. + // + // FIXME: change this to use the more-rich 'getPredicateOnEdge' method if + // "I" is a non-local compare-with-a-constant instruction. This would be + // able to handle value inequalities better, for example if the compare is + // "X < 4" and "X < 3" is known true but "X < 4" itself is not available. + // Perhaps getConstantOnEdge should be smart enough to do this? + + for (BasicBlock *P : predecessors(BB)) { + // If the value is known by LazyValueInfo to be a constant in a + // predecessor, use that information to try to thread this block. + Constant *PredCst = LVI->getConstantOnEdge(V, P, BB, CxtI); + if (Constant *KC = getKnownConstant(PredCst, Preference)) + Result.push_back(std::make_pair(KC, P)); + } + + return !Result.empty(); + } + + /// If I is a PHI node, then we know the incoming values for any constants. + if (PHINode *PN = dyn_cast<PHINode>(I)) { + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + Value *InVal = PN->getIncomingValue(i); + if (Constant *KC = getKnownConstant(InVal, Preference)) { + Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); + } else { + Constant *CI = LVI->getConstantOnEdge(InVal, + PN->getIncomingBlock(i), + BB, CxtI); + if (Constant *KC = getKnownConstant(CI, Preference)) + Result.push_back(std::make_pair(KC, PN->getIncomingBlock(i))); + } + } + + return !Result.empty(); + } + + // Handle Cast instructions. Only see through Cast when the source operand is + // PHI or Cmp and the source type is i1 to save the compilation time. + if (CastInst *CI = dyn_cast<CastInst>(I)) { + Value *Source = CI->getOperand(0); + if (!Source->getType()->isIntegerTy(1)) + return false; + if (!isa<PHINode>(Source) && !isa<CmpInst>(Source)) + return false; + ComputeValueKnownInPredecessors(Source, BB, Result, Preference, CxtI); + if (Result.empty()) + return false; + + // Convert the known values. + for (auto &R : Result) + R.first = ConstantExpr::getCast(CI->getOpcode(), R.first, CI->getType()); + + return true; + } + + // Handle some boolean conditions. + if (I->getType()->getPrimitiveSizeInBits() == 1) { + assert(Preference == WantInteger && "One-bit non-integer type?"); + // X | true -> true + // X & false -> false + if (I->getOpcode() == Instruction::Or || + I->getOpcode() == Instruction::And) { + PredValueInfoTy LHSVals, RHSVals; + + ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, + WantInteger, CxtI); + ComputeValueKnownInPredecessors(I->getOperand(1), BB, RHSVals, + WantInteger, CxtI); + + if (LHSVals.empty() && RHSVals.empty()) + return false; + + ConstantInt *InterestingVal; + if (I->getOpcode() == Instruction::Or) + InterestingVal = ConstantInt::getTrue(I->getContext()); + else + InterestingVal = ConstantInt::getFalse(I->getContext()); + + SmallPtrSet<BasicBlock*, 4> LHSKnownBBs; + + // Scan for the sentinel. If we find an undef, force it to the + // interesting value: x|undef -> true and x&undef -> false. + for (const auto &LHSVal : LHSVals) + if (LHSVal.first == InterestingVal || isa<UndefValue>(LHSVal.first)) { + Result.emplace_back(InterestingVal, LHSVal.second); + LHSKnownBBs.insert(LHSVal.second); + } + for (const auto &RHSVal : RHSVals) + if (RHSVal.first == InterestingVal || isa<UndefValue>(RHSVal.first)) { + // If we already inferred a value for this block on the LHS, don't + // re-add it. + if (!LHSKnownBBs.count(RHSVal.second)) + Result.emplace_back(InterestingVal, RHSVal.second); + } + + return !Result.empty(); + } + + // Handle the NOT form of XOR. + if (I->getOpcode() == Instruction::Xor && + isa<ConstantInt>(I->getOperand(1)) && + cast<ConstantInt>(I->getOperand(1))->isOne()) { + ComputeValueKnownInPredecessors(I->getOperand(0), BB, Result, + WantInteger, CxtI); + if (Result.empty()) + return false; + + // Invert the known values. + for (auto &R : Result) + R.first = ConstantExpr::getNot(R.first); + + return true; + } + + // Try to simplify some other binary operator values. + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) { + assert(Preference != WantBlockAddress + && "A binary operator creating a block address?"); + if (ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1))) { + PredValueInfoTy LHSVals; + ComputeValueKnownInPredecessors(BO->getOperand(0), BB, LHSVals, + WantInteger, CxtI); + + // Try to use constant folding to simplify the binary operator. + for (const auto &LHSVal : LHSVals) { + Constant *V = LHSVal.first; + Constant *Folded = ConstantExpr::get(BO->getOpcode(), V, CI); + + if (Constant *KC = getKnownConstant(Folded, WantInteger)) + Result.push_back(std::make_pair(KC, LHSVal.second)); + } + } + + return !Result.empty(); + } + + // Handle compare with phi operand, where the PHI is defined in this block. + if (CmpInst *Cmp = dyn_cast<CmpInst>(I)) { + assert(Preference == WantInteger && "Compares only produce integers"); + Type *CmpType = Cmp->getType(); + Value *CmpLHS = Cmp->getOperand(0); + Value *CmpRHS = Cmp->getOperand(1); + CmpInst::Predicate Pred = Cmp->getPredicate(); + + PHINode *PN = dyn_cast<PHINode>(CmpLHS); + if (PN && PN->getParent() == BB) { + const DataLayout &DL = PN->getModule()->getDataLayout(); + // We can do this simplification if any comparisons fold to true or false. + // See if any do. + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + BasicBlock *PredBB = PN->getIncomingBlock(i); + Value *LHS = PN->getIncomingValue(i); + Value *RHS = CmpRHS->DoPHITranslation(BB, PredBB); + + Value *Res = SimplifyCmpInst(Pred, LHS, RHS, {DL}); + if (!Res) { + if (!isa<Constant>(RHS)) + continue; + + LazyValueInfo::Tristate + ResT = LVI->getPredicateOnEdge(Pred, LHS, + cast<Constant>(RHS), PredBB, BB, + CxtI ? CxtI : Cmp); + if (ResT == LazyValueInfo::Unknown) + continue; + Res = ConstantInt::get(Type::getInt1Ty(LHS->getContext()), ResT); + } + + if (Constant *KC = getKnownConstant(Res, WantInteger)) + Result.push_back(std::make_pair(KC, PredBB)); + } + + return !Result.empty(); + } + + // If comparing a live-in value against a constant, see if we know the + // live-in value on any predecessors. + if (isa<Constant>(CmpRHS) && !CmpType->isVectorTy()) { + Constant *CmpConst = cast<Constant>(CmpRHS); + + if (!isa<Instruction>(CmpLHS) || + cast<Instruction>(CmpLHS)->getParent() != BB) { + for (BasicBlock *P : predecessors(BB)) { + // If the value is known by LazyValueInfo to be a constant in a + // predecessor, use that information to try to thread this block. + LazyValueInfo::Tristate Res = + LVI->getPredicateOnEdge(Pred, CmpLHS, + CmpConst, P, BB, CxtI ? CxtI : Cmp); + if (Res == LazyValueInfo::Unknown) + continue; + + Constant *ResC = ConstantInt::get(CmpType, Res); + Result.push_back(std::make_pair(ResC, P)); + } + + return !Result.empty(); + } + + // InstCombine can fold some forms of constant range checks into + // (icmp (add (x, C1)), C2). See if we have we have such a thing with + // x as a live-in. + { + using namespace PatternMatch; + + Value *AddLHS; + ConstantInt *AddConst; + if (isa<ConstantInt>(CmpConst) && + match(CmpLHS, m_Add(m_Value(AddLHS), m_ConstantInt(AddConst)))) { + if (!isa<Instruction>(AddLHS) || + cast<Instruction>(AddLHS)->getParent() != BB) { + for (BasicBlock *P : predecessors(BB)) { + // If the value is known by LazyValueInfo to be a ConstantRange in + // a predecessor, use that information to try to thread this + // block. + ConstantRange CR = LVI->getConstantRangeOnEdge( + AddLHS, P, BB, CxtI ? CxtI : cast<Instruction>(CmpLHS)); + // Propagate the range through the addition. + CR = CR.add(AddConst->getValue()); + + // Get the range where the compare returns true. + ConstantRange CmpRange = ConstantRange::makeExactICmpRegion( + Pred, cast<ConstantInt>(CmpConst)->getValue()); + + Constant *ResC; + if (CmpRange.contains(CR)) + ResC = ConstantInt::getTrue(CmpType); + else if (CmpRange.inverse().contains(CR)) + ResC = ConstantInt::getFalse(CmpType); + else + continue; + + Result.push_back(std::make_pair(ResC, P)); + } + + return !Result.empty(); + } + } + } + + // Try to find a constant value for the LHS of a comparison, + // and evaluate it statically if we can. + PredValueInfoTy LHSVals; + ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals, + WantInteger, CxtI); + + for (const auto &LHSVal : LHSVals) { + Constant *V = LHSVal.first; + Constant *Folded = ConstantExpr::getCompare(Pred, V, CmpConst); + if (Constant *KC = getKnownConstant(Folded, WantInteger)) + Result.push_back(std::make_pair(KC, LHSVal.second)); + } + + return !Result.empty(); + } + } + + if (SelectInst *SI = dyn_cast<SelectInst>(I)) { + // Handle select instructions where at least one operand is a known constant + // and we can figure out the condition value for any predecessor block. + Constant *TrueVal = getKnownConstant(SI->getTrueValue(), Preference); + Constant *FalseVal = getKnownConstant(SI->getFalseValue(), Preference); + PredValueInfoTy Conds; + if ((TrueVal || FalseVal) && + ComputeValueKnownInPredecessors(SI->getCondition(), BB, Conds, + WantInteger, CxtI)) { + for (auto &C : Conds) { + Constant *Cond = C.first; + + // Figure out what value to use for the condition. + bool KnownCond; + if (ConstantInt *CI = dyn_cast<ConstantInt>(Cond)) { + // A known boolean. + KnownCond = CI->isOne(); + } else { + assert(isa<UndefValue>(Cond) && "Unexpected condition value"); + // Either operand will do, so be sure to pick the one that's a known + // constant. + // FIXME: Do this more cleverly if both values are known constants? + KnownCond = (TrueVal != nullptr); + } + + // See if the select has a known constant value for this predecessor. + if (Constant *Val = KnownCond ? TrueVal : FalseVal) + Result.push_back(std::make_pair(Val, C.second)); + } + + return !Result.empty(); + } + } + + // If all else fails, see if LVI can figure out a constant value for us. + Constant *CI = LVI->getConstant(V, BB, CxtI); + if (Constant *KC = getKnownConstant(CI, Preference)) { + for (BasicBlock *Pred : predecessors(BB)) + Result.push_back(std::make_pair(KC, Pred)); + } + + return !Result.empty(); +} + +/// GetBestDestForBranchOnUndef - If we determine that the specified block ends +/// in an undefined jump, decide which block is best to revector to. +/// +/// Since we can pick an arbitrary destination, we pick the successor with the +/// fewest predecessors. This should reduce the in-degree of the others. +static unsigned GetBestDestForJumpOnUndef(BasicBlock *BB) { + TerminatorInst *BBTerm = BB->getTerminator(); + unsigned MinSucc = 0; + BasicBlock *TestBB = BBTerm->getSuccessor(MinSucc); + // Compute the successor with the minimum number of predecessors. + unsigned MinNumPreds = std::distance(pred_begin(TestBB), pred_end(TestBB)); + for (unsigned i = 1, e = BBTerm->getNumSuccessors(); i != e; ++i) { + TestBB = BBTerm->getSuccessor(i); + unsigned NumPreds = std::distance(pred_begin(TestBB), pred_end(TestBB)); + if (NumPreds < MinNumPreds) { + MinSucc = i; + MinNumPreds = NumPreds; + } + } + + return MinSucc; +} + +static bool hasAddressTakenAndUsed(BasicBlock *BB) { + if (!BB->hasAddressTaken()) return false; + + // If the block has its address taken, it may be a tree of dead constants + // hanging off of it. These shouldn't keep the block alive. + BlockAddress *BA = BlockAddress::get(BB); + BA->removeDeadConstantUsers(); + return !BA->use_empty(); +} + +/// ProcessBlock - If there are any predecessors whose control can be threaded +/// through to a successor, transform them now. +bool JumpThreadingPass::ProcessBlock(BasicBlock *BB) { + // If the block is trivially dead, just return and let the caller nuke it. + // This simplifies other transformations. + if (pred_empty(BB) && + BB != &BB->getParent()->getEntryBlock()) + return false; + + // If this block has a single predecessor, and if that pred has a single + // successor, merge the blocks. This encourages recursive jump threading + // because now the condition in this block can be threaded through + // predecessors of our predecessor block. + if (BasicBlock *SinglePred = BB->getSinglePredecessor()) { + const TerminatorInst *TI = SinglePred->getTerminator(); + if (!TI->isExceptional() && TI->getNumSuccessors() == 1 && + SinglePred != BB && !hasAddressTakenAndUsed(BB)) { + // If SinglePred was a loop header, BB becomes one. + if (LoopHeaders.erase(SinglePred)) + LoopHeaders.insert(BB); + + LVI->eraseBlock(SinglePred); + MergeBasicBlockIntoOnlyPred(BB); + + // Now that BB is merged into SinglePred (i.e. SinglePred Code followed by + // BB code within one basic block `BB`), we need to invalidate the LVI + // information associated with BB, because the LVI information need not be + // true for all of BB after the merge. For example, + // Before the merge, LVI info and code is as follows: + // SinglePred: <LVI info1 for %p val> + // %y = use of %p + // call @exit() // need not transfer execution to successor. + // assume(%p) // from this point on %p is true + // br label %BB + // BB: <LVI info2 for %p val, i.e. %p is true> + // %x = use of %p + // br label exit + // + // Note that this LVI info for blocks BB and SinglPred is correct for %p + // (info2 and info1 respectively). After the merge and the deletion of the + // LVI info1 for SinglePred. We have the following code: + // BB: <LVI info2 for %p val> + // %y = use of %p + // call @exit() + // assume(%p) + // %x = use of %p <-- LVI info2 is correct from here onwards. + // br label exit + // LVI info2 for BB is incorrect at the beginning of BB. + + // Invalidate LVI information for BB if the LVI is not provably true for + // all of BB. + if (any_of(*BB, [](Instruction &I) { + return !isGuaranteedToTransferExecutionToSuccessor(&I); + })) + LVI->eraseBlock(BB); + return true; + } + } + + if (TryToUnfoldSelectInCurrBB(BB)) + return true; + + // Look if we can propagate guards to predecessors. + if (HasGuards && ProcessGuards(BB)) + return true; + + // What kind of constant we're looking for. + ConstantPreference Preference = WantInteger; + + // Look to see if the terminator is a conditional branch, switch or indirect + // branch, if not we can't thread it. + Value *Condition; + Instruction *Terminator = BB->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(Terminator)) { + // Can't thread an unconditional jump. + if (BI->isUnconditional()) return false; + Condition = BI->getCondition(); + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(Terminator)) { + Condition = SI->getCondition(); + } else if (IndirectBrInst *IB = dyn_cast<IndirectBrInst>(Terminator)) { + // Can't thread indirect branch with no successors. + if (IB->getNumSuccessors() == 0) return false; + Condition = IB->getAddress()->stripPointerCasts(); + Preference = WantBlockAddress; + } else { + return false; // Must be an invoke. + } + + // Run constant folding to see if we can reduce the condition to a simple + // constant. + if (Instruction *I = dyn_cast<Instruction>(Condition)) { + Value *SimpleVal = + ConstantFoldInstruction(I, BB->getModule()->getDataLayout(), TLI); + if (SimpleVal) { + I->replaceAllUsesWith(SimpleVal); + if (isInstructionTriviallyDead(I, TLI)) + I->eraseFromParent(); + Condition = SimpleVal; + } + } + + // If the terminator is branching on an undef, we can pick any of the + // successors to branch to. Let GetBestDestForJumpOnUndef decide. + if (isa<UndefValue>(Condition)) { + unsigned BestSucc = GetBestDestForJumpOnUndef(BB); + + // Fold the branch/switch. + TerminatorInst *BBTerm = BB->getTerminator(); + for (unsigned i = 0, e = BBTerm->getNumSuccessors(); i != e; ++i) { + if (i == BestSucc) continue; + BBTerm->getSuccessor(i)->removePredecessor(BB, true); + } + + DEBUG(dbgs() << " In block '" << BB->getName() + << "' folding undef terminator: " << *BBTerm << '\n'); + BranchInst::Create(BBTerm->getSuccessor(BestSucc), BBTerm); + BBTerm->eraseFromParent(); + return true; + } + + // If the terminator of this block is branching on a constant, simplify the + // terminator to an unconditional branch. This can occur due to threading in + // other blocks. + if (getKnownConstant(Condition, Preference)) { + DEBUG(dbgs() << " In block '" << BB->getName() + << "' folding terminator: " << *BB->getTerminator() << '\n'); + ++NumFolds; + ConstantFoldTerminator(BB, true); + return true; + } + + Instruction *CondInst = dyn_cast<Instruction>(Condition); + + // All the rest of our checks depend on the condition being an instruction. + if (!CondInst) { + // FIXME: Unify this with code below. + if (ProcessThreadableEdges(Condition, BB, Preference, Terminator)) + return true; + return false; + } + + if (CmpInst *CondCmp = dyn_cast<CmpInst>(CondInst)) { + // If we're branching on a conditional, LVI might be able to determine + // it's value at the branch instruction. We only handle comparisons + // against a constant at this time. + // TODO: This should be extended to handle switches as well. + BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); + Constant *CondConst = dyn_cast<Constant>(CondCmp->getOperand(1)); + if (CondBr && CondConst) { + // We should have returned as soon as we turn a conditional branch to + // unconditional. Because its no longer interesting as far as jump + // threading is concerned. + assert(CondBr->isConditional() && "Threading on unconditional terminator"); + + LazyValueInfo::Tristate Ret = + LVI->getPredicateAt(CondCmp->getPredicate(), CondCmp->getOperand(0), + CondConst, CondBr); + if (Ret != LazyValueInfo::Unknown) { + unsigned ToRemove = Ret == LazyValueInfo::True ? 1 : 0; + unsigned ToKeep = Ret == LazyValueInfo::True ? 0 : 1; + CondBr->getSuccessor(ToRemove)->removePredecessor(BB, true); + BranchInst::Create(CondBr->getSuccessor(ToKeep), CondBr); + CondBr->eraseFromParent(); + if (CondCmp->use_empty()) + CondCmp->eraseFromParent(); + // We can safely replace *some* uses of the CondInst if it has + // exactly one value as returned by LVI. RAUW is incorrect in the + // presence of guards and assumes, that have the `Cond` as the use. This + // is because we use the guards/assume to reason about the `Cond` value + // at the end of block, but RAUW unconditionally replaces all uses + // including the guards/assumes themselves and the uses before the + // guard/assume. + else if (CondCmp->getParent() == BB) { + auto *CI = Ret == LazyValueInfo::True ? + ConstantInt::getTrue(CondCmp->getType()) : + ConstantInt::getFalse(CondCmp->getType()); + ReplaceFoldableUses(CondCmp, CI); + } + return true; + } + + // We did not manage to simplify this branch, try to see whether + // CondCmp depends on a known phi-select pattern. + if (TryToUnfoldSelect(CondCmp, BB)) + return true; + } + } + + // Check for some cases that are worth simplifying. Right now we want to look + // for loads that are used by a switch or by the condition for the branch. If + // we see one, check to see if it's partially redundant. If so, insert a PHI + // which can then be used to thread the values. + Value *SimplifyValue = CondInst; + if (CmpInst *CondCmp = dyn_cast<CmpInst>(SimplifyValue)) + if (isa<Constant>(CondCmp->getOperand(1))) + SimplifyValue = CondCmp->getOperand(0); + + // TODO: There are other places where load PRE would be profitable, such as + // more complex comparisons. + if (LoadInst *LI = dyn_cast<LoadInst>(SimplifyValue)) + if (SimplifyPartiallyRedundantLoad(LI)) + return true; + + // Before threading, try to propagate profile data backwards: + if (PHINode *PN = dyn_cast<PHINode>(CondInst)) + if (PN->getParent() == BB && isa<BranchInst>(BB->getTerminator())) + updatePredecessorProfileMetadata(PN, BB); + + // Handle a variety of cases where we are branching on something derived from + // a PHI node in the current block. If we can prove that any predecessors + // compute a predictable value based on a PHI node, thread those predecessors. + if (ProcessThreadableEdges(CondInst, BB, Preference, Terminator)) + return true; + + // If this is an otherwise-unfoldable branch on a phi node in the current + // block, see if we can simplify. + if (PHINode *PN = dyn_cast<PHINode>(CondInst)) + if (PN->getParent() == BB && isa<BranchInst>(BB->getTerminator())) + return ProcessBranchOnPHI(PN); + + // If this is an otherwise-unfoldable branch on a XOR, see if we can simplify. + if (CondInst->getOpcode() == Instruction::Xor && + CondInst->getParent() == BB && isa<BranchInst>(BB->getTerminator())) + return ProcessBranchOnXOR(cast<BinaryOperator>(CondInst)); + + // Search for a stronger dominating condition that can be used to simplify a + // conditional branch leaving BB. + if (ProcessImpliedCondition(BB)) + return true; + + return false; +} + +bool JumpThreadingPass::ProcessImpliedCondition(BasicBlock *BB) { + auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional()) + return false; + + Value *Cond = BI->getCondition(); + BasicBlock *CurrentBB = BB; + BasicBlock *CurrentPred = BB->getSinglePredecessor(); + unsigned Iter = 0; + + auto &DL = BB->getModule()->getDataLayout(); + + while (CurrentPred && Iter++ < ImplicationSearchThreshold) { + auto *PBI = dyn_cast<BranchInst>(CurrentPred->getTerminator()); + if (!PBI || !PBI->isConditional()) + return false; + if (PBI->getSuccessor(0) != CurrentBB && PBI->getSuccessor(1) != CurrentBB) + return false; + + bool CondIsTrue = PBI->getSuccessor(0) == CurrentBB; + Optional<bool> Implication = + isImpliedCondition(PBI->getCondition(), Cond, DL, CondIsTrue); + if (Implication) { + BI->getSuccessor(*Implication ? 1 : 0)->removePredecessor(BB); + BranchInst::Create(BI->getSuccessor(*Implication ? 0 : 1), BI); + BI->eraseFromParent(); + return true; + } + CurrentBB = CurrentPred; + CurrentPred = CurrentBB->getSinglePredecessor(); + } + + return false; +} + +/// Return true if Op is an instruction defined in the given block. +static bool isOpDefinedInBlock(Value *Op, BasicBlock *BB) { + if (Instruction *OpInst = dyn_cast<Instruction>(Op)) + if (OpInst->getParent() == BB) + return true; + return false; +} + +/// SimplifyPartiallyRedundantLoad - If LI is an obviously partially redundant +/// load instruction, eliminate it by replacing it with a PHI node. This is an +/// important optimization that encourages jump threading, and needs to be run +/// interlaced with other jump threading tasks. +bool JumpThreadingPass::SimplifyPartiallyRedundantLoad(LoadInst *LI) { + // Don't hack volatile and ordered loads. + if (!LI->isUnordered()) return false; + + // If the load is defined in a block with exactly one predecessor, it can't be + // partially redundant. + BasicBlock *LoadBB = LI->getParent(); + if (LoadBB->getSinglePredecessor()) + return false; + + // If the load is defined in an EH pad, it can't be partially redundant, + // because the edges between the invoke and the EH pad cannot have other + // instructions between them. + if (LoadBB->isEHPad()) + return false; + + Value *LoadedPtr = LI->getOperand(0); + + // If the loaded operand is defined in the LoadBB and its not a phi, + // it can't be available in predecessors. + if (isOpDefinedInBlock(LoadedPtr, LoadBB) && !isa<PHINode>(LoadedPtr)) + return false; + + // Scan a few instructions up from the load, to see if it is obviously live at + // the entry to its block. + BasicBlock::iterator BBIt(LI); + bool IsLoadCSE; + if (Value *AvailableVal = FindAvailableLoadedValue( + LI, LoadBB, BBIt, DefMaxInstsToScan, AA, &IsLoadCSE)) { + // If the value of the load is locally available within the block, just use + // it. This frequently occurs for reg2mem'd allocas. + + if (IsLoadCSE) { + LoadInst *NLI = cast<LoadInst>(AvailableVal); + combineMetadataForCSE(NLI, LI); + }; + + // If the returned value is the load itself, replace with an undef. This can + // only happen in dead loops. + if (AvailableVal == LI) AvailableVal = UndefValue::get(LI->getType()); + if (AvailableVal->getType() != LI->getType()) + AvailableVal = + CastInst::CreateBitOrPointerCast(AvailableVal, LI->getType(), "", LI); + LI->replaceAllUsesWith(AvailableVal); + LI->eraseFromParent(); + return true; + } + + // Otherwise, if we scanned the whole block and got to the top of the block, + // we know the block is locally transparent to the load. If not, something + // might clobber its value. + if (BBIt != LoadBB->begin()) + return false; + + // If all of the loads and stores that feed the value have the same AA tags, + // then we can propagate them onto any newly inserted loads. + AAMDNodes AATags; + LI->getAAMetadata(AATags); + + SmallPtrSet<BasicBlock*, 8> PredsScanned; + + using AvailablePredsTy = SmallVector<std::pair<BasicBlock *, Value *>, 8>; + + AvailablePredsTy AvailablePreds; + BasicBlock *OneUnavailablePred = nullptr; + SmallVector<LoadInst*, 8> CSELoads; + + // If we got here, the loaded value is transparent through to the start of the + // block. Check to see if it is available in any of the predecessor blocks. + for (BasicBlock *PredBB : predecessors(LoadBB)) { + // If we already scanned this predecessor, skip it. + if (!PredsScanned.insert(PredBB).second) + continue; + + BBIt = PredBB->end(); + unsigned NumScanedInst = 0; + Value *PredAvailable = nullptr; + // NOTE: We don't CSE load that is volatile or anything stronger than + // unordered, that should have been checked when we entered the function. + assert(LI->isUnordered() && "Attempting to CSE volatile or atomic loads"); + // If this is a load on a phi pointer, phi-translate it and search + // for available load/store to the pointer in predecessors. + Value *Ptr = LoadedPtr->DoPHITranslation(LoadBB, PredBB); + PredAvailable = FindAvailablePtrLoadStore( + Ptr, LI->getType(), LI->isAtomic(), PredBB, BBIt, DefMaxInstsToScan, + AA, &IsLoadCSE, &NumScanedInst); + + // If PredBB has a single predecessor, continue scanning through the + // single precessor. + BasicBlock *SinglePredBB = PredBB; + while (!PredAvailable && SinglePredBB && BBIt == SinglePredBB->begin() && + NumScanedInst < DefMaxInstsToScan) { + SinglePredBB = SinglePredBB->getSinglePredecessor(); + if (SinglePredBB) { + BBIt = SinglePredBB->end(); + PredAvailable = FindAvailablePtrLoadStore( + Ptr, LI->getType(), LI->isAtomic(), SinglePredBB, BBIt, + (DefMaxInstsToScan - NumScanedInst), AA, &IsLoadCSE, + &NumScanedInst); + } + } + + if (!PredAvailable) { + OneUnavailablePred = PredBB; + continue; + } + + if (IsLoadCSE) + CSELoads.push_back(cast<LoadInst>(PredAvailable)); + + // If so, this load is partially redundant. Remember this info so that we + // can create a PHI node. + AvailablePreds.push_back(std::make_pair(PredBB, PredAvailable)); + } + + // If the loaded value isn't available in any predecessor, it isn't partially + // redundant. + if (AvailablePreds.empty()) return false; + + // Okay, the loaded value is available in at least one (and maybe all!) + // predecessors. If the value is unavailable in more than one unique + // predecessor, we want to insert a merge block for those common predecessors. + // This ensures that we only have to insert one reload, thus not increasing + // code size. + BasicBlock *UnavailablePred = nullptr; + + // If the value is unavailable in one of predecessors, we will end up + // inserting a new instruction into them. It is only valid if all the + // instructions before LI are guaranteed to pass execution to its successor, + // or if LI is safe to speculate. + // TODO: If this logic becomes more complex, and we will perform PRE insertion + // farther than to a predecessor, we need to reuse the code from GVN's PRE. + // It requires domination tree analysis, so for this simple case it is an + // overkill. + if (PredsScanned.size() != AvailablePreds.size() && + !isSafeToSpeculativelyExecute(LI)) + for (auto I = LoadBB->begin(); &*I != LI; ++I) + if (!isGuaranteedToTransferExecutionToSuccessor(&*I)) + return false; + + // If there is exactly one predecessor where the value is unavailable, the + // already computed 'OneUnavailablePred' block is it. If it ends in an + // unconditional branch, we know that it isn't a critical edge. + if (PredsScanned.size() == AvailablePreds.size()+1 && + OneUnavailablePred->getTerminator()->getNumSuccessors() == 1) { + UnavailablePred = OneUnavailablePred; + } else if (PredsScanned.size() != AvailablePreds.size()) { + // Otherwise, we had multiple unavailable predecessors or we had a critical + // edge from the one. + SmallVector<BasicBlock*, 8> PredsToSplit; + SmallPtrSet<BasicBlock*, 8> AvailablePredSet; + + for (const auto &AvailablePred : AvailablePreds) + AvailablePredSet.insert(AvailablePred.first); + + // Add all the unavailable predecessors to the PredsToSplit list. + for (BasicBlock *P : predecessors(LoadBB)) { + // If the predecessor is an indirect goto, we can't split the edge. + if (isa<IndirectBrInst>(P->getTerminator())) + return false; + + if (!AvailablePredSet.count(P)) + PredsToSplit.push_back(P); + } + + // Split them out to their own block. + UnavailablePred = SplitBlockPreds(LoadBB, PredsToSplit, "thread-pre-split"); + } + + // If the value isn't available in all predecessors, then there will be + // exactly one where it isn't available. Insert a load on that edge and add + // it to the AvailablePreds list. + if (UnavailablePred) { + assert(UnavailablePred->getTerminator()->getNumSuccessors() == 1 && + "Can't handle critical edge here!"); + LoadInst *NewVal = new LoadInst( + LoadedPtr->DoPHITranslation(LoadBB, UnavailablePred), + LI->getName() + ".pr", false, LI->getAlignment(), LI->getOrdering(), + LI->getSyncScopeID(), UnavailablePred->getTerminator()); + NewVal->setDebugLoc(LI->getDebugLoc()); + if (AATags) + NewVal->setAAMetadata(AATags); + + AvailablePreds.push_back(std::make_pair(UnavailablePred, NewVal)); + } + + // Now we know that each predecessor of this block has a value in + // AvailablePreds, sort them for efficient access as we're walking the preds. + array_pod_sort(AvailablePreds.begin(), AvailablePreds.end()); + + // Create a PHI node at the start of the block for the PRE'd load value. + pred_iterator PB = pred_begin(LoadBB), PE = pred_end(LoadBB); + PHINode *PN = PHINode::Create(LI->getType(), std::distance(PB, PE), "", + &LoadBB->front()); + PN->takeName(LI); + PN->setDebugLoc(LI->getDebugLoc()); + + // Insert new entries into the PHI for each predecessor. A single block may + // have multiple entries here. + for (pred_iterator PI = PB; PI != PE; ++PI) { + BasicBlock *P = *PI; + AvailablePredsTy::iterator I = + std::lower_bound(AvailablePreds.begin(), AvailablePreds.end(), + std::make_pair(P, (Value*)nullptr)); + + assert(I != AvailablePreds.end() && I->first == P && + "Didn't find entry for predecessor!"); + + // If we have an available predecessor but it requires casting, insert the + // cast in the predecessor and use the cast. Note that we have to update the + // AvailablePreds vector as we go so that all of the PHI entries for this + // predecessor use the same bitcast. + Value *&PredV = I->second; + if (PredV->getType() != LI->getType()) + PredV = CastInst::CreateBitOrPointerCast(PredV, LI->getType(), "", + P->getTerminator()); + + PN->addIncoming(PredV, I->first); + } + + for (LoadInst *PredLI : CSELoads) { + combineMetadataForCSE(PredLI, LI); + } + + LI->replaceAllUsesWith(PN); + LI->eraseFromParent(); + + return true; +} + +/// FindMostPopularDest - The specified list contains multiple possible +/// threadable destinations. Pick the one that occurs the most frequently in +/// the list. +static BasicBlock * +FindMostPopularDest(BasicBlock *BB, + const SmallVectorImpl<std::pair<BasicBlock *, + BasicBlock *>> &PredToDestList) { + assert(!PredToDestList.empty()); + + // Determine popularity. If there are multiple possible destinations, we + // explicitly choose to ignore 'undef' destinations. We prefer to thread + // blocks with known and real destinations to threading undef. We'll handle + // them later if interesting. + DenseMap<BasicBlock*, unsigned> DestPopularity; + for (const auto &PredToDest : PredToDestList) + if (PredToDest.second) + DestPopularity[PredToDest.second]++; + + // Find the most popular dest. + DenseMap<BasicBlock*, unsigned>::iterator DPI = DestPopularity.begin(); + BasicBlock *MostPopularDest = DPI->first; + unsigned Popularity = DPI->second; + SmallVector<BasicBlock*, 4> SamePopularity; + + for (++DPI; DPI != DestPopularity.end(); ++DPI) { + // If the popularity of this entry isn't higher than the popularity we've + // seen so far, ignore it. + if (DPI->second < Popularity) + ; // ignore. + else if (DPI->second == Popularity) { + // If it is the same as what we've seen so far, keep track of it. + SamePopularity.push_back(DPI->first); + } else { + // If it is more popular, remember it. + SamePopularity.clear(); + MostPopularDest = DPI->first; + Popularity = DPI->second; + } + } + + // Okay, now we know the most popular destination. If there is more than one + // destination, we need to determine one. This is arbitrary, but we need + // to make a deterministic decision. Pick the first one that appears in the + // successor list. + if (!SamePopularity.empty()) { + SamePopularity.push_back(MostPopularDest); + TerminatorInst *TI = BB->getTerminator(); + for (unsigned i = 0; ; ++i) { + assert(i != TI->getNumSuccessors() && "Didn't find any successor!"); + + if (!is_contained(SamePopularity, TI->getSuccessor(i))) + continue; + + MostPopularDest = TI->getSuccessor(i); + break; + } + } + + // Okay, we have finally picked the most popular destination. + return MostPopularDest; +} + +bool JumpThreadingPass::ProcessThreadableEdges(Value *Cond, BasicBlock *BB, + ConstantPreference Preference, + Instruction *CxtI) { + // If threading this would thread across a loop header, don't even try to + // thread the edge. + if (LoopHeaders.count(BB)) + return false; + + PredValueInfoTy PredValues; + if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues, Preference, CxtI)) + return false; + + assert(!PredValues.empty() && + "ComputeValueKnownInPredecessors returned true with no values"); + + DEBUG(dbgs() << "IN BB: " << *BB; + for (const auto &PredValue : PredValues) { + dbgs() << " BB '" << BB->getName() << "': FOUND condition = " + << *PredValue.first + << " for pred '" << PredValue.second->getName() << "'.\n"; + }); + + // Decide what we want to thread through. Convert our list of known values to + // a list of known destinations for each pred. This also discards duplicate + // predecessors and keeps track of the undefined inputs (which are represented + // as a null dest in the PredToDestList). + SmallPtrSet<BasicBlock*, 16> SeenPreds; + SmallVector<std::pair<BasicBlock*, BasicBlock*>, 16> PredToDestList; + + BasicBlock *OnlyDest = nullptr; + BasicBlock *MultipleDestSentinel = (BasicBlock*)(intptr_t)~0ULL; + Constant *OnlyVal = nullptr; + Constant *MultipleVal = (Constant *)(intptr_t)~0ULL; + + unsigned PredWithKnownDest = 0; + for (const auto &PredValue : PredValues) { + BasicBlock *Pred = PredValue.second; + if (!SeenPreds.insert(Pred).second) + continue; // Duplicate predecessor entry. + + Constant *Val = PredValue.first; + + BasicBlock *DestBB; + if (isa<UndefValue>(Val)) + DestBB = nullptr; + else if (BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator())) { + assert(isa<ConstantInt>(Val) && "Expecting a constant integer"); + DestBB = BI->getSuccessor(cast<ConstantInt>(Val)->isZero()); + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB->getTerminator())) { + assert(isa<ConstantInt>(Val) && "Expecting a constant integer"); + DestBB = SI->findCaseValue(cast<ConstantInt>(Val))->getCaseSuccessor(); + } else { + assert(isa<IndirectBrInst>(BB->getTerminator()) + && "Unexpected terminator"); + assert(isa<BlockAddress>(Val) && "Expecting a constant blockaddress"); + DestBB = cast<BlockAddress>(Val)->getBasicBlock(); + } + + // If we have exactly one destination, remember it for efficiency below. + if (PredToDestList.empty()) { + OnlyDest = DestBB; + OnlyVal = Val; + } else { + if (OnlyDest != DestBB) + OnlyDest = MultipleDestSentinel; + // It possible we have same destination, but different value, e.g. default + // case in switchinst. + if (Val != OnlyVal) + OnlyVal = MultipleVal; + } + + // We know where this predecessor is going. + ++PredWithKnownDest; + + // If the predecessor ends with an indirect goto, we can't change its + // destination. + if (isa<IndirectBrInst>(Pred->getTerminator())) + continue; + + PredToDestList.push_back(std::make_pair(Pred, DestBB)); + } + + // If all edges were unthreadable, we fail. + if (PredToDestList.empty()) + return false; + + // If all the predecessors go to a single known successor, we want to fold, + // not thread. By doing so, we do not need to duplicate the current block and + // also miss potential opportunities in case we dont/cant duplicate. + if (OnlyDest && OnlyDest != MultipleDestSentinel) { + if (PredWithKnownDest == + (size_t)std::distance(pred_begin(BB), pred_end(BB))) { + bool SeenFirstBranchToOnlyDest = false; + for (BasicBlock *SuccBB : successors(BB)) { + if (SuccBB == OnlyDest && !SeenFirstBranchToOnlyDest) + SeenFirstBranchToOnlyDest = true; // Don't modify the first branch. + else + SuccBB->removePredecessor(BB, true); // This is unreachable successor. + } + + // Finally update the terminator. + TerminatorInst *Term = BB->getTerminator(); + BranchInst::Create(OnlyDest, Term); + Term->eraseFromParent(); + + // If the condition is now dead due to the removal of the old terminator, + // erase it. + if (auto *CondInst = dyn_cast<Instruction>(Cond)) { + if (CondInst->use_empty() && !CondInst->mayHaveSideEffects()) + CondInst->eraseFromParent(); + // We can safely replace *some* uses of the CondInst if it has + // exactly one value as returned by LVI. RAUW is incorrect in the + // presence of guards and assumes, that have the `Cond` as the use. This + // is because we use the guards/assume to reason about the `Cond` value + // at the end of block, but RAUW unconditionally replaces all uses + // including the guards/assumes themselves and the uses before the + // guard/assume. + else if (OnlyVal && OnlyVal != MultipleVal && + CondInst->getParent() == BB) + ReplaceFoldableUses(CondInst, OnlyVal); + } + return true; + } + } + + // Determine which is the most common successor. If we have many inputs and + // this block is a switch, we want to start by threading the batch that goes + // to the most popular destination first. If we only know about one + // threadable destination (the common case) we can avoid this. + BasicBlock *MostPopularDest = OnlyDest; + + if (MostPopularDest == MultipleDestSentinel) + MostPopularDest = FindMostPopularDest(BB, PredToDestList); + + // Now that we know what the most popular destination is, factor all + // predecessors that will jump to it into a single predecessor. + SmallVector<BasicBlock*, 16> PredsToFactor; + for (const auto &PredToDest : PredToDestList) + if (PredToDest.second == MostPopularDest) { + BasicBlock *Pred = PredToDest.first; + + // This predecessor may be a switch or something else that has multiple + // edges to the block. Factor each of these edges by listing them + // according to # occurrences in PredsToFactor. + for (BasicBlock *Succ : successors(Pred)) + if (Succ == BB) + PredsToFactor.push_back(Pred); + } + + // If the threadable edges are branching on an undefined value, we get to pick + // the destination that these predecessors should get to. + if (!MostPopularDest) + MostPopularDest = BB->getTerminator()-> + getSuccessor(GetBestDestForJumpOnUndef(BB)); + + // Ok, try to thread it! + return ThreadEdge(BB, PredsToFactor, MostPopularDest); +} + +/// ProcessBranchOnPHI - We have an otherwise unthreadable conditional branch on +/// a PHI node in the current block. See if there are any simplifications we +/// can do based on inputs to the phi node. +bool JumpThreadingPass::ProcessBranchOnPHI(PHINode *PN) { + BasicBlock *BB = PN->getParent(); + + // TODO: We could make use of this to do it once for blocks with common PHI + // values. + SmallVector<BasicBlock*, 1> PredBBs; + PredBBs.resize(1); + + // If any of the predecessor blocks end in an unconditional branch, we can + // *duplicate* the conditional branch into that block in order to further + // encourage jump threading and to eliminate cases where we have branch on a + // phi of an icmp (branch on icmp is much better). + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) { + BasicBlock *PredBB = PN->getIncomingBlock(i); + if (BranchInst *PredBr = dyn_cast<BranchInst>(PredBB->getTerminator())) + if (PredBr->isUnconditional()) { + PredBBs[0] = PredBB; + // Try to duplicate BB into PredBB. + if (DuplicateCondBranchOnPHIIntoPred(BB, PredBBs)) + return true; + } + } + + return false; +} + +/// ProcessBranchOnXOR - We have an otherwise unthreadable conditional branch on +/// a xor instruction in the current block. See if there are any +/// simplifications we can do based on inputs to the xor. +bool JumpThreadingPass::ProcessBranchOnXOR(BinaryOperator *BO) { + BasicBlock *BB = BO->getParent(); + + // If either the LHS or RHS of the xor is a constant, don't do this + // optimization. + if (isa<ConstantInt>(BO->getOperand(0)) || + isa<ConstantInt>(BO->getOperand(1))) + return false; + + // If the first instruction in BB isn't a phi, we won't be able to infer + // anything special about any particular predecessor. + if (!isa<PHINode>(BB->front())) + return false; + + // If this BB is a landing pad, we won't be able to split the edge into it. + if (BB->isEHPad()) + return false; + + // If we have a xor as the branch input to this block, and we know that the + // LHS or RHS of the xor in any predecessor is true/false, then we can clone + // the condition into the predecessor and fix that value to true, saving some + // logical ops on that path and encouraging other paths to simplify. + // + // This copies something like this: + // + // BB: + // %X = phi i1 [1], [%X'] + // %Y = icmp eq i32 %A, %B + // %Z = xor i1 %X, %Y + // br i1 %Z, ... + // + // Into: + // BB': + // %Y = icmp ne i32 %A, %B + // br i1 %Y, ... + + PredValueInfoTy XorOpValues; + bool isLHS = true; + if (!ComputeValueKnownInPredecessors(BO->getOperand(0), BB, XorOpValues, + WantInteger, BO)) { + assert(XorOpValues.empty()); + if (!ComputeValueKnownInPredecessors(BO->getOperand(1), BB, XorOpValues, + WantInteger, BO)) + return false; + isLHS = false; + } + + assert(!XorOpValues.empty() && + "ComputeValueKnownInPredecessors returned true with no values"); + + // Scan the information to see which is most popular: true or false. The + // predecessors can be of the set true, false, or undef. + unsigned NumTrue = 0, NumFalse = 0; + for (const auto &XorOpValue : XorOpValues) { + if (isa<UndefValue>(XorOpValue.first)) + // Ignore undefs for the count. + continue; + if (cast<ConstantInt>(XorOpValue.first)->isZero()) + ++NumFalse; + else + ++NumTrue; + } + + // Determine which value to split on, true, false, or undef if neither. + ConstantInt *SplitVal = nullptr; + if (NumTrue > NumFalse) + SplitVal = ConstantInt::getTrue(BB->getContext()); + else if (NumTrue != 0 || NumFalse != 0) + SplitVal = ConstantInt::getFalse(BB->getContext()); + + // Collect all of the blocks that this can be folded into so that we can + // factor this once and clone it once. + SmallVector<BasicBlock*, 8> BlocksToFoldInto; + for (const auto &XorOpValue : XorOpValues) { + if (XorOpValue.first != SplitVal && !isa<UndefValue>(XorOpValue.first)) + continue; + + BlocksToFoldInto.push_back(XorOpValue.second); + } + + // If we inferred a value for all of the predecessors, then duplication won't + // help us. However, we can just replace the LHS or RHS with the constant. + if (BlocksToFoldInto.size() == + cast<PHINode>(BB->front()).getNumIncomingValues()) { + if (!SplitVal) { + // If all preds provide undef, just nuke the xor, because it is undef too. + BO->replaceAllUsesWith(UndefValue::get(BO->getType())); + BO->eraseFromParent(); + } else if (SplitVal->isZero()) { + // If all preds provide 0, replace the xor with the other input. + BO->replaceAllUsesWith(BO->getOperand(isLHS)); + BO->eraseFromParent(); + } else { + // If all preds provide 1, set the computed value to 1. + BO->setOperand(!isLHS, SplitVal); + } + + return true; + } + + // Try to duplicate BB into PredBB. + return DuplicateCondBranchOnPHIIntoPred(BB, BlocksToFoldInto); +} + +/// AddPHINodeEntriesForMappedBlock - We're adding 'NewPred' as a new +/// predecessor to the PHIBB block. If it has PHI nodes, add entries for +/// NewPred using the entries from OldPred (suitably mapped). +static void AddPHINodeEntriesForMappedBlock(BasicBlock *PHIBB, + BasicBlock *OldPred, + BasicBlock *NewPred, + DenseMap<Instruction*, Value*> &ValueMap) { + for (PHINode &PN : PHIBB->phis()) { + // Ok, we have a PHI node. Figure out what the incoming value was for the + // DestBlock. + Value *IV = PN.getIncomingValueForBlock(OldPred); + + // Remap the value if necessary. + if (Instruction *Inst = dyn_cast<Instruction>(IV)) { + DenseMap<Instruction*, Value*>::iterator I = ValueMap.find(Inst); + if (I != ValueMap.end()) + IV = I->second; + } + + PN.addIncoming(IV, NewPred); + } +} + +/// ThreadEdge - We have decided that it is safe and profitable to factor the +/// blocks in PredBBs to one predecessor, then thread an edge from it to SuccBB +/// across BB. Transform the IR to reflect this change. +bool JumpThreadingPass::ThreadEdge(BasicBlock *BB, + const SmallVectorImpl<BasicBlock *> &PredBBs, + BasicBlock *SuccBB) { + // If threading to the same block as we come from, we would infinite loop. + if (SuccBB == BB) { + DEBUG(dbgs() << " Not threading across BB '" << BB->getName() + << "' - would thread to self!\n"); + return false; + } + + // If threading this would thread across a loop header, don't thread the edge. + // See the comments above FindLoopHeaders for justifications and caveats. + if (LoopHeaders.count(BB) || LoopHeaders.count(SuccBB)) { + DEBUG({ + bool BBIsHeader = LoopHeaders.count(BB); + bool SuccIsHeader = LoopHeaders.count(SuccBB); + dbgs() << " Not threading across " + << (BBIsHeader ? "loop header BB '" : "block BB '") << BB->getName() + << "' to dest " << (SuccIsHeader ? "loop header BB '" : "block BB '") + << SuccBB->getName() << "' - it might create an irreducible loop!\n"; + }); + return false; + } + + unsigned JumpThreadCost = + getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); + if (JumpThreadCost > BBDupThreshold) { + DEBUG(dbgs() << " Not threading BB '" << BB->getName() + << "' - Cost is too high: " << JumpThreadCost << "\n"); + return false; + } + + // And finally, do it! Start by factoring the predecessors if needed. + BasicBlock *PredBB; + if (PredBBs.size() == 1) + PredBB = PredBBs[0]; + else { + DEBUG(dbgs() << " Factoring out " << PredBBs.size() + << " common predecessors.\n"); + PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); + } + + // And finally, do it! + DEBUG(dbgs() << " Threading edge from '" << PredBB->getName() << "' to '" + << SuccBB->getName() << "' with cost: " << JumpThreadCost + << ", across block:\n " + << *BB << "\n"); + + LVI->threadEdge(PredBB, BB, SuccBB); + + // We are going to have to map operands from the original BB block to the new + // copy of the block 'NewBB'. If there are PHI nodes in BB, evaluate them to + // account for entry from PredBB. + DenseMap<Instruction*, Value*> ValueMapping; + + BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), + BB->getName()+".thread", + BB->getParent(), BB); + NewBB->moveAfter(PredBB); + + // Set the block frequency of NewBB. + if (HasProfileData) { + auto NewBBFreq = + BFI->getBlockFreq(PredBB) * BPI->getEdgeProbability(PredBB, BB); + BFI->setBlockFreq(NewBB, NewBBFreq.getFrequency()); + } + + BasicBlock::iterator BI = BB->begin(); + for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) + ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); + + // Clone the non-phi instructions of BB into NewBB, keeping track of the + // mapping and using it to remap operands in the cloned instructions. + for (; !isa<TerminatorInst>(BI); ++BI) { + Instruction *New = BI->clone(); + New->setName(BI->getName()); + NewBB->getInstList().push_back(New); + ValueMapping[&*BI] = New; + + // Remap operands to patch up intra-block references. + for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) + if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) { + DenseMap<Instruction*, Value*>::iterator I = ValueMapping.find(Inst); + if (I != ValueMapping.end()) + New->setOperand(i, I->second); + } + } + + // We didn't copy the terminator from BB over to NewBB, because there is now + // an unconditional jump to SuccBB. Insert the unconditional jump. + BranchInst *NewBI = BranchInst::Create(SuccBB, NewBB); + NewBI->setDebugLoc(BB->getTerminator()->getDebugLoc()); + + // Check to see if SuccBB has PHI nodes. If so, we need to add entries to the + // PHI nodes for NewBB now. + AddPHINodeEntriesForMappedBlock(SuccBB, BB, NewBB, ValueMapping); + + // If there were values defined in BB that are used outside the block, then we + // now have to update all uses of the value to use either the original value, + // the cloned value, or some PHI derived value. This can require arbitrary + // PHI insertion, of which we are prepared to do, clean these up now. + SSAUpdater SSAUpdate; + SmallVector<Use*, 16> UsesToRename; + for (Instruction &I : *BB) { + // Scan all uses of this instruction to see if it is used outside of its + // block, and if so, record them in UsesToRename. + for (Use &U : I.uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + if (PHINode *UserPN = dyn_cast<PHINode>(User)) { + if (UserPN->getIncomingBlock(U) == BB) + continue; + } else if (User->getParent() == BB) + continue; + + UsesToRename.push_back(&U); + } + + // If there are no uses outside the block, we're done with this instruction. + if (UsesToRename.empty()) + continue; + + DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); + + // We found a use of I outside of BB. Rename all uses of I that are outside + // its block to be uses of the appropriate PHI node etc. See ValuesInBlocks + // with the two values we know. + SSAUpdate.Initialize(I.getType(), I.getName()); + SSAUpdate.AddAvailableValue(BB, &I); + SSAUpdate.AddAvailableValue(NewBB, ValueMapping[&I]); + + while (!UsesToRename.empty()) + SSAUpdate.RewriteUse(*UsesToRename.pop_back_val()); + DEBUG(dbgs() << "\n"); + } + + // Ok, NewBB is good to go. Update the terminator of PredBB to jump to + // NewBB instead of BB. This eliminates predecessors from BB, which requires + // us to simplify any PHI nodes in BB. + TerminatorInst *PredTerm = PredBB->getTerminator(); + for (unsigned i = 0, e = PredTerm->getNumSuccessors(); i != e; ++i) + if (PredTerm->getSuccessor(i) == BB) { + BB->removePredecessor(PredBB, true); + PredTerm->setSuccessor(i, NewBB); + } + + // At this point, the IR is fully up to date and consistent. Do a quick scan + // over the new instructions and zap any that are constants or dead. This + // frequently happens because of phi translation. + SimplifyInstructionsInBlock(NewBB, TLI); + + // Update the edge weight from BB to SuccBB, which should be less than before. + UpdateBlockFreqAndEdgeWeight(PredBB, BB, NewBB, SuccBB); + + // Threaded an edge! + ++NumThreads; + return true; +} + +/// Create a new basic block that will be the predecessor of BB and successor of +/// all blocks in Preds. When profile data is available, update the frequency of +/// this new block. +BasicBlock *JumpThreadingPass::SplitBlockPreds(BasicBlock *BB, + ArrayRef<BasicBlock *> Preds, + const char *Suffix) { + // Collect the frequencies of all predecessors of BB, which will be used to + // update the edge weight on BB->SuccBB. + BlockFrequency PredBBFreq(0); + if (HasProfileData) + for (auto Pred : Preds) + PredBBFreq += BFI->getBlockFreq(Pred) * BPI->getEdgeProbability(Pred, BB); + + BasicBlock *PredBB = SplitBlockPredecessors(BB, Preds, Suffix); + + // Set the block frequency of the newly created PredBB, which is the sum of + // frequencies of Preds. + if (HasProfileData) + BFI->setBlockFreq(PredBB, PredBBFreq.getFrequency()); + return PredBB; +} + +bool JumpThreadingPass::doesBlockHaveProfileData(BasicBlock *BB) { + const TerminatorInst *TI = BB->getTerminator(); + assert(TI->getNumSuccessors() > 1 && "not a split"); + + MDNode *WeightsNode = TI->getMetadata(LLVMContext::MD_prof); + if (!WeightsNode) + return false; + + MDString *MDName = cast<MDString>(WeightsNode->getOperand(0)); + if (MDName->getString() != "branch_weights") + return false; + + // Ensure there are weights for all of the successors. Note that the first + // operand to the metadata node is a name, not a weight. + return WeightsNode->getNumOperands() == TI->getNumSuccessors() + 1; +} + +/// Update the block frequency of BB and branch weight and the metadata on the +/// edge BB->SuccBB. This is done by scaling the weight of BB->SuccBB by 1 - +/// Freq(PredBB->BB) / Freq(BB->SuccBB). +void JumpThreadingPass::UpdateBlockFreqAndEdgeWeight(BasicBlock *PredBB, + BasicBlock *BB, + BasicBlock *NewBB, + BasicBlock *SuccBB) { + if (!HasProfileData) + return; + + assert(BFI && BPI && "BFI & BPI should have been created here"); + + // As the edge from PredBB to BB is deleted, we have to update the block + // frequency of BB. + auto BBOrigFreq = BFI->getBlockFreq(BB); + auto NewBBFreq = BFI->getBlockFreq(NewBB); + auto BB2SuccBBFreq = BBOrigFreq * BPI->getEdgeProbability(BB, SuccBB); + auto BBNewFreq = BBOrigFreq - NewBBFreq; + BFI->setBlockFreq(BB, BBNewFreq.getFrequency()); + + // Collect updated outgoing edges' frequencies from BB and use them to update + // edge probabilities. + SmallVector<uint64_t, 4> BBSuccFreq; + for (BasicBlock *Succ : successors(BB)) { + auto SuccFreq = (Succ == SuccBB) + ? BB2SuccBBFreq - NewBBFreq + : BBOrigFreq * BPI->getEdgeProbability(BB, Succ); + BBSuccFreq.push_back(SuccFreq.getFrequency()); + } + + uint64_t MaxBBSuccFreq = + *std::max_element(BBSuccFreq.begin(), BBSuccFreq.end()); + + SmallVector<BranchProbability, 4> BBSuccProbs; + if (MaxBBSuccFreq == 0) + BBSuccProbs.assign(BBSuccFreq.size(), + {1, static_cast<unsigned>(BBSuccFreq.size())}); + else { + for (uint64_t Freq : BBSuccFreq) + BBSuccProbs.push_back( + BranchProbability::getBranchProbability(Freq, MaxBBSuccFreq)); + // Normalize edge probabilities so that they sum up to one. + BranchProbability::normalizeProbabilities(BBSuccProbs.begin(), + BBSuccProbs.end()); + } + + // Update edge probabilities in BPI. + for (int I = 0, E = BBSuccProbs.size(); I < E; I++) + BPI->setEdgeProbability(BB, I, BBSuccProbs[I]); + + // Update the profile metadata as well. + // + // Don't do this if the profile of the transformed blocks was statically + // estimated. (This could occur despite the function having an entry + // frequency in completely cold parts of the CFG.) + // + // In this case we don't want to suggest to subsequent passes that the + // calculated weights are fully consistent. Consider this graph: + // + // check_1 + // 50% / | + // eq_1 | 50% + // \ | + // check_2 + // 50% / | + // eq_2 | 50% + // \ | + // check_3 + // 50% / | + // eq_3 | 50% + // \ | + // + // Assuming the blocks check_* all compare the same value against 1, 2 and 3, + // the overall probabilities are inconsistent; the total probability that the + // value is either 1, 2 or 3 is 150%. + // + // As a consequence if we thread eq_1 -> check_2 to check_3, check_2->check_3 + // becomes 0%. This is even worse if the edge whose probability becomes 0% is + // the loop exit edge. Then based solely on static estimation we would assume + // the loop was extremely hot. + // + // FIXME this locally as well so that BPI and BFI are consistent as well. We + // shouldn't make edges extremely likely or unlikely based solely on static + // estimation. + if (BBSuccProbs.size() >= 2 && doesBlockHaveProfileData(BB)) { + SmallVector<uint32_t, 4> Weights; + for (auto Prob : BBSuccProbs) + Weights.push_back(Prob.getNumerator()); + + auto TI = BB->getTerminator(); + TI->setMetadata( + LLVMContext::MD_prof, + MDBuilder(TI->getParent()->getContext()).createBranchWeights(Weights)); + } +} + +/// DuplicateCondBranchOnPHIIntoPred - PredBB contains an unconditional branch +/// to BB which contains an i1 PHI node and a conditional branch on that PHI. +/// If we can duplicate the contents of BB up into PredBB do so now, this +/// improves the odds that the branch will be on an analyzable instruction like +/// a compare. +bool JumpThreadingPass::DuplicateCondBranchOnPHIIntoPred( + BasicBlock *BB, const SmallVectorImpl<BasicBlock *> &PredBBs) { + assert(!PredBBs.empty() && "Can't handle an empty set"); + + // If BB is a loop header, then duplicating this block outside the loop would + // cause us to transform this into an irreducible loop, don't do this. + // See the comments above FindLoopHeaders for justifications and caveats. + if (LoopHeaders.count(BB)) { + DEBUG(dbgs() << " Not duplicating loop header '" << BB->getName() + << "' into predecessor block '" << PredBBs[0]->getName() + << "' - it might create an irreducible loop!\n"); + return false; + } + + unsigned DuplicationCost = + getJumpThreadDuplicationCost(BB, BB->getTerminator(), BBDupThreshold); + if (DuplicationCost > BBDupThreshold) { + DEBUG(dbgs() << " Not duplicating BB '" << BB->getName() + << "' - Cost is too high: " << DuplicationCost << "\n"); + return false; + } + + // And finally, do it! Start by factoring the predecessors if needed. + BasicBlock *PredBB; + if (PredBBs.size() == 1) + PredBB = PredBBs[0]; + else { + DEBUG(dbgs() << " Factoring out " << PredBBs.size() + << " common predecessors.\n"); + PredBB = SplitBlockPreds(BB, PredBBs, ".thr_comm"); + } + + // Okay, we decided to do this! Clone all the instructions in BB onto the end + // of PredBB. + DEBUG(dbgs() << " Duplicating block '" << BB->getName() << "' into end of '" + << PredBB->getName() << "' to eliminate branch on phi. Cost: " + << DuplicationCost << " block is:" << *BB << "\n"); + + // Unless PredBB ends with an unconditional branch, split the edge so that we + // can just clone the bits from BB into the end of the new PredBB. + BranchInst *OldPredBranch = dyn_cast<BranchInst>(PredBB->getTerminator()); + + if (!OldPredBranch || !OldPredBranch->isUnconditional()) { + PredBB = SplitEdge(PredBB, BB); + OldPredBranch = cast<BranchInst>(PredBB->getTerminator()); + } + + // We are going to have to map operands from the original BB block into the + // PredBB block. Evaluate PHI nodes in BB. + DenseMap<Instruction*, Value*> ValueMapping; + + BasicBlock::iterator BI = BB->begin(); + for (; PHINode *PN = dyn_cast<PHINode>(BI); ++BI) + ValueMapping[PN] = PN->getIncomingValueForBlock(PredBB); + // Clone the non-phi instructions of BB into PredBB, keeping track of the + // mapping and using it to remap operands in the cloned instructions. + for (; BI != BB->end(); ++BI) { + Instruction *New = BI->clone(); + + // Remap operands to patch up intra-block references. + for (unsigned i = 0, e = New->getNumOperands(); i != e; ++i) + if (Instruction *Inst = dyn_cast<Instruction>(New->getOperand(i))) { + DenseMap<Instruction*, Value*>::iterator I = ValueMapping.find(Inst); + if (I != ValueMapping.end()) + New->setOperand(i, I->second); + } + + // If this instruction can be simplified after the operands are updated, + // just use the simplified value instead. This frequently happens due to + // phi translation. + if (Value *IV = SimplifyInstruction( + New, + {BB->getModule()->getDataLayout(), TLI, nullptr, nullptr, New})) { + ValueMapping[&*BI] = IV; + if (!New->mayHaveSideEffects()) { + New->deleteValue(); + New = nullptr; + } + } else { + ValueMapping[&*BI] = New; + } + if (New) { + // Otherwise, insert the new instruction into the block. + New->setName(BI->getName()); + PredBB->getInstList().insert(OldPredBranch->getIterator(), New); + } + } + + // Check to see if the targets of the branch had PHI nodes. If so, we need to + // add entries to the PHI nodes for branch from PredBB now. + BranchInst *BBBranch = cast<BranchInst>(BB->getTerminator()); + AddPHINodeEntriesForMappedBlock(BBBranch->getSuccessor(0), BB, PredBB, + ValueMapping); + AddPHINodeEntriesForMappedBlock(BBBranch->getSuccessor(1), BB, PredBB, + ValueMapping); + + // If there were values defined in BB that are used outside the block, then we + // now have to update all uses of the value to use either the original value, + // the cloned value, or some PHI derived value. This can require arbitrary + // PHI insertion, of which we are prepared to do, clean these up now. + SSAUpdater SSAUpdate; + SmallVector<Use*, 16> UsesToRename; + for (Instruction &I : *BB) { + // Scan all uses of this instruction to see if it is used outside of its + // block, and if so, record them in UsesToRename. + for (Use &U : I.uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + if (PHINode *UserPN = dyn_cast<PHINode>(User)) { + if (UserPN->getIncomingBlock(U) == BB) + continue; + } else if (User->getParent() == BB) + continue; + + UsesToRename.push_back(&U); + } + + // If there are no uses outside the block, we're done with this instruction. + if (UsesToRename.empty()) + continue; + + DEBUG(dbgs() << "JT: Renaming non-local uses of: " << I << "\n"); + + // We found a use of I outside of BB. Rename all uses of I that are outside + // its block to be uses of the appropriate PHI node etc. See ValuesInBlocks + // with the two values we know. + SSAUpdate.Initialize(I.getType(), I.getName()); + SSAUpdate.AddAvailableValue(BB, &I); + SSAUpdate.AddAvailableValue(PredBB, ValueMapping[&I]); + + while (!UsesToRename.empty()) + SSAUpdate.RewriteUse(*UsesToRename.pop_back_val()); + DEBUG(dbgs() << "\n"); + } + + // PredBB no longer jumps to BB, remove entries in the PHI node for the edge + // that we nuked. + BB->removePredecessor(PredBB, true); + + // Remove the unconditional branch at the end of the PredBB block. + OldPredBranch->eraseFromParent(); + + ++NumDupes; + return true; +} + +/// TryToUnfoldSelect - Look for blocks of the form +/// bb1: +/// %a = select +/// br bb2 +/// +/// bb2: +/// %p = phi [%a, %bb1] ... +/// %c = icmp %p +/// br i1 %c +/// +/// And expand the select into a branch structure if one of its arms allows %c +/// to be folded. This later enables threading from bb1 over bb2. +bool JumpThreadingPass::TryToUnfoldSelect(CmpInst *CondCmp, BasicBlock *BB) { + BranchInst *CondBr = dyn_cast<BranchInst>(BB->getTerminator()); + PHINode *CondLHS = dyn_cast<PHINode>(CondCmp->getOperand(0)); + Constant *CondRHS = cast<Constant>(CondCmp->getOperand(1)); + + if (!CondBr || !CondBr->isConditional() || !CondLHS || + CondLHS->getParent() != BB) + return false; + + for (unsigned I = 0, E = CondLHS->getNumIncomingValues(); I != E; ++I) { + BasicBlock *Pred = CondLHS->getIncomingBlock(I); + SelectInst *SI = dyn_cast<SelectInst>(CondLHS->getIncomingValue(I)); + + // Look if one of the incoming values is a select in the corresponding + // predecessor. + if (!SI || SI->getParent() != Pred || !SI->hasOneUse()) + continue; + + BranchInst *PredTerm = dyn_cast<BranchInst>(Pred->getTerminator()); + if (!PredTerm || !PredTerm->isUnconditional()) + continue; + + // Now check if one of the select values would allow us to constant fold the + // terminator in BB. We don't do the transform if both sides fold, those + // cases will be threaded in any case. + LazyValueInfo::Tristate LHSFolds = + LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(1), + CondRHS, Pred, BB, CondCmp); + LazyValueInfo::Tristate RHSFolds = + LVI->getPredicateOnEdge(CondCmp->getPredicate(), SI->getOperand(2), + CondRHS, Pred, BB, CondCmp); + if ((LHSFolds != LazyValueInfo::Unknown || + RHSFolds != LazyValueInfo::Unknown) && + LHSFolds != RHSFolds) { + // Expand the select. + // + // Pred -- + // | v + // | NewBB + // | | + // |----- + // v + // BB + BasicBlock *NewBB = BasicBlock::Create(BB->getContext(), "select.unfold", + BB->getParent(), BB); + // Move the unconditional branch to NewBB. + PredTerm->removeFromParent(); + NewBB->getInstList().insert(NewBB->end(), PredTerm); + // Create a conditional branch and update PHI nodes. + BranchInst::Create(NewBB, BB, SI->getCondition(), Pred); + CondLHS->setIncomingValue(I, SI->getFalseValue()); + CondLHS->addIncoming(SI->getTrueValue(), NewBB); + // The select is now dead. + SI->eraseFromParent(); + + // Update any other PHI nodes in BB. + for (BasicBlock::iterator BI = BB->begin(); + PHINode *Phi = dyn_cast<PHINode>(BI); ++BI) + if (Phi != CondLHS) + Phi->addIncoming(Phi->getIncomingValueForBlock(Pred), NewBB); + return true; + } + } + return false; +} + +/// TryToUnfoldSelectInCurrBB - Look for PHI/Select or PHI/CMP/Select in the +/// same BB in the form +/// bb: +/// %p = phi [false, %bb1], [true, %bb2], [false, %bb3], [true, %bb4], ... +/// %s = select %p, trueval, falseval +/// +/// or +/// +/// bb: +/// %p = phi [0, %bb1], [1, %bb2], [0, %bb3], [1, %bb4], ... +/// %c = cmp %p, 0 +/// %s = select %c, trueval, falseval +/// +/// And expand the select into a branch structure. This later enables +/// jump-threading over bb in this pass. +/// +/// Using the similar approach of SimplifyCFG::FoldCondBranchOnPHI(), unfold +/// select if the associated PHI has at least one constant. If the unfolded +/// select is not jump-threaded, it will be folded again in the later +/// optimizations. +bool JumpThreadingPass::TryToUnfoldSelectInCurrBB(BasicBlock *BB) { + // If threading this would thread across a loop header, don't thread the edge. + // See the comments above FindLoopHeaders for justifications and caveats. + if (LoopHeaders.count(BB)) + return false; + + for (BasicBlock::iterator BI = BB->begin(); + PHINode *PN = dyn_cast<PHINode>(BI); ++BI) { + // Look for a Phi having at least one constant incoming value. + if (llvm::all_of(PN->incoming_values(), + [](Value *V) { return !isa<ConstantInt>(V); })) + continue; + + auto isUnfoldCandidate = [BB](SelectInst *SI, Value *V) { + // Check if SI is in BB and use V as condition. + if (SI->getParent() != BB) + return false; + Value *Cond = SI->getCondition(); + return (Cond && Cond == V && Cond->getType()->isIntegerTy(1)); + }; + + SelectInst *SI = nullptr; + for (Use &U : PN->uses()) { + if (ICmpInst *Cmp = dyn_cast<ICmpInst>(U.getUser())) { + // Look for a ICmp in BB that compares PN with a constant and is the + // condition of a Select. + if (Cmp->getParent() == BB && Cmp->hasOneUse() && + isa<ConstantInt>(Cmp->getOperand(1 - U.getOperandNo()))) + if (SelectInst *SelectI = dyn_cast<SelectInst>(Cmp->user_back())) + if (isUnfoldCandidate(SelectI, Cmp->use_begin()->get())) { + SI = SelectI; + break; + } + } else if (SelectInst *SelectI = dyn_cast<SelectInst>(U.getUser())) { + // Look for a Select in BB that uses PN as condtion. + if (isUnfoldCandidate(SelectI, U.get())) { + SI = SelectI; + break; + } + } + } + + if (!SI) + continue; + // Expand the select. + TerminatorInst *Term = + SplitBlockAndInsertIfThen(SI->getCondition(), SI, false); + PHINode *NewPN = PHINode::Create(SI->getType(), 2, "", SI); + NewPN->addIncoming(SI->getTrueValue(), Term->getParent()); + NewPN->addIncoming(SI->getFalseValue(), BB); + SI->replaceAllUsesWith(NewPN); + SI->eraseFromParent(); + return true; + } + return false; +} + +/// Try to propagate a guard from the current BB into one of its predecessors +/// in case if another branch of execution implies that the condition of this +/// guard is always true. Currently we only process the simplest case that +/// looks like: +/// +/// Start: +/// %cond = ... +/// br i1 %cond, label %T1, label %F1 +/// T1: +/// br label %Merge +/// F1: +/// br label %Merge +/// Merge: +/// %condGuard = ... +/// call void(i1, ...) @llvm.experimental.guard( i1 %condGuard )[ "deopt"() ] +/// +/// And cond either implies condGuard or !condGuard. In this case all the +/// instructions before the guard can be duplicated in both branches, and the +/// guard is then threaded to one of them. +bool JumpThreadingPass::ProcessGuards(BasicBlock *BB) { + using namespace PatternMatch; + + // We only want to deal with two predecessors. + BasicBlock *Pred1, *Pred2; + auto PI = pred_begin(BB), PE = pred_end(BB); + if (PI == PE) + return false; + Pred1 = *PI++; + if (PI == PE) + return false; + Pred2 = *PI++; + if (PI != PE) + return false; + if (Pred1 == Pred2) + return false; + + // Try to thread one of the guards of the block. + // TODO: Look up deeper than to immediate predecessor? + auto *Parent = Pred1->getSinglePredecessor(); + if (!Parent || Parent != Pred2->getSinglePredecessor()) + return false; + + if (auto *BI = dyn_cast<BranchInst>(Parent->getTerminator())) + for (auto &I : *BB) + if (match(&I, m_Intrinsic<Intrinsic::experimental_guard>())) + if (ThreadGuard(BB, cast<IntrinsicInst>(&I), BI)) + return true; + + return false; +} + +/// Try to propagate the guard from BB which is the lower block of a diamond +/// to one of its branches, in case if diamond's condition implies guard's +/// condition. +bool JumpThreadingPass::ThreadGuard(BasicBlock *BB, IntrinsicInst *Guard, + BranchInst *BI) { + assert(BI->getNumSuccessors() == 2 && "Wrong number of successors?"); + assert(BI->isConditional() && "Unconditional branch has 2 successors?"); + Value *GuardCond = Guard->getArgOperand(0); + Value *BranchCond = BI->getCondition(); + BasicBlock *TrueDest = BI->getSuccessor(0); + BasicBlock *FalseDest = BI->getSuccessor(1); + + auto &DL = BB->getModule()->getDataLayout(); + bool TrueDestIsSafe = false; + bool FalseDestIsSafe = false; + + // True dest is safe if BranchCond => GuardCond. + auto Impl = isImpliedCondition(BranchCond, GuardCond, DL); + if (Impl && *Impl) + TrueDestIsSafe = true; + else { + // False dest is safe if !BranchCond => GuardCond. + Impl = isImpliedCondition(BranchCond, GuardCond, DL, /* LHSIsTrue */ false); + if (Impl && *Impl) + FalseDestIsSafe = true; + } + + if (!TrueDestIsSafe && !FalseDestIsSafe) + return false; + + BasicBlock *UnguardedBlock = TrueDestIsSafe ? TrueDest : FalseDest; + BasicBlock *GuardedBlock = FalseDestIsSafe ? TrueDest : FalseDest; + + ValueToValueMapTy UnguardedMapping, GuardedMapping; + Instruction *AfterGuard = Guard->getNextNode(); + unsigned Cost = getJumpThreadDuplicationCost(BB, AfterGuard, BBDupThreshold); + if (Cost > BBDupThreshold) + return false; + // Duplicate all instructions before the guard and the guard itself to the + // branch where implication is not proved. + GuardedBlock = DuplicateInstructionsInSplitBetween( + BB, GuardedBlock, AfterGuard, GuardedMapping); + assert(GuardedBlock && "Could not create the guarded block?"); + // Duplicate all instructions before the guard in the unguarded branch. + // Since we have successfully duplicated the guarded block and this block + // has fewer instructions, we expect it to succeed. + UnguardedBlock = DuplicateInstructionsInSplitBetween(BB, UnguardedBlock, + Guard, UnguardedMapping); + assert(UnguardedBlock && "Could not create the unguarded block?"); + DEBUG(dbgs() << "Moved guard " << *Guard << " to block " + << GuardedBlock->getName() << "\n"); + + // Some instructions before the guard may still have uses. For them, we need + // to create Phi nodes merging their copies in both guarded and unguarded + // branches. Those instructions that have no uses can be just removed. + SmallVector<Instruction *, 4> ToRemove; + for (auto BI = BB->begin(); &*BI != AfterGuard; ++BI) + if (!isa<PHINode>(&*BI)) + ToRemove.push_back(&*BI); + + Instruction *InsertionPoint = &*BB->getFirstInsertionPt(); + assert(InsertionPoint && "Empty block?"); + // Substitute with Phis & remove. + for (auto *Inst : reverse(ToRemove)) { + if (!Inst->use_empty()) { + PHINode *NewPN = PHINode::Create(Inst->getType(), 2); + NewPN->addIncoming(UnguardedMapping[Inst], UnguardedBlock); + NewPN->addIncoming(GuardedMapping[Inst], GuardedBlock); + NewPN->insertBefore(InsertionPoint); + Inst->replaceAllUsesWith(NewPN); + } + Inst->eraseFromParent(); + } + return true; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LICM.cpp b/contrib/llvm/lib/Transforms/Scalar/LICM.cpp new file mode 100644 index 000000000000..946474fef062 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LICM.cpp @@ -0,0 +1,1576 @@ +//===-- LICM.cpp - Loop Invariant Code Motion Pass ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs loop invariant code motion, attempting to remove as much +// code from the body of a loop as possible. It does this by either hoisting +// code into the preheader block, or by sinking code to the exit blocks if it is +// safe. This pass also promotes must-aliased memory locations in the loop to +// live in registers, thus hoisting and sinking "invariant" loads and stores. +// +// This pass uses alias analysis for two purposes: +// +// 1. Moving loop invariant loads and calls out of loops. If we can determine +// that a load or call inside of a loop never aliases anything stored to, +// we can hoist it or sink it like any other instruction. +// 2. Scalar Promotion of Memory - If there is a store instruction inside of +// the loop, we try to move the store to happen AFTER the loop instead of +// inside of the loop. This can only happen if a few conditions are true: +// A. The pointer stored through is loop invariant +// B. There are no stores or loads in the loop which _may_ alias the +// pointer. There are no calls in the loop which mod/ref the pointer. +// If these conditions are true, we can promote the loads and stores in the +// loop of the pointer to use a temporary alloca'd variable. We then use +// the SSAUpdater to construct the appropriate SSA form for the value. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LICM.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PredIteratorCache.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include <algorithm> +#include <utility> +using namespace llvm; + +#define DEBUG_TYPE "licm" + +STATISTIC(NumSunk, "Number of instructions sunk out of loop"); +STATISTIC(NumHoisted, "Number of instructions hoisted out of loop"); +STATISTIC(NumMovedLoads, "Number of load insts hoisted or sunk"); +STATISTIC(NumMovedCalls, "Number of call insts hoisted or sunk"); +STATISTIC(NumPromoted, "Number of memory locations promoted to registers"); + +/// Memory promotion is enabled by default. +static cl::opt<bool> + DisablePromotion("disable-licm-promotion", cl::Hidden, cl::init(false), + cl::desc("Disable memory promotion in LICM pass")); + +static cl::opt<uint32_t> MaxNumUsesTraversed( + "licm-max-num-uses-traversed", cl::Hidden, cl::init(8), + cl::desc("Max num uses visited for identifying load " + "invariance in loop using invariant start (default = 8)")); + +static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI); +static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + TargetTransformInfo *TTI, bool &FreeInLoop); +static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE); +static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, + const Loop *CurLoop, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, bool FreeInLoop); +static bool isSafeToExecuteUnconditionally(Instruction &Inst, + const DominatorTree *DT, + const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, + const Instruction *CtxI = nullptr); +static bool pointerInvalidatedByLoop(Value *V, uint64_t Size, + const AAMDNodes &AAInfo, + AliasSetTracker *CurAST); +static Instruction * +CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, + const LoopInfo *LI, + const LoopSafetyInfo *SafetyInfo); + +namespace { +struct LoopInvariantCodeMotion { + bool runOnLoop(Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, + TargetLibraryInfo *TLI, TargetTransformInfo *TTI, + ScalarEvolution *SE, MemorySSA *MSSA, + OptimizationRemarkEmitter *ORE, bool DeleteAST); + + DenseMap<Loop *, AliasSetTracker *> &getLoopToAliasSetMap() { + return LoopToAliasSetMap; + } + +private: + DenseMap<Loop *, AliasSetTracker *> LoopToAliasSetMap; + + AliasSetTracker *collectAliasInfoForLoop(Loop *L, LoopInfo *LI, + AliasAnalysis *AA); +}; + +struct LegacyLICMPass : public LoopPass { + static char ID; // Pass identification, replacement for typeid + LegacyLICMPass() : LoopPass(ID) { + initializeLegacyLICMPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) { + // If we have run LICM on a previous loop but now we are skipping + // (because we've hit the opt-bisect limit), we need to clear the + // loop alias information. + for (auto <AS : LICM.getLoopToAliasSetMap()) + delete LTAS.second; + LICM.getLoopToAliasSetMap().clear(); + return false; + } + + auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + MemorySSA *MSSA = EnableMSSALoopDependency + ? (&getAnalysis<MemorySSAWrapperPass>().getMSSA()) + : nullptr; + // For the old PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(L->getHeader()->getParent()); + return LICM.runOnLoop(L, + &getAnalysis<AAResultsWrapperPass>().getAAResults(), + &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), + &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()), + SE ? &SE->getSE() : nullptr, MSSA, &ORE, false); + } + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + /// + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + if (EnableMSSALoopDependency) + AU.addRequired<MemorySSAWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } + + using llvm::Pass::doFinalization; + + bool doFinalization() override { + assert(LICM.getLoopToAliasSetMap().empty() && + "Didn't free loop alias sets"); + return false; + } + +private: + LoopInvariantCodeMotion LICM; + + /// cloneBasicBlockAnalysis - Simple Analysis hook. Clone alias set info. + void cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, + Loop *L) override; + + /// deleteAnalysisValue - Simple Analysis hook. Delete value V from alias + /// set. + void deleteAnalysisValue(Value *V, Loop *L) override; + + /// Simple Analysis hook. Delete loop L from alias set map. + void deleteAnalysisLoop(Loop *L) override; +}; +} // namespace + +PreservedAnalyses LICMPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); + // FIXME: This should probably be optional rather than required. + if (!ORE) + report_fatal_error("LICM: OptimizationRemarkEmitterAnalysis not " + "cached at a higher level"); + + LoopInvariantCodeMotion LICM; + if (!LICM.runOnLoop(&L, &AR.AA, &AR.LI, &AR.DT, &AR.TLI, &AR.TTI, &AR.SE, + AR.MSSA, ORE, true)) + return PreservedAnalyses::all(); + + auto PA = getLoopPassPreservedAnalyses(); + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +char LegacyLICMPass::ID = 0; +INITIALIZE_PASS_BEGIN(LegacyLICMPass, "licm", "Loop Invariant Code Motion", + false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_END(LegacyLICMPass, "licm", "Loop Invariant Code Motion", false, + false) + +Pass *llvm::createLICMPass() { return new LegacyLICMPass(); } + +/// Hoist expressions out of the specified loop. Note, alias info for inner +/// loop is not preserved so it is not a good idea to run LICM multiple +/// times on one loop. +/// We should delete AST for inner loops in the new pass manager to avoid +/// memory leak. +/// +bool LoopInvariantCodeMotion::runOnLoop( + Loop *L, AliasAnalysis *AA, LoopInfo *LI, DominatorTree *DT, + TargetLibraryInfo *TLI, TargetTransformInfo *TTI, ScalarEvolution *SE, + MemorySSA *MSSA, OptimizationRemarkEmitter *ORE, bool DeleteAST) { + bool Changed = false; + + assert(L->isLCSSAForm(*DT) && "Loop is not in LCSSA form."); + + AliasSetTracker *CurAST = collectAliasInfoForLoop(L, LI, AA); + + // Get the preheader block to move instructions into... + BasicBlock *Preheader = L->getLoopPreheader(); + + // Compute loop safety information. + LoopSafetyInfo SafetyInfo; + computeLoopSafetyInfo(&SafetyInfo, L); + + // We want to visit all of the instructions in this loop... that are not parts + // of our subloops (they have already had their invariants hoisted out of + // their loop, into this loop, so there is no need to process the BODIES of + // the subloops). + // + // Traverse the body of the loop in depth first order on the dominator tree so + // that we are guaranteed to see definitions before we see uses. This allows + // us to sink instructions in one pass, without iteration. After sinking + // instructions, we perform another pass to hoist them out of the loop. + // + if (L->hasDedicatedExits()) + Changed |= sinkRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, TTI, L, + CurAST, &SafetyInfo, ORE); + if (Preheader) + Changed |= hoistRegion(DT->getNode(L->getHeader()), AA, LI, DT, TLI, L, + CurAST, &SafetyInfo, ORE); + + // Now that all loop invariants have been removed from the loop, promote any + // memory references to scalars that we can. + // Don't sink stores from loops without dedicated block exits. Exits + // containing indirect branches are not transformed by loop simplify, + // make sure we catch that. An additional load may be generated in the + // preheader for SSA updater, so also avoid sinking when no preheader + // is available. + if (!DisablePromotion && Preheader && L->hasDedicatedExits()) { + // Figure out the loop exits and their insertion points + SmallVector<BasicBlock *, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + + // We can't insert into a catchswitch. + bool HasCatchSwitch = llvm::any_of(ExitBlocks, [](BasicBlock *Exit) { + return isa<CatchSwitchInst>(Exit->getTerminator()); + }); + + if (!HasCatchSwitch) { + SmallVector<Instruction *, 8> InsertPts; + InsertPts.reserve(ExitBlocks.size()); + for (BasicBlock *ExitBlock : ExitBlocks) + InsertPts.push_back(&*ExitBlock->getFirstInsertionPt()); + + PredIteratorCache PIC; + + bool Promoted = false; + + // Loop over all of the alias sets in the tracker object. + for (AliasSet &AS : *CurAST) { + // We can promote this alias set if it has a store, if it is a "Must" + // alias set, if the pointer is loop invariant, and if we are not + // eliminating any volatile loads or stores. + if (AS.isForwardingAliasSet() || !AS.isMod() || !AS.isMustAlias() || + AS.isVolatile() || !L->isLoopInvariant(AS.begin()->getValue())) + continue; + + assert( + !AS.empty() && + "Must alias set should have at least one pointer element in it!"); + + SmallSetVector<Value *, 8> PointerMustAliases; + for (const auto &ASI : AS) + PointerMustAliases.insert(ASI.getValue()); + + Promoted |= promoteLoopAccessesToScalars(PointerMustAliases, ExitBlocks, + InsertPts, PIC, LI, DT, TLI, L, + CurAST, &SafetyInfo, ORE); + } + + // Once we have promoted values across the loop body we have to + // recursively reform LCSSA as any nested loop may now have values defined + // within the loop used in the outer loop. + // FIXME: This is really heavy handed. It would be a bit better to use an + // SSAUpdater strategy during promotion that was LCSSA aware and reformed + // it as it went. + if (Promoted) + formLCSSARecursively(*L, *DT, LI, SE); + + Changed |= Promoted; + } + } + + // Check that neither this loop nor its parent have had LCSSA broken. LICM is + // specifically moving instructions across the loop boundary and so it is + // especially in need of sanity checking here. + assert(L->isLCSSAForm(*DT) && "Loop not left in LCSSA form after LICM!"); + assert((!L->getParentLoop() || L->getParentLoop()->isLCSSAForm(*DT)) && + "Parent loop not left in LCSSA form after LICM!"); + + // If this loop is nested inside of another one, save the alias information + // for when we process the outer loop. + if (L->getParentLoop() && !DeleteAST) + LoopToAliasSetMap[L] = CurAST; + else + delete CurAST; + + if (Changed && SE) + SE->forgetLoopDispositions(L); + return Changed; +} + +/// Walk the specified region of the CFG (defined by all blocks dominated by +/// the specified block, and that are in the current loop) in reverse depth +/// first order w.r.t the DominatorTree. This allows us to visit uses before +/// definitions, allowing us to sink a loop body in one pass without iteration. +/// +bool llvm::sinkRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, + DominatorTree *DT, TargetLibraryInfo *TLI, + TargetTransformInfo *TTI, Loop *CurLoop, + AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { + + // Verify inputs. + assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && + CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && + "Unexpected input to sinkRegion"); + + // We want to visit children before parents. We will enque all the parents + // before their children in the worklist and process the worklist in reverse + // order. + SmallVector<DomTreeNode *, 16> Worklist = collectChildrenInLoop(N, CurLoop); + + bool Changed = false; + for (DomTreeNode *DTN : reverse(Worklist)) { + BasicBlock *BB = DTN->getBlock(); + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (inSubLoop(BB, CurLoop, LI)) + continue; + + for (BasicBlock::iterator II = BB->end(); II != BB->begin();) { + Instruction &I = *--II; + + // If the instruction is dead, we would try to sink it because it isn't + // used in the loop, instead, just delete it. + if (isInstructionTriviallyDead(&I, TLI)) { + DEBUG(dbgs() << "LICM deleting dead inst: " << I << '\n'); + ++II; + CurAST->deleteValue(&I); + I.eraseFromParent(); + Changed = true; + continue; + } + + // Check to see if we can sink this instruction to the exit blocks + // of the loop. We can do this if the all users of the instruction are + // outside of the loop. In this case, it doesn't even matter if the + // operands of the instruction are loop invariant. + // + bool FreeInLoop = false; + if (isNotUsedOrFreeInLoop(I, CurLoop, SafetyInfo, TTI, FreeInLoop) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE)) { + if (sink(I, LI, DT, CurLoop, SafetyInfo, ORE, FreeInLoop)) { + if (!FreeInLoop) { + ++II; + CurAST->deleteValue(&I); + I.eraseFromParent(); + } + Changed = true; + } + } + } + } + return Changed; +} + +/// Walk the specified region of the CFG (defined by all blocks dominated by +/// the specified block, and that are in the current loop) in depth first +/// order w.r.t the DominatorTree. This allows us to visit definitions before +/// uses, allowing us to hoist a loop body in one pass without iteration. +/// +bool llvm::hoistRegion(DomTreeNode *N, AliasAnalysis *AA, LoopInfo *LI, + DominatorTree *DT, TargetLibraryInfo *TLI, Loop *CurLoop, + AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { + // Verify inputs. + assert(N != nullptr && AA != nullptr && LI != nullptr && DT != nullptr && + CurLoop != nullptr && CurAST != nullptr && SafetyInfo != nullptr && + "Unexpected input to hoistRegion"); + + // We want to visit parents before children. We will enque all the parents + // before their children in the worklist and process the worklist in order. + SmallVector<DomTreeNode *, 16> Worklist = collectChildrenInLoop(N, CurLoop); + + bool Changed = false; + for (DomTreeNode *DTN : Worklist) { + BasicBlock *BB = DTN->getBlock(); + // Only need to process the contents of this block if it is not part of a + // subloop (which would already have been processed). + if (!inSubLoop(BB, CurLoop, LI)) + for (BasicBlock::iterator II = BB->begin(), E = BB->end(); II != E;) { + Instruction &I = *II++; + // Try constant folding this instruction. If all the operands are + // constants, it is technically hoistable, but it would be better to + // just fold it. + if (Constant *C = ConstantFoldInstruction( + &I, I.getModule()->getDataLayout(), TLI)) { + DEBUG(dbgs() << "LICM folding inst: " << I << " --> " << *C << '\n'); + CurAST->copyValue(&I, C); + I.replaceAllUsesWith(C); + if (isInstructionTriviallyDead(&I, TLI)) { + CurAST->deleteValue(&I); + I.eraseFromParent(); + } + Changed = true; + continue; + } + + // Attempt to remove floating point division out of the loop by + // converting it to a reciprocal multiplication. + if (I.getOpcode() == Instruction::FDiv && + CurLoop->isLoopInvariant(I.getOperand(1)) && + I.hasAllowReciprocal()) { + auto Divisor = I.getOperand(1); + auto One = llvm::ConstantFP::get(Divisor->getType(), 1.0); + auto ReciprocalDivisor = BinaryOperator::CreateFDiv(One, Divisor); + ReciprocalDivisor->setFastMathFlags(I.getFastMathFlags()); + ReciprocalDivisor->insertBefore(&I); + + auto Product = + BinaryOperator::CreateFMul(I.getOperand(0), ReciprocalDivisor); + Product->setFastMathFlags(I.getFastMathFlags()); + Product->insertAfter(&I); + I.replaceAllUsesWith(Product); + I.eraseFromParent(); + + hoist(*ReciprocalDivisor, DT, CurLoop, SafetyInfo, ORE); + Changed = true; + continue; + } + + // Try hoisting the instruction out to the preheader. We can only do + // this if all of the operands of the instruction are loop invariant and + // if it is safe to hoist the instruction. + // + if (CurLoop->hasLoopInvariantOperands(&I) && + canSinkOrHoistInst(I, AA, DT, CurLoop, CurAST, SafetyInfo, ORE) && + isSafeToExecuteUnconditionally( + I, DT, CurLoop, SafetyInfo, ORE, + CurLoop->getLoopPreheader()->getTerminator())) + Changed |= hoist(I, DT, CurLoop, SafetyInfo, ORE); + } + } + + return Changed; +} + +/// Computes loop safety information, checks loop body & header +/// for the possibility of may throw exception. +/// +void llvm::computeLoopSafetyInfo(LoopSafetyInfo *SafetyInfo, Loop *CurLoop) { + assert(CurLoop != nullptr && "CurLoop cant be null"); + BasicBlock *Header = CurLoop->getHeader(); + // Setting default safety values. + SafetyInfo->MayThrow = false; + SafetyInfo->HeaderMayThrow = false; + // Iterate over header and compute safety info. + for (BasicBlock::iterator I = Header->begin(), E = Header->end(); + (I != E) && !SafetyInfo->HeaderMayThrow; ++I) + SafetyInfo->HeaderMayThrow |= + !isGuaranteedToTransferExecutionToSuccessor(&*I); + + SafetyInfo->MayThrow = SafetyInfo->HeaderMayThrow; + // Iterate over loop instructions and compute safety info. + // Skip header as it has been computed and stored in HeaderMayThrow. + // The first block in loopinfo.Blocks is guaranteed to be the header. + assert(Header == *CurLoop->getBlocks().begin() && + "First block must be header"); + for (Loop::block_iterator BB = std::next(CurLoop->block_begin()), + BBE = CurLoop->block_end(); + (BB != BBE) && !SafetyInfo->MayThrow; ++BB) + for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); + (I != E) && !SafetyInfo->MayThrow; ++I) + SafetyInfo->MayThrow |= !isGuaranteedToTransferExecutionToSuccessor(&*I); + + // Compute funclet colors if we might sink/hoist in a function with a funclet + // personality routine. + Function *Fn = CurLoop->getHeader()->getParent(); + if (Fn->hasPersonalityFn()) + if (Constant *PersonalityFn = Fn->getPersonalityFn()) + if (isFuncletEHPersonality(classifyEHPersonality(PersonalityFn))) + SafetyInfo->BlockColors = colorEHFunclets(*Fn); +} + +// Return true if LI is invariant within scope of the loop. LI is invariant if +// CurLoop is dominated by an invariant.start representing the same memory +// location and size as the memory location LI loads from, and also the +// invariant.start has no uses. +static bool isLoadInvariantInLoop(LoadInst *LI, DominatorTree *DT, + Loop *CurLoop) { + Value *Addr = LI->getOperand(0); + const DataLayout &DL = LI->getModule()->getDataLayout(); + const uint32_t LocSizeInBits = DL.getTypeSizeInBits( + cast<PointerType>(Addr->getType())->getElementType()); + + // if the type is i8 addrspace(x)*, we know this is the type of + // llvm.invariant.start operand + auto *PtrInt8Ty = PointerType::get(Type::getInt8Ty(LI->getContext()), + LI->getPointerAddressSpace()); + unsigned BitcastsVisited = 0; + // Look through bitcasts until we reach the i8* type (this is invariant.start + // operand type). + while (Addr->getType() != PtrInt8Ty) { + auto *BC = dyn_cast<BitCastInst>(Addr); + // Avoid traversing high number of bitcast uses. + if (++BitcastsVisited > MaxNumUsesTraversed || !BC) + return false; + Addr = BC->getOperand(0); + } + + unsigned UsesVisited = 0; + // Traverse all uses of the load operand value, to see if invariant.start is + // one of the uses, and whether it dominates the load instruction. + for (auto *U : Addr->users()) { + // Avoid traversing for Load operand with high number of users. + if (++UsesVisited > MaxNumUsesTraversed) + return false; + IntrinsicInst *II = dyn_cast<IntrinsicInst>(U); + // If there are escaping uses of invariant.start instruction, the load maybe + // non-invariant. + if (!II || II->getIntrinsicID() != Intrinsic::invariant_start || + !II->use_empty()) + continue; + unsigned InvariantSizeInBits = + cast<ConstantInt>(II->getArgOperand(0))->getSExtValue() * 8; + // Confirm the invariant.start location size contains the load operand size + // in bits. Also, the invariant.start should dominate the load, and we + // should not hoist the load out of a loop that contains this dominating + // invariant.start. + if (LocSizeInBits <= InvariantSizeInBits && + DT->properlyDominates(II->getParent(), CurLoop->getHeader())) + return true; + } + + return false; +} + +bool llvm::canSinkOrHoistInst(Instruction &I, AAResults *AA, DominatorTree *DT, + Loop *CurLoop, AliasSetTracker *CurAST, + LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { + // SafetyInfo is nullptr if we are checking for sinking from preheader to + // loop body. + const bool SinkingToLoopBody = !SafetyInfo; + // Loads have extra constraints we have to verify before we can hoist them. + if (LoadInst *LI = dyn_cast<LoadInst>(&I)) { + if (!LI->isUnordered()) + return false; // Don't sink/hoist volatile or ordered atomic loads! + + // Loads from constant memory are always safe to move, even if they end up + // in the same alias set as something that ends up being modified. + if (AA->pointsToConstantMemory(LI->getOperand(0))) + return true; + if (LI->getMetadata(LLVMContext::MD_invariant_load)) + return true; + + if (LI->isAtomic() && SinkingToLoopBody) + return false; // Don't sink unordered atomic loads to loop body. + + // This checks for an invariant.start dominating the load. + if (isLoadInvariantInLoop(LI, DT, CurLoop)) + return true; + + // Don't hoist loads which have may-aliased stores in loop. + uint64_t Size = 0; + if (LI->getType()->isSized()) + Size = I.getModule()->getDataLayout().getTypeStoreSize(LI->getType()); + + AAMDNodes AAInfo; + LI->getAAMetadata(AAInfo); + + bool Invalidated = + pointerInvalidatedByLoop(LI->getOperand(0), Size, AAInfo, CurAST); + // Check loop-invariant address because this may also be a sinkable load + // whose address is not necessarily loop-invariant. + if (ORE && Invalidated && CurLoop->isLoopInvariant(LI->getPointerOperand())) + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "LoadWithLoopInvariantAddressInvalidated", LI) + << "failed to move load with loop-invariant address " + "because the loop may invalidate its value"; + }); + + return !Invalidated; + } else if (CallInst *CI = dyn_cast<CallInst>(&I)) { + // Don't sink or hoist dbg info; it's legal, but not useful. + if (isa<DbgInfoIntrinsic>(I)) + return false; + + // Don't sink calls which can throw. + if (CI->mayThrow()) + return false; + + // Handle simple cases by querying alias analysis. + FunctionModRefBehavior Behavior = AA->getModRefBehavior(CI); + if (Behavior == FMRB_DoesNotAccessMemory) + return true; + if (AliasAnalysis::onlyReadsMemory(Behavior)) { + // A readonly argmemonly function only reads from memory pointed to by + // it's arguments with arbitrary offsets. If we can prove there are no + // writes to this memory in the loop, we can hoist or sink. + if (AliasAnalysis::onlyAccessesArgPointees(Behavior)) { + for (Value *Op : CI->arg_operands()) + if (Op->getType()->isPointerTy() && + pointerInvalidatedByLoop(Op, MemoryLocation::UnknownSize, + AAMDNodes(), CurAST)) + return false; + return true; + } + // If this call only reads from memory and there are no writes to memory + // in the loop, we can hoist or sink the call as appropriate. + bool FoundMod = false; + for (AliasSet &AS : *CurAST) { + if (!AS.isForwardingAliasSet() && AS.isMod()) { + FoundMod = true; + break; + } + } + if (!FoundMod) + return true; + } + + // FIXME: This should use mod/ref information to see if we can hoist or + // sink the call. + + return false; + } + + // Only these instructions are hoistable/sinkable. + if (!isa<BinaryOperator>(I) && !isa<CastInst>(I) && !isa<SelectInst>(I) && + !isa<GetElementPtrInst>(I) && !isa<CmpInst>(I) && + !isa<InsertElementInst>(I) && !isa<ExtractElementInst>(I) && + !isa<ShuffleVectorInst>(I) && !isa<ExtractValueInst>(I) && + !isa<InsertValueInst>(I)) + return false; + + // If we are checking for sinking from preheader to loop body it will be + // always safe as there is no speculative execution. + if (SinkingToLoopBody) + return true; + + // TODO: Plumb the context instruction through to make hoisting and sinking + // more powerful. Hoisting of loads already works due to the special casing + // above. + return isSafeToExecuteUnconditionally(I, DT, CurLoop, SafetyInfo, nullptr); +} + +/// Returns true if a PHINode is a trivially replaceable with an +/// Instruction. +/// This is true when all incoming values are that instruction. +/// This pattern occurs most often with LCSSA PHI nodes. +/// +static bool isTriviallyReplacablePHI(const PHINode &PN, const Instruction &I) { + for (const Value *IncValue : PN.incoming_values()) + if (IncValue != &I) + return false; + + return true; +} + +/// Return true if the instruction is free in the loop. +static bool isFreeInLoop(const Instruction &I, const Loop *CurLoop, + const TargetTransformInfo *TTI) { + + if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&I)) { + if (TTI->getUserCost(GEP) != TargetTransformInfo::TCC_Free) + return false; + // For a GEP, we cannot simply use getUserCost because currently it + // optimistically assume that a GEP will fold into addressing mode + // regardless of its users. + const BasicBlock *BB = GEP->getParent(); + for (const User *U : GEP->users()) { + const Instruction *UI = cast<Instruction>(U); + if (CurLoop->contains(UI) && + (BB != UI->getParent() || + (!isa<StoreInst>(UI) && !isa<LoadInst>(UI)))) + return false; + } + return true; + } else + return TTI->getUserCost(&I) == TargetTransformInfo::TCC_Free; +} + +/// Return true if the only users of this instruction are outside of +/// the loop. If this is true, we can sink the instruction to the exit +/// blocks of the loop. +/// +/// We also return true if the instruction could be folded away in lowering. +/// (e.g., a GEP can be folded into a load as an addressing mode in the loop). +static bool isNotUsedOrFreeInLoop(const Instruction &I, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + TargetTransformInfo *TTI, bool &FreeInLoop) { + const auto &BlockColors = SafetyInfo->BlockColors; + bool IsFree = isFreeInLoop(I, CurLoop, TTI); + for (const User *U : I.users()) { + const Instruction *UI = cast<Instruction>(U); + if (const PHINode *PN = dyn_cast<PHINode>(UI)) { + const BasicBlock *BB = PN->getParent(); + // We cannot sink uses in catchswitches. + if (isa<CatchSwitchInst>(BB->getTerminator())) + return false; + + // We need to sink a callsite to a unique funclet. Avoid sinking if the + // phi use is too muddled. + if (isa<CallInst>(I)) + if (!BlockColors.empty() && + BlockColors.find(const_cast<BasicBlock *>(BB))->second.size() != 1) + return false; + } + + if (CurLoop->contains(UI)) { + if (IsFree) { + FreeInLoop = true; + continue; + } + return false; + } + } + return true; +} + +static Instruction * +CloneInstructionInExitBlock(Instruction &I, BasicBlock &ExitBlock, PHINode &PN, + const LoopInfo *LI, + const LoopSafetyInfo *SafetyInfo) { + Instruction *New; + if (auto *CI = dyn_cast<CallInst>(&I)) { + const auto &BlockColors = SafetyInfo->BlockColors; + + // Sinking call-sites need to be handled differently from other + // instructions. The cloned call-site needs a funclet bundle operand + // appropriate for it's location in the CFG. + SmallVector<OperandBundleDef, 1> OpBundles; + for (unsigned BundleIdx = 0, BundleEnd = CI->getNumOperandBundles(); + BundleIdx != BundleEnd; ++BundleIdx) { + OperandBundleUse Bundle = CI->getOperandBundleAt(BundleIdx); + if (Bundle.getTagID() == LLVMContext::OB_funclet) + continue; + + OpBundles.emplace_back(Bundle); + } + + if (!BlockColors.empty()) { + const ColorVector &CV = BlockColors.find(&ExitBlock)->second; + assert(CV.size() == 1 && "non-unique color for exit block!"); + BasicBlock *BBColor = CV.front(); + Instruction *EHPad = BBColor->getFirstNonPHI(); + if (EHPad->isEHPad()) + OpBundles.emplace_back("funclet", EHPad); + } + + New = CallInst::Create(CI, OpBundles); + } else { + New = I.clone(); + } + + ExitBlock.getInstList().insert(ExitBlock.getFirstInsertionPt(), New); + if (!I.getName().empty()) + New->setName(I.getName() + ".le"); + + // Build LCSSA PHI nodes for any in-loop operands. Note that this is + // particularly cheap because we can rip off the PHI node that we're + // replacing for the number and blocks of the predecessors. + // OPT: If this shows up in a profile, we can instead finish sinking all + // invariant instructions, and then walk their operands to re-establish + // LCSSA. That will eliminate creating PHI nodes just to nuke them when + // sinking bottom-up. + for (User::op_iterator OI = New->op_begin(), OE = New->op_end(); OI != OE; + ++OI) + if (Instruction *OInst = dyn_cast<Instruction>(*OI)) + if (Loop *OLoop = LI->getLoopFor(OInst->getParent())) + if (!OLoop->contains(&PN)) { + PHINode *OpPN = + PHINode::Create(OInst->getType(), PN.getNumIncomingValues(), + OInst->getName() + ".lcssa", &ExitBlock.front()); + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) + OpPN->addIncoming(OInst, PN.getIncomingBlock(i)); + *OI = OpPN; + } + return New; +} + +static Instruction *sinkThroughTriviallyReplacablePHI( + PHINode *TPN, Instruction *I, LoopInfo *LI, + SmallDenseMap<BasicBlock *, Instruction *, 32> &SunkCopies, + const LoopSafetyInfo *SafetyInfo, const Loop *CurLoop) { + assert(isTriviallyReplacablePHI(*TPN, *I) && + "Expect only trivially replacalbe PHI"); + BasicBlock *ExitBlock = TPN->getParent(); + Instruction *New; + auto It = SunkCopies.find(ExitBlock); + if (It != SunkCopies.end()) + New = It->second; + else + New = SunkCopies[ExitBlock] = + CloneInstructionInExitBlock(*I, *ExitBlock, *TPN, LI, SafetyInfo); + return New; +} + +static bool canSplitPredecessors(PHINode *PN, LoopSafetyInfo *SafetyInfo) { + BasicBlock *BB = PN->getParent(); + if (!BB->canSplitPredecessors()) + return false; + // It's not impossible to split EHPad blocks, but if BlockColors already exist + // it require updating BlockColors for all offspring blocks accordingly. By + // skipping such corner case, we can make updating BlockColors after splitting + // predecessor fairly simple. + if (!SafetyInfo->BlockColors.empty() && BB->getFirstNonPHI()->isEHPad()) + return false; + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + BasicBlock *BBPred = *PI; + if (isa<IndirectBrInst>(BBPred->getTerminator())) + return false; + } + return true; +} + +static void splitPredecessorsOfLoopExit(PHINode *PN, DominatorTree *DT, + LoopInfo *LI, const Loop *CurLoop, + LoopSafetyInfo *SafetyInfo) { +#ifndef NDEBUG + SmallVector<BasicBlock *, 32> ExitBlocks; + CurLoop->getUniqueExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 32> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); +#endif + BasicBlock *ExitBB = PN->getParent(); + assert(ExitBlockSet.count(ExitBB) && "Expect the PHI is in an exit block."); + + // Split predecessors of the loop exit to make instructions in the loop are + // exposed to exit blocks through trivially replacable PHIs while keeping the + // loop in the canonical form where each predecessor of each exit block should + // be contained within the loop. For example, this will convert the loop below + // from + // + // LB1: + // %v1 = + // br %LE, %LB2 + // LB2: + // %v2 = + // br %LE, %LB1 + // LE: + // %p = phi [%v1, %LB1], [%v2, %LB2] <-- non-trivially replacable + // + // to + // + // LB1: + // %v1 = + // br %LE.split, %LB2 + // LB2: + // %v2 = + // br %LE.split2, %LB1 + // LE.split: + // %p1 = phi [%v1, %LB1] <-- trivially replacable + // br %LE + // LE.split2: + // %p2 = phi [%v2, %LB2] <-- trivially replacable + // br %LE + // LE: + // %p = phi [%p1, %LE.split], [%p2, %LE.split2] + // + auto &BlockColors = SafetyInfo->BlockColors; + SmallSetVector<BasicBlock *, 8> PredBBs(pred_begin(ExitBB), pred_end(ExitBB)); + while (!PredBBs.empty()) { + BasicBlock *PredBB = *PredBBs.begin(); + assert(CurLoop->contains(PredBB) && + "Expect all predecessors are in the loop"); + if (PN->getBasicBlockIndex(PredBB) >= 0) { + BasicBlock *NewPred = SplitBlockPredecessors( + ExitBB, PredBB, ".split.loop.exit", DT, LI, true); + // Since we do not allow splitting EH-block with BlockColors in + // canSplitPredecessors(), we can simply assign predecessor's color to + // the new block. + if (!BlockColors.empty()) + BlockColors[NewPred] = BlockColors[PredBB]; + } + PredBBs.remove(PredBB); + } +} + +/// When an instruction is found to only be used outside of the loop, this +/// function moves it to the exit blocks and patches up SSA form as needed. +/// This method is guaranteed to remove the original instruction from its +/// position, and may either delete it or move it to outside of the loop. +/// +static bool sink(Instruction &I, LoopInfo *LI, DominatorTree *DT, + const Loop *CurLoop, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, bool FreeInLoop) { + DEBUG(dbgs() << "LICM sinking instruction: " << I << "\n"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "InstSunk", &I) + << "sinking " << ore::NV("Inst", &I); + }); + bool Changed = false; + if (isa<LoadInst>(I)) + ++NumMovedLoads; + else if (isa<CallInst>(I)) + ++NumMovedCalls; + ++NumSunk; + + // Iterate over users to be ready for actual sinking. Replace users via + // unrechable blocks with undef and make all user PHIs trivially replcable. + SmallPtrSet<Instruction *, 8> VisitedUsers; + for (Value::user_iterator UI = I.user_begin(), UE = I.user_end(); UI != UE;) { + auto *User = cast<Instruction>(*UI); + Use &U = UI.getUse(); + ++UI; + + if (VisitedUsers.count(User) || CurLoop->contains(User)) + continue; + + if (!DT->isReachableFromEntry(User->getParent())) { + U = UndefValue::get(I.getType()); + Changed = true; + continue; + } + + // The user must be a PHI node. + PHINode *PN = cast<PHINode>(User); + + // Surprisingly, instructions can be used outside of loops without any + // exits. This can only happen in PHI nodes if the incoming block is + // unreachable. + BasicBlock *BB = PN->getIncomingBlock(U); + if (!DT->isReachableFromEntry(BB)) { + U = UndefValue::get(I.getType()); + Changed = true; + continue; + } + + VisitedUsers.insert(PN); + if (isTriviallyReplacablePHI(*PN, I)) + continue; + + if (!canSplitPredecessors(PN, SafetyInfo)) + return Changed; + + // Split predecessors of the PHI so that we can make users trivially + // replacable. + splitPredecessorsOfLoopExit(PN, DT, LI, CurLoop, SafetyInfo); + + // Should rebuild the iterators, as they may be invalidated by + // splitPredecessorsOfLoopExit(). + UI = I.user_begin(); + UE = I.user_end(); + } + + if (VisitedUsers.empty()) + return Changed; + +#ifndef NDEBUG + SmallVector<BasicBlock *, 32> ExitBlocks; + CurLoop->getUniqueExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 32> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); +#endif + + // Clones of this instruction. Don't create more than one per exit block! + SmallDenseMap<BasicBlock *, Instruction *, 32> SunkCopies; + + // If this instruction is only used outside of the loop, then all users are + // PHI nodes in exit blocks due to LCSSA form. Just RAUW them with clones of + // the instruction. + SmallSetVector<User*, 8> Users(I.user_begin(), I.user_end()); + for (auto *UI : Users) { + auto *User = cast<Instruction>(UI); + + if (CurLoop->contains(User)) + continue; + + PHINode *PN = cast<PHINode>(User); + assert(ExitBlockSet.count(PN->getParent()) && + "The LCSSA PHI is not in an exit block!"); + // The PHI must be trivially replacable. + Instruction *New = sinkThroughTriviallyReplacablePHI(PN, &I, LI, SunkCopies, + SafetyInfo, CurLoop); + PN->replaceAllUsesWith(New); + PN->eraseFromParent(); + Changed = true; + } + return Changed; +} + +/// When an instruction is found to only use loop invariant operands that +/// is safe to hoist, this instruction is called to do the dirty work. +/// +static bool hoist(Instruction &I, const DominatorTree *DT, const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { + auto *Preheader = CurLoop->getLoopPreheader(); + DEBUG(dbgs() << "LICM hoisting to " << Preheader->getName() << ": " << I + << "\n"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Hoisted", &I) << "hoisting " + << ore::NV("Inst", &I); + }); + + // Metadata can be dependent on conditions we are hoisting above. + // Conservatively strip all metadata on the instruction unless we were + // guaranteed to execute I if we entered the loop, in which case the metadata + // is valid in the loop preheader. + if (I.hasMetadataOtherThanDebugLoc() && + // The check on hasMetadataOtherThanDebugLoc is to prevent us from burning + // time in isGuaranteedToExecute if we don't actually have anything to + // drop. It is a compile time optimization, not required for correctness. + !isGuaranteedToExecute(I, DT, CurLoop, SafetyInfo)) + I.dropUnknownNonDebugMetadata(); + + // Move the new node to the Preheader, before its terminator. + I.moveBefore(Preheader->getTerminator()); + + // Do not retain debug locations when we are moving instructions to different + // basic blocks, because we want to avoid jumpy line tables. Calls, however, + // need to retain their debug locs because they may be inlined. + // FIXME: How do we retain source locations without causing poor debugging + // behavior? + if (!isa<CallInst>(I)) + I.setDebugLoc(DebugLoc()); + + if (isa<LoadInst>(I)) + ++NumMovedLoads; + else if (isa<CallInst>(I)) + ++NumMovedCalls; + ++NumHoisted; + return true; +} + +/// Only sink or hoist an instruction if it is not a trapping instruction, +/// or if the instruction is known not to trap when moved to the preheader. +/// or if it is a trapping instruction and is guaranteed to execute. +static bool isSafeToExecuteUnconditionally(Instruction &Inst, + const DominatorTree *DT, + const Loop *CurLoop, + const LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE, + const Instruction *CtxI) { + if (isSafeToSpeculativelyExecute(&Inst, CtxI, DT)) + return true; + + bool GuaranteedToExecute = + isGuaranteedToExecute(Inst, DT, CurLoop, SafetyInfo); + + if (!GuaranteedToExecute) { + auto *LI = dyn_cast<LoadInst>(&Inst); + if (LI && CurLoop->isLoopInvariant(LI->getPointerOperand())) + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "LoadWithLoopInvariantAddressCondExecuted", LI) + << "failed to hoist load with loop-invariant address " + "because load is conditionally executed"; + }); + } + + return GuaranteedToExecute; +} + +namespace { +class LoopPromoter : public LoadAndStorePromoter { + Value *SomePtr; // Designated pointer to store to. + const SmallSetVector<Value *, 8> &PointerMustAliases; + SmallVectorImpl<BasicBlock *> &LoopExitBlocks; + SmallVectorImpl<Instruction *> &LoopInsertPts; + PredIteratorCache &PredCache; + AliasSetTracker &AST; + LoopInfo &LI; + DebugLoc DL; + int Alignment; + bool UnorderedAtomic; + AAMDNodes AATags; + + Value *maybeInsertLCSSAPHI(Value *V, BasicBlock *BB) const { + if (Instruction *I = dyn_cast<Instruction>(V)) + if (Loop *L = LI.getLoopFor(I->getParent())) + if (!L->contains(BB)) { + // We need to create an LCSSA PHI node for the incoming value and + // store that. + PHINode *PN = PHINode::Create(I->getType(), PredCache.size(BB), + I->getName() + ".lcssa", &BB->front()); + for (BasicBlock *Pred : PredCache.get(BB)) + PN->addIncoming(I, Pred); + return PN; + } + return V; + } + +public: + LoopPromoter(Value *SP, ArrayRef<const Instruction *> Insts, SSAUpdater &S, + const SmallSetVector<Value *, 8> &PMA, + SmallVectorImpl<BasicBlock *> &LEB, + SmallVectorImpl<Instruction *> &LIP, PredIteratorCache &PIC, + AliasSetTracker &ast, LoopInfo &li, DebugLoc dl, int alignment, + bool UnorderedAtomic, const AAMDNodes &AATags) + : LoadAndStorePromoter(Insts, S), SomePtr(SP), PointerMustAliases(PMA), + LoopExitBlocks(LEB), LoopInsertPts(LIP), PredCache(PIC), AST(ast), + LI(li), DL(std::move(dl)), Alignment(alignment), + UnorderedAtomic(UnorderedAtomic), AATags(AATags) {} + + bool isInstInList(Instruction *I, + const SmallVectorImpl<Instruction *> &) const override { + Value *Ptr; + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + Ptr = LI->getOperand(0); + else + Ptr = cast<StoreInst>(I)->getPointerOperand(); + return PointerMustAliases.count(Ptr); + } + + void doExtraRewritesBeforeFinalDeletion() const override { + // Insert stores after in the loop exit blocks. Each exit block gets a + // store of the live-out values that feed them. Since we've already told + // the SSA updater about the defs in the loop and the preheader + // definition, it is all set and we can start using it. + for (unsigned i = 0, e = LoopExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = LoopExitBlocks[i]; + Value *LiveInValue = SSA.GetValueInMiddleOfBlock(ExitBlock); + LiveInValue = maybeInsertLCSSAPHI(LiveInValue, ExitBlock); + Value *Ptr = maybeInsertLCSSAPHI(SomePtr, ExitBlock); + Instruction *InsertPos = LoopInsertPts[i]; + StoreInst *NewSI = new StoreInst(LiveInValue, Ptr, InsertPos); + if (UnorderedAtomic) + NewSI->setOrdering(AtomicOrdering::Unordered); + NewSI->setAlignment(Alignment); + NewSI->setDebugLoc(DL); + if (AATags) + NewSI->setAAMetadata(AATags); + } + } + + void replaceLoadWithValue(LoadInst *LI, Value *V) const override { + // Update alias analysis. + AST.copyValue(LI, V); + } + void instructionDeleted(Instruction *I) const override { AST.deleteValue(I); } +}; + + +/// Return true iff we can prove that a caller of this function can not inspect +/// the contents of the provided object in a well defined program. +bool isKnownNonEscaping(Value *Object, const TargetLibraryInfo *TLI) { + if (isa<AllocaInst>(Object)) + // Since the alloca goes out of scope, we know the caller can't retain a + // reference to it and be well defined. Thus, we don't need to check for + // capture. + return true; + + // For all other objects we need to know that the caller can't possibly + // have gotten a reference to the object. There are two components of + // that: + // 1) Object can't be escaped by this function. This is what + // PointerMayBeCaptured checks. + // 2) Object can't have been captured at definition site. For this, we + // need to know the return value is noalias. At the moment, we use a + // weaker condition and handle only AllocLikeFunctions (which are + // known to be noalias). TODO + return isAllocLikeFn(Object, TLI) && + !PointerMayBeCaptured(Object, true, true); +} + +} // namespace + +/// Try to promote memory values to scalars by sinking stores out of the +/// loop and moving loads to before the loop. We do this by looping over +/// the stores in the loop, looking for stores to Must pointers which are +/// loop invariant. +/// +bool llvm::promoteLoopAccessesToScalars( + const SmallSetVector<Value *, 8> &PointerMustAliases, + SmallVectorImpl<BasicBlock *> &ExitBlocks, + SmallVectorImpl<Instruction *> &InsertPts, PredIteratorCache &PIC, + LoopInfo *LI, DominatorTree *DT, const TargetLibraryInfo *TLI, + Loop *CurLoop, AliasSetTracker *CurAST, LoopSafetyInfo *SafetyInfo, + OptimizationRemarkEmitter *ORE) { + // Verify inputs. + assert(LI != nullptr && DT != nullptr && CurLoop != nullptr && + CurAST != nullptr && SafetyInfo != nullptr && + "Unexpected Input to promoteLoopAccessesToScalars"); + + Value *SomePtr = *PointerMustAliases.begin(); + BasicBlock *Preheader = CurLoop->getLoopPreheader(); + + // It isn't safe to promote a load/store from the loop if the load/store is + // conditional. For example, turning: + // + // for () { if (c) *P += 1; } + // + // into: + // + // tmp = *P; for () { if (c) tmp +=1; } *P = tmp; + // + // is not safe, because *P may only be valid to access if 'c' is true. + // + // The safety property divides into two parts: + // p1) The memory may not be dereferenceable on entry to the loop. In this + // case, we can't insert the required load in the preheader. + // p2) The memory model does not allow us to insert a store along any dynamic + // path which did not originally have one. + // + // If at least one store is guaranteed to execute, both properties are + // satisfied, and promotion is legal. + // + // This, however, is not a necessary condition. Even if no store/load is + // guaranteed to execute, we can still establish these properties. + // We can establish (p1) by proving that hoisting the load into the preheader + // is safe (i.e. proving dereferenceability on all paths through the loop). We + // can use any access within the alias set to prove dereferenceability, + // since they're all must alias. + // + // There are two ways establish (p2): + // a) Prove the location is thread-local. In this case the memory model + // requirement does not apply, and stores are safe to insert. + // b) Prove a store dominates every exit block. In this case, if an exit + // blocks is reached, the original dynamic path would have taken us through + // the store, so inserting a store into the exit block is safe. Note that this + // is different from the store being guaranteed to execute. For instance, + // if an exception is thrown on the first iteration of the loop, the original + // store is never executed, but the exit blocks are not executed either. + + bool DereferenceableInPH = false; + bool SafeToInsertStore = false; + + SmallVector<Instruction *, 64> LoopUses; + + // We start with an alignment of one and try to find instructions that allow + // us to prove better alignment. + unsigned Alignment = 1; + // Keep track of which types of access we see + bool SawUnorderedAtomic = false; + bool SawNotAtomic = false; + AAMDNodes AATags; + + const DataLayout &MDL = Preheader->getModule()->getDataLayout(); + + bool IsKnownThreadLocalObject = false; + if (SafetyInfo->MayThrow) { + // If a loop can throw, we have to insert a store along each unwind edge. + // That said, we can't actually make the unwind edge explicit. Therefore, + // we have to prove that the store is dead along the unwind edge. We do + // this by proving that the caller can't have a reference to the object + // after return and thus can't possibly load from the object. + Value *Object = GetUnderlyingObject(SomePtr, MDL); + if (!isKnownNonEscaping(Object, TLI)) + return false; + // Subtlety: Alloca's aren't visible to callers, but *are* potentially + // visible to other threads if captured and used during their lifetimes. + IsKnownThreadLocalObject = !isa<AllocaInst>(Object); + } + + // Check that all of the pointers in the alias set have the same type. We + // cannot (yet) promote a memory location that is loaded and stored in + // different sizes. While we are at it, collect alignment and AA info. + for (Value *ASIV : PointerMustAliases) { + // Check that all of the pointers in the alias set have the same type. We + // cannot (yet) promote a memory location that is loaded and stored in + // different sizes. + if (SomePtr->getType() != ASIV->getType()) + return false; + + for (User *U : ASIV->users()) { + // Ignore instructions that are outside the loop. + Instruction *UI = dyn_cast<Instruction>(U); + if (!UI || !CurLoop->contains(UI)) + continue; + + // If there is an non-load/store instruction in the loop, we can't promote + // it. + if (LoadInst *Load = dyn_cast<LoadInst>(UI)) { + assert(!Load->isVolatile() && "AST broken"); + if (!Load->isUnordered()) + return false; + + SawUnorderedAtomic |= Load->isAtomic(); + SawNotAtomic |= !Load->isAtomic(); + + if (!DereferenceableInPH) + DereferenceableInPH = isSafeToExecuteUnconditionally( + *Load, DT, CurLoop, SafetyInfo, ORE, Preheader->getTerminator()); + } else if (const StoreInst *Store = dyn_cast<StoreInst>(UI)) { + // Stores *of* the pointer are not interesting, only stores *to* the + // pointer. + if (UI->getOperand(1) != ASIV) + continue; + assert(!Store->isVolatile() && "AST broken"); + if (!Store->isUnordered()) + return false; + + SawUnorderedAtomic |= Store->isAtomic(); + SawNotAtomic |= !Store->isAtomic(); + + // If the store is guaranteed to execute, both properties are satisfied. + // We may want to check if a store is guaranteed to execute even if we + // already know that promotion is safe, since it may have higher + // alignment than any other guaranteed stores, in which case we can + // raise the alignment on the promoted store. + unsigned InstAlignment = Store->getAlignment(); + if (!InstAlignment) + InstAlignment = + MDL.getABITypeAlignment(Store->getValueOperand()->getType()); + + if (!DereferenceableInPH || !SafeToInsertStore || + (InstAlignment > Alignment)) { + if (isGuaranteedToExecute(*UI, DT, CurLoop, SafetyInfo)) { + DereferenceableInPH = true; + SafeToInsertStore = true; + Alignment = std::max(Alignment, InstAlignment); + } + } + + // If a store dominates all exit blocks, it is safe to sink. + // As explained above, if an exit block was executed, a dominating + // store must have been been executed at least once, so we are not + // introducing stores on paths that did not have them. + // Note that this only looks at explicit exit blocks. If we ever + // start sinking stores into unwind edges (see above), this will break. + if (!SafeToInsertStore) + SafeToInsertStore = llvm::all_of(ExitBlocks, [&](BasicBlock *Exit) { + return DT->dominates(Store->getParent(), Exit); + }); + + // If the store is not guaranteed to execute, we may still get + // deref info through it. + if (!DereferenceableInPH) { + DereferenceableInPH = isDereferenceableAndAlignedPointer( + Store->getPointerOperand(), Store->getAlignment(), MDL, + Preheader->getTerminator(), DT); + } + } else + return false; // Not a load or store. + + // Merge the AA tags. + if (LoopUses.empty()) { + // On the first load/store, just take its AA tags. + UI->getAAMetadata(AATags); + } else if (AATags) { + UI->getAAMetadata(AATags, /* Merge = */ true); + } + + LoopUses.push_back(UI); + } + } + + // If we found both an unordered atomic instruction and a non-atomic memory + // access, bail. We can't blindly promote non-atomic to atomic since we + // might not be able to lower the result. We can't downgrade since that + // would violate memory model. Also, align 0 is an error for atomics. + if (SawUnorderedAtomic && SawNotAtomic) + return false; + + // If we couldn't prove we can hoist the load, bail. + if (!DereferenceableInPH) + return false; + + // We know we can hoist the load, but don't have a guaranteed store. + // Check whether the location is thread-local. If it is, then we can insert + // stores along paths which originally didn't have them without violating the + // memory model. + if (!SafeToInsertStore) { + if (IsKnownThreadLocalObject) + SafeToInsertStore = true; + else { + Value *Object = GetUnderlyingObject(SomePtr, MDL); + SafeToInsertStore = + (isAllocLikeFn(Object, TLI) || isa<AllocaInst>(Object)) && + !PointerMayBeCaptured(Object, true, true); + } + } + + // If we've still failed to prove we can sink the store, give up. + if (!SafeToInsertStore) + return false; + + // Otherwise, this is safe to promote, lets do it! + DEBUG(dbgs() << "LICM: Promoting value stored to in loop: " << *SomePtr + << '\n'); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "PromoteLoopAccessesToScalar", + LoopUses[0]) + << "Moving accesses to memory location out of the loop"; + }); + ++NumPromoted; + + // Grab a debug location for the inserted loads/stores; given that the + // inserted loads/stores have little relation to the original loads/stores, + // this code just arbitrarily picks a location from one, since any debug + // location is better than none. + DebugLoc DL = LoopUses[0]->getDebugLoc(); + + // We use the SSAUpdater interface to insert phi nodes as required. + SmallVector<PHINode *, 16> NewPHIs; + SSAUpdater SSA(&NewPHIs); + LoopPromoter Promoter(SomePtr, LoopUses, SSA, PointerMustAliases, ExitBlocks, + InsertPts, PIC, *CurAST, *LI, DL, Alignment, + SawUnorderedAtomic, AATags); + + // Set up the preheader to have a definition of the value. It is the live-out + // value from the preheader that uses in the loop will use. + LoadInst *PreheaderLoad = new LoadInst( + SomePtr, SomePtr->getName() + ".promoted", Preheader->getTerminator()); + if (SawUnorderedAtomic) + PreheaderLoad->setOrdering(AtomicOrdering::Unordered); + PreheaderLoad->setAlignment(Alignment); + PreheaderLoad->setDebugLoc(DL); + if (AATags) + PreheaderLoad->setAAMetadata(AATags); + SSA.AddAvailableValue(Preheader, PreheaderLoad); + + // Rewrite all the loads in the loop and remember all the definitions from + // stores in the loop. + Promoter.run(LoopUses); + + // If the SSAUpdater didn't use the load in the preheader, just zap it now. + if (PreheaderLoad->use_empty()) + PreheaderLoad->eraseFromParent(); + + return true; +} + +/// Returns an owning pointer to an alias set which incorporates aliasing info +/// from L and all subloops of L. +/// FIXME: In new pass manager, there is no helper function to handle loop +/// analysis such as cloneBasicBlockAnalysis, so the AST needs to be recomputed +/// from scratch for every loop. Hook up with the helper functions when +/// available in the new pass manager to avoid redundant computation. +AliasSetTracker * +LoopInvariantCodeMotion::collectAliasInfoForLoop(Loop *L, LoopInfo *LI, + AliasAnalysis *AA) { + AliasSetTracker *CurAST = nullptr; + SmallVector<Loop *, 4> RecomputeLoops; + for (Loop *InnerL : L->getSubLoops()) { + auto MapI = LoopToAliasSetMap.find(InnerL); + // If the AST for this inner loop is missing it may have been merged into + // some other loop's AST and then that loop unrolled, and so we need to + // recompute it. + if (MapI == LoopToAliasSetMap.end()) { + RecomputeLoops.push_back(InnerL); + continue; + } + AliasSetTracker *InnerAST = MapI->second; + + if (CurAST != nullptr) { + // What if InnerLoop was modified by other passes ? + CurAST->add(*InnerAST); + + // Once we've incorporated the inner loop's AST into ours, we don't need + // the subloop's anymore. + delete InnerAST; + } else { + CurAST = InnerAST; + } + LoopToAliasSetMap.erase(MapI); + } + if (CurAST == nullptr) + CurAST = new AliasSetTracker(*AA); + + auto mergeLoop = [&](Loop *L) { + // Loop over the body of this loop, looking for calls, invokes, and stores. + for (BasicBlock *BB : L->blocks()) + CurAST->add(*BB); // Incorporate the specified basic block + }; + + // Add everything from the sub loops that are no longer directly available. + for (Loop *InnerL : RecomputeLoops) + mergeLoop(InnerL); + + // And merge in this loop. + mergeLoop(L); + + return CurAST; +} + +/// Simple analysis hook. Clone alias set info. +/// +void LegacyLICMPass::cloneBasicBlockAnalysis(BasicBlock *From, BasicBlock *To, + Loop *L) { + AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); + if (!AST) + return; + + AST->copyValue(From, To); +} + +/// Simple Analysis hook. Delete value V from alias set +/// +void LegacyLICMPass::deleteAnalysisValue(Value *V, Loop *L) { + AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); + if (!AST) + return; + + AST->deleteValue(V); +} + +/// Simple Analysis hook. Delete value L from alias set map. +/// +void LegacyLICMPass::deleteAnalysisLoop(Loop *L) { + AliasSetTracker *AST = LICM.getLoopToAliasSetMap().lookup(L); + if (!AST) + return; + + delete AST; + LICM.getLoopToAliasSetMap().erase(L); +} + +/// Return true if the body of this loop may store into the memory +/// location pointed to by V. +/// +static bool pointerInvalidatedByLoop(Value *V, uint64_t Size, + const AAMDNodes &AAInfo, + AliasSetTracker *CurAST) { + // Check to see if any of the basic blocks in CurLoop invalidate *V. + return CurAST->getAliasSetForPointer(V, Size, AAInfo).isMod(); +} + +/// Little predicate that returns true if the specified basic block is in +/// a subloop of the current one, not the current one itself. +/// +static bool inSubLoop(BasicBlock *BB, Loop *CurLoop, LoopInfo *LI) { + assert(CurLoop->contains(BB) && "Only valid if BB is IN the loop"); + return LI->getLoopFor(BB) != CurLoop; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp new file mode 100644 index 000000000000..a64c99117d64 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopAccessAnalysisPrinter.cpp @@ -0,0 +1,25 @@ +//===- LoopAccessAnalysisPrinter.cpp - Loop Access Analysis Printer --------==// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopAccessAnalysisPrinter.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-accesses" + +PreservedAnalyses +LoopAccessInfoPrinterPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &) { + Function &F = *L.getHeader()->getParent(); + auto &LAI = AM.getResult<LoopAccessAnalysis>(L, AR); + OS << "Loop access info in function '" << F.getName() << "':\n"; + OS.indent(2) << L.getHeader()->getName() << ":\n"; + LAI.print(OS, 4); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp new file mode 100644 index 000000000000..7f7c6de76450 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDataPrefetch.cpp @@ -0,0 +1,336 @@ +//===-------- LoopDataPrefetch.cpp - Loop Data Prefetching Pass -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements a Loop Data Prefetching Pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopDataPrefetch.h" + +#define DEBUG_TYPE "loop-data-prefetch" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +using namespace llvm; + +// By default, we limit this to creating 16 PHIs (which is a little over half +// of the allocatable register set). +static cl::opt<bool> +PrefetchWrites("loop-prefetch-writes", cl::Hidden, cl::init(false), + cl::desc("Prefetch write addresses")); + +static cl::opt<unsigned> + PrefetchDistance("prefetch-distance", + cl::desc("Number of instructions to prefetch ahead"), + cl::Hidden); + +static cl::opt<unsigned> + MinPrefetchStride("min-prefetch-stride", + cl::desc("Min stride to add prefetches"), cl::Hidden); + +static cl::opt<unsigned> MaxPrefetchIterationsAhead( + "max-prefetch-iters-ahead", + cl::desc("Max number of iterations to prefetch ahead"), cl::Hidden); + +STATISTIC(NumPrefetches, "Number of prefetches inserted"); + +namespace { + +/// Loop prefetch implementation class. +class LoopDataPrefetch { +public: + LoopDataPrefetch(AssumptionCache *AC, LoopInfo *LI, ScalarEvolution *SE, + const TargetTransformInfo *TTI, + OptimizationRemarkEmitter *ORE) + : AC(AC), LI(LI), SE(SE), TTI(TTI), ORE(ORE) {} + + bool run(); + +private: + bool runOnLoop(Loop *L); + + /// \brief Check if the the stride of the accesses is large enough to + /// warrant a prefetch. + bool isStrideLargeEnough(const SCEVAddRecExpr *AR); + + unsigned getMinPrefetchStride() { + if (MinPrefetchStride.getNumOccurrences() > 0) + return MinPrefetchStride; + return TTI->getMinPrefetchStride(); + } + + unsigned getPrefetchDistance() { + if (PrefetchDistance.getNumOccurrences() > 0) + return PrefetchDistance; + return TTI->getPrefetchDistance(); + } + + unsigned getMaxPrefetchIterationsAhead() { + if (MaxPrefetchIterationsAhead.getNumOccurrences() > 0) + return MaxPrefetchIterationsAhead; + return TTI->getMaxPrefetchIterationsAhead(); + } + + AssumptionCache *AC; + LoopInfo *LI; + ScalarEvolution *SE; + const TargetTransformInfo *TTI; + OptimizationRemarkEmitter *ORE; +}; + +/// Legacy class for inserting loop data prefetches. +class LoopDataPrefetchLegacyPass : public FunctionPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopDataPrefetchLegacyPass() : FunctionPass(ID) { + initializeLoopDataPrefetchLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + bool runOnFunction(Function &F) override; + }; +} + +char LoopDataPrefetchLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopDataPrefetchLegacyPass, "loop-data-prefetch", + "Loop Data Prefetch", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(LoopDataPrefetchLegacyPass, "loop-data-prefetch", + "Loop Data Prefetch", false, false) + +FunctionPass *llvm::createLoopDataPrefetchPass() { + return new LoopDataPrefetchLegacyPass(); +} + +bool LoopDataPrefetch::isStrideLargeEnough(const SCEVAddRecExpr *AR) { + unsigned TargetMinStride = getMinPrefetchStride(); + // No need to check if any stride goes. + if (TargetMinStride <= 1) + return true; + + const auto *ConstStride = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)); + // If MinStride is set, don't prefetch unless we can ensure that stride is + // larger. + if (!ConstStride) + return false; + + unsigned AbsStride = std::abs(ConstStride->getAPInt().getSExtValue()); + return TargetMinStride <= AbsStride; +} + +PreservedAnalyses LoopDataPrefetchPass::run(Function &F, + FunctionAnalysisManager &AM) { + LoopInfo *LI = &AM.getResult<LoopAnalysis>(F); + ScalarEvolution *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); + AssumptionCache *AC = &AM.getResult<AssumptionAnalysis>(F); + OptimizationRemarkEmitter *ORE = + &AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + const TargetTransformInfo *TTI = &AM.getResult<TargetIRAnalysis>(F); + + LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); + bool Changed = LDP.run(); + + if (Changed) { + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<LoopAnalysis>(); + return PA; + } + + return PreservedAnalyses::all(); +} + +bool LoopDataPrefetchLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + OptimizationRemarkEmitter *ORE = + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + const TargetTransformInfo *TTI = + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + LoopDataPrefetch LDP(AC, LI, SE, TTI, ORE); + return LDP.run(); +} + +bool LoopDataPrefetch::run() { + // If PrefetchDistance is not set, don't run the pass. This gives an + // opportunity for targets to run this pass for selected subtargets only + // (whose TTI sets PrefetchDistance). + if (getPrefetchDistance() == 0) + return false; + assert(TTI->getCacheLineSize() && "Cache line size is not set for target"); + + bool MadeChange = false; + + for (Loop *I : *LI) + for (auto L = df_begin(I), LE = df_end(I); L != LE; ++L) + MadeChange |= runOnLoop(*L); + + return MadeChange; +} + +bool LoopDataPrefetch::runOnLoop(Loop *L) { + bool MadeChange = false; + + // Only prefetch in the inner-most loop + if (!L->empty()) + return MadeChange; + + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + + // Calculate the number of iterations ahead to prefetch + CodeMetrics Metrics; + for (const auto BB : L->blocks()) { + // If the loop already has prefetches, then assume that the user knows + // what they are doing and don't add any more. + for (auto &I : *BB) + if (CallInst *CI = dyn_cast<CallInst>(&I)) + if (Function *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::prefetch) + return MadeChange; + + Metrics.analyzeBasicBlock(BB, *TTI, EphValues); + } + unsigned LoopSize = Metrics.NumInsts; + if (!LoopSize) + LoopSize = 1; + + unsigned ItersAhead = getPrefetchDistance() / LoopSize; + if (!ItersAhead) + ItersAhead = 1; + + if (ItersAhead > getMaxPrefetchIterationsAhead()) + return MadeChange; + + DEBUG(dbgs() << "Prefetching " << ItersAhead + << " iterations ahead (loop size: " << LoopSize << ") in " + << L->getHeader()->getParent()->getName() << ": " << *L); + + SmallVector<std::pair<Instruction *, const SCEVAddRecExpr *>, 16> PrefLoads; + for (const auto BB : L->blocks()) { + for (auto &I : *BB) { + Value *PtrValue; + Instruction *MemI; + + if (LoadInst *LMemI = dyn_cast<LoadInst>(&I)) { + MemI = LMemI; + PtrValue = LMemI->getPointerOperand(); + } else if (StoreInst *SMemI = dyn_cast<StoreInst>(&I)) { + if (!PrefetchWrites) continue; + MemI = SMemI; + PtrValue = SMemI->getPointerOperand(); + } else continue; + + unsigned PtrAddrSpace = PtrValue->getType()->getPointerAddressSpace(); + if (PtrAddrSpace) + continue; + + if (L->isLoopInvariant(PtrValue)) + continue; + + const SCEV *LSCEV = SE->getSCEV(PtrValue); + const SCEVAddRecExpr *LSCEVAddRec = dyn_cast<SCEVAddRecExpr>(LSCEV); + if (!LSCEVAddRec) + continue; + + // Check if the the stride of the accesses is large enough to warrant a + // prefetch. + if (!isStrideLargeEnough(LSCEVAddRec)) + continue; + + // We don't want to double prefetch individual cache lines. If this load + // is known to be within one cache line of some other load that has + // already been prefetched, then don't prefetch this one as well. + bool DupPref = false; + for (const auto &PrefLoad : PrefLoads) { + const SCEV *PtrDiff = SE->getMinusSCEV(LSCEVAddRec, PrefLoad.second); + if (const SCEVConstant *ConstPtrDiff = + dyn_cast<SCEVConstant>(PtrDiff)) { + int64_t PD = std::abs(ConstPtrDiff->getValue()->getSExtValue()); + if (PD < (int64_t) TTI->getCacheLineSize()) { + DupPref = true; + break; + } + } + } + if (DupPref) + continue; + + const SCEV *NextLSCEV = SE->getAddExpr(LSCEVAddRec, SE->getMulExpr( + SE->getConstant(LSCEVAddRec->getType(), ItersAhead), + LSCEVAddRec->getStepRecurrence(*SE))); + if (!isSafeToExpand(NextLSCEV, *SE)) + continue; + + PrefLoads.push_back(std::make_pair(MemI, LSCEVAddRec)); + + Type *I8Ptr = Type::getInt8PtrTy(BB->getContext(), PtrAddrSpace); + SCEVExpander SCEVE(*SE, I.getModule()->getDataLayout(), "prefaddr"); + Value *PrefPtrValue = SCEVE.expandCodeFor(NextLSCEV, I8Ptr, MemI); + + IRBuilder<> Builder(MemI); + Module *M = BB->getParent()->getParent(); + Type *I32 = Type::getInt32Ty(BB->getContext()); + Value *PrefetchFunc = Intrinsic::getDeclaration(M, Intrinsic::prefetch); + Builder.CreateCall( + PrefetchFunc, + {PrefPtrValue, + ConstantInt::get(I32, MemI->mayReadFromMemory() ? 0 : 1), + ConstantInt::get(I32, 3), ConstantInt::get(I32, 1)}); + ++NumPrefetches; + DEBUG(dbgs() << " Access: " << *PtrValue << ", SCEV: " << *LSCEV + << "\n"); + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Prefetched", MemI) + << "prefetched memory access"; + }); + + MadeChange = true; + } + } + + return MadeChange; +} + diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp new file mode 100644 index 000000000000..15cd1086f209 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDeletion.cpp @@ -0,0 +1,267 @@ +//===- LoopDeletion.cpp - Dead Loop Deletion Pass ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Dead Loop Deletion Pass. This pass is responsible +// for eliminating loops with non-infinite computable trip counts that have no +// side effects or volatile instructions, and do not contribute to the +// computation of the function's return value. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopDeletion.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-delete" + +STATISTIC(NumDeleted, "Number of loops deleted"); + +enum class LoopDeletionResult { + Unmodified, + Modified, + Deleted, +}; + +/// Determines if a loop is dead. +/// +/// This assumes that we've already checked for unique exit and exiting blocks, +/// and that the code is in LCSSA form. +static bool isLoopDead(Loop *L, ScalarEvolution &SE, + SmallVectorImpl<BasicBlock *> &ExitingBlocks, + BasicBlock *ExitBlock, bool &Changed, + BasicBlock *Preheader) { + // Make sure that all PHI entries coming from the loop are loop invariant. + // Because the code is in LCSSA form, any values used outside of the loop + // must pass through a PHI in the exit block, meaning that this check is + // sufficient to guarantee that no loop-variant values are used outside + // of the loop. + bool AllEntriesInvariant = true; + bool AllOutgoingValuesSame = true; + for (PHINode &P : ExitBlock->phis()) { + Value *incoming = P.getIncomingValueForBlock(ExitingBlocks[0]); + + // Make sure all exiting blocks produce the same incoming value for the exit + // block. If there are different incoming values for different exiting + // blocks, then it is impossible to statically determine which value should + // be used. + AllOutgoingValuesSame = + all_of(makeArrayRef(ExitingBlocks).slice(1), [&](BasicBlock *BB) { + return incoming == P.getIncomingValueForBlock(BB); + }); + + if (!AllOutgoingValuesSame) + break; + + if (Instruction *I = dyn_cast<Instruction>(incoming)) + if (!L->makeLoopInvariant(I, Changed, Preheader->getTerminator())) { + AllEntriesInvariant = false; + break; + } + } + + if (Changed) + SE.forgetLoopDispositions(L); + + if (!AllEntriesInvariant || !AllOutgoingValuesSame) + return false; + + // Make sure that no instructions in the block have potential side-effects. + // This includes instructions that could write to memory, and loads that are + // marked volatile. + for (auto &I : L->blocks()) + if (any_of(*I, [](Instruction &I) { return I.mayHaveSideEffects(); })) + return false; + return true; +} + +/// This function returns true if there is no viable path from the +/// entry block to the header of \p L. Right now, it only does +/// a local search to save compile time. +static bool isLoopNeverExecuted(Loop *L) { + using namespace PatternMatch; + + auto *Preheader = L->getLoopPreheader(); + // TODO: We can relax this constraint, since we just need a loop + // predecessor. + assert(Preheader && "Needs preheader!"); + + if (Preheader == &Preheader->getParent()->getEntryBlock()) + return false; + // All predecessors of the preheader should have a constant conditional + // branch, with the loop's preheader as not-taken. + for (auto *Pred: predecessors(Preheader)) { + BasicBlock *Taken, *NotTaken; + ConstantInt *Cond; + if (!match(Pred->getTerminator(), + m_Br(m_ConstantInt(Cond), Taken, NotTaken))) + return false; + if (!Cond->getZExtValue()) + std::swap(Taken, NotTaken); + if (Taken == Preheader) + return false; + } + assert(!pred_empty(Preheader) && + "Preheader should have predecessors at this point!"); + // All the predecessors have the loop preheader as not-taken target. + return true; +} + +/// Remove a loop if it is dead. +/// +/// A loop is considered dead if it does not impact the observable behavior of +/// the program other than finite running time. This never removes a loop that +/// might be infinite (unless it is never executed), as doing so could change +/// the halting/non-halting nature of a program. +/// +/// This entire process relies pretty heavily on LoopSimplify form and LCSSA in +/// order to make various safety checks work. +/// +/// \returns true if any changes were made. This may mutate the loop even if it +/// is unable to delete it due to hoisting trivially loop invariant +/// instructions out of the loop. +static LoopDeletionResult deleteLoopIfDead(Loop *L, DominatorTree &DT, + ScalarEvolution &SE, LoopInfo &LI) { + assert(L->isLCSSAForm(DT) && "Expected LCSSA!"); + + // We can only remove the loop if there is a preheader that we can branch from + // after removing it. Also, if LoopSimplify form is not available, stay out + // of trouble. + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader || !L->hasDedicatedExits()) { + DEBUG(dbgs() + << "Deletion requires Loop with preheader and dedicated exits.\n"); + return LoopDeletionResult::Unmodified; + } + // We can't remove loops that contain subloops. If the subloops were dead, + // they would already have been removed in earlier executions of this pass. + if (L->begin() != L->end()) { + DEBUG(dbgs() << "Loop contains subloops.\n"); + return LoopDeletionResult::Unmodified; + } + + + BasicBlock *ExitBlock = L->getUniqueExitBlock(); + + if (ExitBlock && isLoopNeverExecuted(L)) { + DEBUG(dbgs() << "Loop is proven to never execute, delete it!"); + // Set incoming value to undef for phi nodes in the exit block. + for (PHINode &P : ExitBlock->phis()) { + std::fill(P.incoming_values().begin(), P.incoming_values().end(), + UndefValue::get(P.getType())); + } + deleteDeadLoop(L, &DT, &SE, &LI); + ++NumDeleted; + return LoopDeletionResult::Deleted; + } + + // The remaining checks below are for a loop being dead because all statements + // in the loop are invariant. + SmallVector<BasicBlock *, 4> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + + // We require that the loop only have a single exit block. Otherwise, we'd + // be in the situation of needing to be able to solve statically which exit + // block will be branched to, or trying to preserve the branching logic in + // a loop invariant manner. + if (!ExitBlock) { + DEBUG(dbgs() << "Deletion requires single exit block\n"); + return LoopDeletionResult::Unmodified; + } + // Finally, we have to check that the loop really is dead. + bool Changed = false; + if (!isLoopDead(L, SE, ExitingBlocks, ExitBlock, Changed, Preheader)) { + DEBUG(dbgs() << "Loop is not invariant, cannot delete.\n"); + return Changed ? LoopDeletionResult::Modified + : LoopDeletionResult::Unmodified; + } + + // Don't remove loops for which we can't solve the trip count. + // They could be infinite, in which case we'd be changing program behavior. + const SCEV *S = SE.getMaxBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(S)) { + DEBUG(dbgs() << "Could not compute SCEV MaxBackedgeTakenCount.\n"); + return Changed ? LoopDeletionResult::Modified + : LoopDeletionResult::Unmodified; + } + + DEBUG(dbgs() << "Loop is invariant, delete it!"); + deleteDeadLoop(L, &DT, &SE, &LI); + ++NumDeleted; + + return LoopDeletionResult::Deleted; +} + +PreservedAnalyses LoopDeletionPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &Updater) { + + DEBUG(dbgs() << "Analyzing Loop for deletion: "); + DEBUG(L.dump()); + std::string LoopName = L.getName(); + auto Result = deleteLoopIfDead(&L, AR.DT, AR.SE, AR.LI); + if (Result == LoopDeletionResult::Unmodified) + return PreservedAnalyses::all(); + + if (Result == LoopDeletionResult::Deleted) + Updater.markLoopAsDeleted(L, LoopName); + + return getLoopPassPreservedAnalyses(); +} + +namespace { +class LoopDeletionLegacyPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopDeletionLegacyPass() : LoopPass(ID) { + initializeLoopDeletionLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + // Possibly eliminate loop L if it is dead. + bool runOnLoop(Loop *L, LPPassManager &) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + getLoopAnalysisUsage(AU); + } +}; +} + +char LoopDeletionLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopDeletionLegacyPass, "loop-deletion", + "Delete dead loops", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopDeletionLegacyPass, "loop-deletion", + "Delete dead loops", false, false) + +Pass *llvm::createLoopDeletionPass() { return new LoopDeletionLegacyPass(); } + +bool LoopDeletionLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + + DEBUG(dbgs() << "Analyzing Loop for deletion: "); + DEBUG(L->dump()); + + LoopDeletionResult Result = deleteLoopIfDead(L, DT, SE, LI); + + if (Result == LoopDeletionResult::Deleted) + LPM.markLoopAsDeleted(*L); + + return Result != LoopDeletionResult::Unmodified; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp new file mode 100644 index 000000000000..0d7e3db901cb --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopDistribute.cpp @@ -0,0 +1,1025 @@ +//===- LoopDistribute.cpp - Loop Distribution Pass ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Loop Distribution Pass. Its main focus is to +// distribute loops that cannot be vectorized due to dependence cycles. It +// tries to isolate the offending dependences into a new loop allowing +// vectorization of the remaining parts. +// +// For dependence analysis, the pass uses the LoopVectorizer's +// LoopAccessAnalysis. Because this analysis presumes no change in the order of +// memory operations, special care is taken to preserve the lexical order of +// these operations. +// +// Similarly to the Vectorizer, the pass also supports loop versioning to +// run-time disambiguate potentially overlapping arrays. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopDistribute.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/EquivalenceClasses.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <cassert> +#include <functional> +#include <list> +#include <tuple> +#include <utility> + +using namespace llvm; + +#define LDIST_NAME "loop-distribute" +#define DEBUG_TYPE LDIST_NAME + +static cl::opt<bool> + LDistVerify("loop-distribute-verify", cl::Hidden, + cl::desc("Turn on DominatorTree and LoopInfo verification " + "after Loop Distribution"), + cl::init(false)); + +static cl::opt<bool> DistributeNonIfConvertible( + "loop-distribute-non-if-convertible", cl::Hidden, + cl::desc("Whether to distribute into a loop that may not be " + "if-convertible by the loop vectorizer"), + cl::init(false)); + +static cl::opt<unsigned> DistributeSCEVCheckThreshold( + "loop-distribute-scev-check-threshold", cl::init(8), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed for Loop " + "Distribution")); + +static cl::opt<unsigned> PragmaDistributeSCEVCheckThreshold( + "loop-distribute-scev-check-threshold-with-pragma", cl::init(128), + cl::Hidden, + cl::desc( + "The maximum number of SCEV checks allowed for Loop " + "Distribution for loop marked with #pragma loop distribute(enable)")); + +static cl::opt<bool> EnableLoopDistribute( + "enable-loop-distribute", cl::Hidden, + cl::desc("Enable the new, experimental LoopDistribution Pass"), + cl::init(false)); + +STATISTIC(NumLoopsDistributed, "Number of loops distributed"); + +namespace { + +/// \brief Maintains the set of instructions of the loop for a partition before +/// cloning. After cloning, it hosts the new loop. +class InstPartition { + using InstructionSet = SmallPtrSet<Instruction *, 8>; + +public: + InstPartition(Instruction *I, Loop *L, bool DepCycle = false) + : DepCycle(DepCycle), OrigLoop(L) { + Set.insert(I); + } + + /// \brief Returns whether this partition contains a dependence cycle. + bool hasDepCycle() const { return DepCycle; } + + /// \brief Adds an instruction to this partition. + void add(Instruction *I) { Set.insert(I); } + + /// \brief Collection accessors. + InstructionSet::iterator begin() { return Set.begin(); } + InstructionSet::iterator end() { return Set.end(); } + InstructionSet::const_iterator begin() const { return Set.begin(); } + InstructionSet::const_iterator end() const { return Set.end(); } + bool empty() const { return Set.empty(); } + + /// \brief Moves this partition into \p Other. This partition becomes empty + /// after this. + void moveTo(InstPartition &Other) { + Other.Set.insert(Set.begin(), Set.end()); + Set.clear(); + Other.DepCycle |= DepCycle; + } + + /// \brief Populates the partition with a transitive closure of all the + /// instructions that the seeded instructions dependent on. + void populateUsedSet() { + // FIXME: We currently don't use control-dependence but simply include all + // blocks (possibly empty at the end) and let simplifycfg mostly clean this + // up. + for (auto *B : OrigLoop->getBlocks()) + Set.insert(B->getTerminator()); + + // Follow the use-def chains to form a transitive closure of all the + // instructions that the originally seeded instructions depend on. + SmallVector<Instruction *, 8> Worklist(Set.begin(), Set.end()); + while (!Worklist.empty()) { + Instruction *I = Worklist.pop_back_val(); + // Insert instructions from the loop that we depend on. + for (Value *V : I->operand_values()) { + auto *I = dyn_cast<Instruction>(V); + if (I && OrigLoop->contains(I->getParent()) && Set.insert(I).second) + Worklist.push_back(I); + } + } + } + + /// \brief Clones the original loop. + /// + /// Updates LoopInfo and DominatorTree using the information that block \p + /// LoopDomBB dominates the loop. + Loop *cloneLoopWithPreheader(BasicBlock *InsertBefore, BasicBlock *LoopDomBB, + unsigned Index, LoopInfo *LI, + DominatorTree *DT) { + ClonedLoop = ::cloneLoopWithPreheader(InsertBefore, LoopDomBB, OrigLoop, + VMap, Twine(".ldist") + Twine(Index), + LI, DT, ClonedLoopBlocks); + return ClonedLoop; + } + + /// \brief The cloned loop. If this partition is mapped to the original loop, + /// this is null. + const Loop *getClonedLoop() const { return ClonedLoop; } + + /// \brief Returns the loop where this partition ends up after distribution. + /// If this partition is mapped to the original loop then use the block from + /// the loop. + const Loop *getDistributedLoop() const { + return ClonedLoop ? ClonedLoop : OrigLoop; + } + + /// \brief The VMap that is populated by cloning and then used in + /// remapinstruction to remap the cloned instructions. + ValueToValueMapTy &getVMap() { return VMap; } + + /// \brief Remaps the cloned instructions using VMap. + void remapInstructions() { + remapInstructionsInBlocks(ClonedLoopBlocks, VMap); + } + + /// \brief Based on the set of instructions selected for this partition, + /// removes the unnecessary ones. + void removeUnusedInsts() { + SmallVector<Instruction *, 8> Unused; + + for (auto *Block : OrigLoop->getBlocks()) + for (auto &Inst : *Block) + if (!Set.count(&Inst)) { + Instruction *NewInst = &Inst; + if (!VMap.empty()) + NewInst = cast<Instruction>(VMap[NewInst]); + + assert(!isa<BranchInst>(NewInst) && + "Branches are marked used early on"); + Unused.push_back(NewInst); + } + + // Delete the instructions backwards, as it has a reduced likelihood of + // having to update as many def-use and use-def chains. + for (auto *Inst : reverse(Unused)) { + if (!Inst->use_empty()) + Inst->replaceAllUsesWith(UndefValue::get(Inst->getType())); + Inst->eraseFromParent(); + } + } + + void print() const { + if (DepCycle) + dbgs() << " (cycle)\n"; + for (auto *I : Set) + // Prefix with the block name. + dbgs() << " " << I->getParent()->getName() << ":" << *I << "\n"; + } + + void printBlocks() const { + for (auto *BB : getDistributedLoop()->getBlocks()) + dbgs() << *BB; + } + +private: + /// \brief Instructions from OrigLoop selected for this partition. + InstructionSet Set; + + /// \brief Whether this partition contains a dependence cycle. + bool DepCycle; + + /// \brief The original loop. + Loop *OrigLoop; + + /// \brief The cloned loop. If this partition is mapped to the original loop, + /// this is null. + Loop *ClonedLoop = nullptr; + + /// \brief The blocks of ClonedLoop including the preheader. If this + /// partition is mapped to the original loop, this is empty. + SmallVector<BasicBlock *, 8> ClonedLoopBlocks; + + /// \brief These gets populated once the set of instructions have been + /// finalized. If this partition is mapped to the original loop, these are not + /// set. + ValueToValueMapTy VMap; +}; + +/// \brief Holds the set of Partitions. It populates them, merges them and then +/// clones the loops. +class InstPartitionContainer { + using InstToPartitionIdT = DenseMap<Instruction *, int>; + +public: + InstPartitionContainer(Loop *L, LoopInfo *LI, DominatorTree *DT) + : L(L), LI(LI), DT(DT) {} + + /// \brief Returns the number of partitions. + unsigned getSize() const { return PartitionContainer.size(); } + + /// \brief Adds \p Inst into the current partition if that is marked to + /// contain cycles. Otherwise start a new partition for it. + void addToCyclicPartition(Instruction *Inst) { + // If the current partition is non-cyclic. Start a new one. + if (PartitionContainer.empty() || !PartitionContainer.back().hasDepCycle()) + PartitionContainer.emplace_back(Inst, L, /*DepCycle=*/true); + else + PartitionContainer.back().add(Inst); + } + + /// \brief Adds \p Inst into a partition that is not marked to contain + /// dependence cycles. + /// + // Initially we isolate memory instructions into as many partitions as + // possible, then later we may merge them back together. + void addToNewNonCyclicPartition(Instruction *Inst) { + PartitionContainer.emplace_back(Inst, L); + } + + /// \brief Merges adjacent non-cyclic partitions. + /// + /// The idea is that we currently only want to isolate the non-vectorizable + /// partition. We could later allow more distribution among these partition + /// too. + void mergeAdjacentNonCyclic() { + mergeAdjacentPartitionsIf( + [](const InstPartition *P) { return !P->hasDepCycle(); }); + } + + /// \brief If a partition contains only conditional stores, we won't vectorize + /// it. Try to merge it with a previous cyclic partition. + void mergeNonIfConvertible() { + mergeAdjacentPartitionsIf([&](const InstPartition *Partition) { + if (Partition->hasDepCycle()) + return true; + + // Now, check if all stores are conditional in this partition. + bool seenStore = false; + + for (auto *Inst : *Partition) + if (isa<StoreInst>(Inst)) { + seenStore = true; + if (!LoopAccessInfo::blockNeedsPredication(Inst->getParent(), L, DT)) + return false; + } + return seenStore; + }); + } + + /// \brief Merges the partitions according to various heuristics. + void mergeBeforePopulating() { + mergeAdjacentNonCyclic(); + if (!DistributeNonIfConvertible) + mergeNonIfConvertible(); + } + + /// \brief Merges partitions in order to ensure that no loads are duplicated. + /// + /// We can't duplicate loads because that could potentially reorder them. + /// LoopAccessAnalysis provides dependency information with the context that + /// the order of memory operation is preserved. + /// + /// Return if any partitions were merged. + bool mergeToAvoidDuplicatedLoads() { + using LoadToPartitionT = DenseMap<Instruction *, InstPartition *>; + using ToBeMergedT = EquivalenceClasses<InstPartition *>; + + LoadToPartitionT LoadToPartition; + ToBeMergedT ToBeMerged; + + // Step through the partitions and create equivalence between partitions + // that contain the same load. Also put partitions in between them in the + // same equivalence class to avoid reordering of memory operations. + for (PartitionContainerT::iterator I = PartitionContainer.begin(), + E = PartitionContainer.end(); + I != E; ++I) { + auto *PartI = &*I; + + // If a load occurs in two partitions PartI and PartJ, merge all + // partitions (PartI, PartJ] into PartI. + for (Instruction *Inst : *PartI) + if (isa<LoadInst>(Inst)) { + bool NewElt; + LoadToPartitionT::iterator LoadToPart; + + std::tie(LoadToPart, NewElt) = + LoadToPartition.insert(std::make_pair(Inst, PartI)); + if (!NewElt) { + DEBUG(dbgs() << "Merging partitions due to this load in multiple " + << "partitions: " << PartI << ", " + << LoadToPart->second << "\n" << *Inst << "\n"); + + auto PartJ = I; + do { + --PartJ; + ToBeMerged.unionSets(PartI, &*PartJ); + } while (&*PartJ != LoadToPart->second); + } + } + } + if (ToBeMerged.empty()) + return false; + + // Merge the member of an equivalence class into its class leader. This + // makes the members empty. + for (ToBeMergedT::iterator I = ToBeMerged.begin(), E = ToBeMerged.end(); + I != E; ++I) { + if (!I->isLeader()) + continue; + + auto PartI = I->getData(); + for (auto PartJ : make_range(std::next(ToBeMerged.member_begin(I)), + ToBeMerged.member_end())) { + PartJ->moveTo(*PartI); + } + } + + // Remove the empty partitions. + PartitionContainer.remove_if( + [](const InstPartition &P) { return P.empty(); }); + + return true; + } + + /// \brief Sets up the mapping between instructions to partitions. If the + /// instruction is duplicated across multiple partitions, set the entry to -1. + void setupPartitionIdOnInstructions() { + int PartitionID = 0; + for (const auto &Partition : PartitionContainer) { + for (Instruction *Inst : Partition) { + bool NewElt; + InstToPartitionIdT::iterator Iter; + + std::tie(Iter, NewElt) = + InstToPartitionId.insert(std::make_pair(Inst, PartitionID)); + if (!NewElt) + Iter->second = -1; + } + ++PartitionID; + } + } + + /// \brief Populates the partition with everything that the seeding + /// instructions require. + void populateUsedSet() { + for (auto &P : PartitionContainer) + P.populateUsedSet(); + } + + /// \brief This performs the main chunk of the work of cloning the loops for + /// the partitions. + void cloneLoops() { + BasicBlock *OrigPH = L->getLoopPreheader(); + // At this point the predecessor of the preheader is either the memcheck + // block or the top part of the original preheader. + BasicBlock *Pred = OrigPH->getSinglePredecessor(); + assert(Pred && "Preheader does not have a single predecessor"); + BasicBlock *ExitBlock = L->getExitBlock(); + assert(ExitBlock && "No single exit block"); + Loop *NewLoop; + + assert(!PartitionContainer.empty() && "at least two partitions expected"); + // We're cloning the preheader along with the loop so we already made sure + // it was empty. + assert(&*OrigPH->begin() == OrigPH->getTerminator() && + "preheader not empty"); + + // Create a loop for each partition except the last. Clone the original + // loop before PH along with adding a preheader for the cloned loop. Then + // update PH to point to the newly added preheader. + BasicBlock *TopPH = OrigPH; + unsigned Index = getSize() - 1; + for (auto I = std::next(PartitionContainer.rbegin()), + E = PartitionContainer.rend(); + I != E; ++I, --Index, TopPH = NewLoop->getLoopPreheader()) { + auto *Part = &*I; + + NewLoop = Part->cloneLoopWithPreheader(TopPH, Pred, Index, LI, DT); + + Part->getVMap()[ExitBlock] = TopPH; + Part->remapInstructions(); + } + Pred->getTerminator()->replaceUsesOfWith(OrigPH, TopPH); + + // Now go in forward order and update the immediate dominator for the + // preheaders with the exiting block of the previous loop. Dominance + // within the loop is updated in cloneLoopWithPreheader. + for (auto Curr = PartitionContainer.cbegin(), + Next = std::next(PartitionContainer.cbegin()), + E = PartitionContainer.cend(); + Next != E; ++Curr, ++Next) + DT->changeImmediateDominator( + Next->getDistributedLoop()->getLoopPreheader(), + Curr->getDistributedLoop()->getExitingBlock()); + } + + /// \brief Removes the dead instructions from the cloned loops. + void removeUnusedInsts() { + for (auto &Partition : PartitionContainer) + Partition.removeUnusedInsts(); + } + + /// \brief For each memory pointer, it computes the partitionId the pointer is + /// used in. + /// + /// This returns an array of int where the I-th entry corresponds to I-th + /// entry in LAI.getRuntimePointerCheck(). If the pointer is used in multiple + /// partitions its entry is set to -1. + SmallVector<int, 8> + computePartitionSetForPointers(const LoopAccessInfo &LAI) { + const RuntimePointerChecking *RtPtrCheck = LAI.getRuntimePointerChecking(); + + unsigned N = RtPtrCheck->Pointers.size(); + SmallVector<int, 8> PtrToPartitions(N); + for (unsigned I = 0; I < N; ++I) { + Value *Ptr = RtPtrCheck->Pointers[I].PointerValue; + auto Instructions = + LAI.getInstructionsForAccess(Ptr, RtPtrCheck->Pointers[I].IsWritePtr); + + int &Partition = PtrToPartitions[I]; + // First set it to uninitialized. + Partition = -2; + for (Instruction *Inst : Instructions) { + // Note that this could be -1 if Inst is duplicated across multiple + // partitions. + int ThisPartition = this->InstToPartitionId[Inst]; + if (Partition == -2) + Partition = ThisPartition; + // -1 means belonging to multiple partitions. + else if (Partition == -1) + break; + else if (Partition != (int)ThisPartition) + Partition = -1; + } + assert(Partition != -2 && "Pointer not belonging to any partition"); + } + + return PtrToPartitions; + } + + void print(raw_ostream &OS) const { + unsigned Index = 0; + for (const auto &P : PartitionContainer) { + OS << "Partition " << Index++ << " (" << &P << "):\n"; + P.print(); + } + } + + void dump() const { print(dbgs()); } + +#ifndef NDEBUG + friend raw_ostream &operator<<(raw_ostream &OS, + const InstPartitionContainer &Partitions) { + Partitions.print(OS); + return OS; + } +#endif + + void printBlocks() const { + unsigned Index = 0; + for (const auto &P : PartitionContainer) { + dbgs() << "\nPartition " << Index++ << " (" << &P << "):\n"; + P.printBlocks(); + } + } + +private: + using PartitionContainerT = std::list<InstPartition>; + + /// \brief List of partitions. + PartitionContainerT PartitionContainer; + + /// \brief Mapping from Instruction to partition Id. If the instruction + /// belongs to multiple partitions the entry contains -1. + InstToPartitionIdT InstToPartitionId; + + Loop *L; + LoopInfo *LI; + DominatorTree *DT; + + /// \brief The control structure to merge adjacent partitions if both satisfy + /// the \p Predicate. + template <class UnaryPredicate> + void mergeAdjacentPartitionsIf(UnaryPredicate Predicate) { + InstPartition *PrevMatch = nullptr; + for (auto I = PartitionContainer.begin(); I != PartitionContainer.end();) { + auto DoesMatch = Predicate(&*I); + if (PrevMatch == nullptr && DoesMatch) { + PrevMatch = &*I; + ++I; + } else if (PrevMatch != nullptr && DoesMatch) { + I->moveTo(*PrevMatch); + I = PartitionContainer.erase(I); + } else { + PrevMatch = nullptr; + ++I; + } + } + } +}; + +/// \brief For each memory instruction, this class maintains difference of the +/// number of unsafe dependences that start out from this instruction minus +/// those that end here. +/// +/// By traversing the memory instructions in program order and accumulating this +/// number, we know whether any unsafe dependence crosses over a program point. +class MemoryInstructionDependences { + using Dependence = MemoryDepChecker::Dependence; + +public: + struct Entry { + Instruction *Inst; + unsigned NumUnsafeDependencesStartOrEnd = 0; + + Entry(Instruction *Inst) : Inst(Inst) {} + }; + + using AccessesType = SmallVector<Entry, 8>; + + AccessesType::const_iterator begin() const { return Accesses.begin(); } + AccessesType::const_iterator end() const { return Accesses.end(); } + + MemoryInstructionDependences( + const SmallVectorImpl<Instruction *> &Instructions, + const SmallVectorImpl<Dependence> &Dependences) { + Accesses.append(Instructions.begin(), Instructions.end()); + + DEBUG(dbgs() << "Backward dependences:\n"); + for (auto &Dep : Dependences) + if (Dep.isPossiblyBackward()) { + // Note that the designations source and destination follow the program + // order, i.e. source is always first. (The direction is given by the + // DepType.) + ++Accesses[Dep.Source].NumUnsafeDependencesStartOrEnd; + --Accesses[Dep.Destination].NumUnsafeDependencesStartOrEnd; + + DEBUG(Dep.print(dbgs(), 2, Instructions)); + } + } + +private: + AccessesType Accesses; +}; + +/// \brief The actual class performing the per-loop work. +class LoopDistributeForLoop { +public: + LoopDistributeForLoop(Loop *L, Function *F, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE, OptimizationRemarkEmitter *ORE) + : L(L), F(F), LI(LI), DT(DT), SE(SE), ORE(ORE) { + setForced(); + } + + /// \brief Try to distribute an inner-most loop. + bool processLoop(std::function<const LoopAccessInfo &(Loop &)> &GetLAA) { + assert(L->empty() && "Only process inner loops."); + + DEBUG(dbgs() << "\nLDist: In \"" << L->getHeader()->getParent()->getName() + << "\" checking " << *L << "\n"); + + if (!L->getExitBlock()) + return fail("MultipleExitBlocks", "multiple exit blocks"); + if (!L->isLoopSimplifyForm()) + return fail("NotLoopSimplifyForm", + "loop is not in loop-simplify form"); + + BasicBlock *PH = L->getLoopPreheader(); + + // LAA will check that we only have a single exiting block. + LAI = &GetLAA(*L); + + // Currently, we only distribute to isolate the part of the loop with + // dependence cycles to enable partial vectorization. + if (LAI->canVectorizeMemory()) + return fail("MemOpsCanBeVectorized", + "memory operations are safe for vectorization"); + + auto *Dependences = LAI->getDepChecker().getDependences(); + if (!Dependences || Dependences->empty()) + return fail("NoUnsafeDeps", "no unsafe dependences to isolate"); + + InstPartitionContainer Partitions(L, LI, DT); + + // First, go through each memory operation and assign them to consecutive + // partitions (the order of partitions follows program order). Put those + // with unsafe dependences into "cyclic" partition otherwise put each store + // in its own "non-cyclic" partition (we'll merge these later). + // + // Note that a memory operation (e.g. Load2 below) at a program point that + // has an unsafe dependence (Store3->Load1) spanning over it must be + // included in the same cyclic partition as the dependent operations. This + // is to preserve the original program order after distribution. E.g.: + // + // NumUnsafeDependencesStartOrEnd NumUnsafeDependencesActive + // Load1 -. 1 0->1 + // Load2 | /Unsafe/ 0 1 + // Store3 -' -1 1->0 + // Load4 0 0 + // + // NumUnsafeDependencesActive > 0 indicates this situation and in this case + // we just keep assigning to the same cyclic partition until + // NumUnsafeDependencesActive reaches 0. + const MemoryDepChecker &DepChecker = LAI->getDepChecker(); + MemoryInstructionDependences MID(DepChecker.getMemoryInstructions(), + *Dependences); + + int NumUnsafeDependencesActive = 0; + for (auto &InstDep : MID) { + Instruction *I = InstDep.Inst; + // We update NumUnsafeDependencesActive post-instruction, catch the + // start of a dependence directly via NumUnsafeDependencesStartOrEnd. + if (NumUnsafeDependencesActive || + InstDep.NumUnsafeDependencesStartOrEnd > 0) + Partitions.addToCyclicPartition(I); + else + Partitions.addToNewNonCyclicPartition(I); + NumUnsafeDependencesActive += InstDep.NumUnsafeDependencesStartOrEnd; + assert(NumUnsafeDependencesActive >= 0 && + "Negative number of dependences active"); + } + + // Add partitions for values used outside. These partitions can be out of + // order from the original program order. This is OK because if the + // partition uses a load we will merge this partition with the original + // partition of the load that we set up in the previous loop (see + // mergeToAvoidDuplicatedLoads). + auto DefsUsedOutside = findDefsUsedOutsideOfLoop(L); + for (auto *Inst : DefsUsedOutside) + Partitions.addToNewNonCyclicPartition(Inst); + + DEBUG(dbgs() << "Seeded partitions:\n" << Partitions); + if (Partitions.getSize() < 2) + return fail("CantIsolateUnsafeDeps", + "cannot isolate unsafe dependencies"); + + // Run the merge heuristics: Merge non-cyclic adjacent partitions since we + // should be able to vectorize these together. + Partitions.mergeBeforePopulating(); + DEBUG(dbgs() << "\nMerged partitions:\n" << Partitions); + if (Partitions.getSize() < 2) + return fail("CantIsolateUnsafeDeps", + "cannot isolate unsafe dependencies"); + + // Now, populate the partitions with non-memory operations. + Partitions.populateUsedSet(); + DEBUG(dbgs() << "\nPopulated partitions:\n" << Partitions); + + // In order to preserve original lexical order for loads, keep them in the + // partition that we set up in the MemoryInstructionDependences loop. + if (Partitions.mergeToAvoidDuplicatedLoads()) { + DEBUG(dbgs() << "\nPartitions merged to ensure unique loads:\n" + << Partitions); + if (Partitions.getSize() < 2) + return fail("CantIsolateUnsafeDeps", + "cannot isolate unsafe dependencies"); + } + + // Don't distribute the loop if we need too many SCEV run-time checks. + const SCEVUnionPredicate &Pred = LAI->getPSE().getUnionPredicate(); + if (Pred.getComplexity() > (IsForced.getValueOr(false) + ? PragmaDistributeSCEVCheckThreshold + : DistributeSCEVCheckThreshold)) + return fail("TooManySCEVRuntimeChecks", + "too many SCEV run-time checks needed.\n"); + + DEBUG(dbgs() << "\nDistributing loop: " << *L << "\n"); + // We're done forming the partitions set up the reverse mapping from + // instructions to partitions. + Partitions.setupPartitionIdOnInstructions(); + + // To keep things simple have an empty preheader before we version or clone + // the loop. (Also split if this has no predecessor, i.e. entry, because we + // rely on PH having a predecessor.) + if (!PH->getSinglePredecessor() || &*PH->begin() != PH->getTerminator()) + SplitBlock(PH, PH->getTerminator(), DT, LI); + + // If we need run-time checks, version the loop now. + auto PtrToPartition = Partitions.computePartitionSetForPointers(*LAI); + const auto *RtPtrChecking = LAI->getRuntimePointerChecking(); + const auto &AllChecks = RtPtrChecking->getChecks(); + auto Checks = includeOnlyCrossPartitionChecks(AllChecks, PtrToPartition, + RtPtrChecking); + + if (!Pred.isAlwaysTrue() || !Checks.empty()) { + DEBUG(dbgs() << "\nPointers:\n"); + DEBUG(LAI->getRuntimePointerChecking()->printChecks(dbgs(), Checks)); + LoopVersioning LVer(*LAI, L, LI, DT, SE, false); + LVer.setAliasChecks(std::move(Checks)); + LVer.setSCEVChecks(LAI->getPSE().getUnionPredicate()); + LVer.versionLoop(DefsUsedOutside); + LVer.annotateLoopWithNoAlias(); + } + + // Create identical copies of the original loop for each partition and hook + // them up sequentially. + Partitions.cloneLoops(); + + // Now, we remove the instruction from each loop that don't belong to that + // partition. + Partitions.removeUnusedInsts(); + DEBUG(dbgs() << "\nAfter removing unused Instrs:\n"); + DEBUG(Partitions.printBlocks()); + + if (LDistVerify) { + LI->verify(*DT); + DT->verifyDomTree(); + } + + ++NumLoopsDistributed; + // Report the success. + ORE->emit([&]() { + return OptimizationRemark(LDIST_NAME, "Distribute", L->getStartLoc(), + L->getHeader()) + << "distributed loop"; + }); + return true; + } + + /// \brief Provide diagnostics then \return with false. + bool fail(StringRef RemarkName, StringRef Message) { + LLVMContext &Ctx = F->getContext(); + bool Forced = isForced().getValueOr(false); + + DEBUG(dbgs() << "Skipping; " << Message << "\n"); + + // With Rpass-missed report that distribution failed. + ORE->emit([&]() { + return OptimizationRemarkMissed(LDIST_NAME, "NotDistributed", + L->getStartLoc(), L->getHeader()) + << "loop not distributed: use -Rpass-analysis=loop-distribute for " + "more " + "info"; + }); + + // With Rpass-analysis report why. This is on by default if distribution + // was requested explicitly. + ORE->emit(OptimizationRemarkAnalysis( + Forced ? OptimizationRemarkAnalysis::AlwaysPrint : LDIST_NAME, + RemarkName, L->getStartLoc(), L->getHeader()) + << "loop not distributed: " << Message); + + // Also issue a warning if distribution was requested explicitly but it + // failed. + if (Forced) + Ctx.diagnose(DiagnosticInfoOptimizationFailure( + *F, L->getStartLoc(), "loop not distributed: failed " + "explicitly specified loop distribution")); + + return false; + } + + /// \brief Return if distribution forced to be enabled/disabled for the loop. + /// + /// If the optional has a value, it indicates whether distribution was forced + /// to be enabled (true) or disabled (false). If the optional has no value + /// distribution was not forced either way. + const Optional<bool> &isForced() const { return IsForced; } + +private: + /// \brief Filter out checks between pointers from the same partition. + /// + /// \p PtrToPartition contains the partition number for pointers. Partition + /// number -1 means that the pointer is used in multiple partitions. In this + /// case we can't safely omit the check. + SmallVector<RuntimePointerChecking::PointerCheck, 4> + includeOnlyCrossPartitionChecks( + const SmallVectorImpl<RuntimePointerChecking::PointerCheck> &AllChecks, + const SmallVectorImpl<int> &PtrToPartition, + const RuntimePointerChecking *RtPtrChecking) { + SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; + + copy_if(AllChecks, std::back_inserter(Checks), + [&](const RuntimePointerChecking::PointerCheck &Check) { + for (unsigned PtrIdx1 : Check.first->Members) + for (unsigned PtrIdx2 : Check.second->Members) + // Only include this check if there is a pair of pointers + // that require checking and the pointers fall into + // separate partitions. + // + // (Note that we already know at this point that the two + // pointer groups need checking but it doesn't follow + // that each pair of pointers within the two groups need + // checking as well. + // + // In other words we don't want to include a check just + // because there is a pair of pointers between the two + // pointer groups that require checks and a different + // pair whose pointers fall into different partitions.) + if (RtPtrChecking->needsChecking(PtrIdx1, PtrIdx2) && + !RuntimePointerChecking::arePointersInSamePartition( + PtrToPartition, PtrIdx1, PtrIdx2)) + return true; + return false; + }); + + return Checks; + } + + /// \brief Check whether the loop metadata is forcing distribution to be + /// enabled/disabled. + void setForced() { + Optional<const MDOperand *> Value = + findStringMetadataForLoop(L, "llvm.loop.distribute.enable"); + if (!Value) + return; + + const MDOperand *Op = *Value; + assert(Op && mdconst::hasa<ConstantInt>(*Op) && "invalid metadata"); + IsForced = mdconst::extract<ConstantInt>(*Op)->getZExtValue(); + } + + Loop *L; + Function *F; + + // Analyses used. + LoopInfo *LI; + const LoopAccessInfo *LAI = nullptr; + DominatorTree *DT; + ScalarEvolution *SE; + OptimizationRemarkEmitter *ORE; + + /// \brief Indicates whether distribution is forced to be enabled/disabled for + /// the loop. + /// + /// If the optional has a value, it indicates whether distribution was forced + /// to be enabled (true) or disabled (false). If the optional has no value + /// distribution was not forced either way. + Optional<bool> IsForced; +}; + +} // end anonymous namespace + +/// Shared implementation between new and old PMs. +static bool runImpl(Function &F, LoopInfo *LI, DominatorTree *DT, + ScalarEvolution *SE, OptimizationRemarkEmitter *ORE, + std::function<const LoopAccessInfo &(Loop &)> &GetLAA) { + // Build up a worklist of inner-loops to vectorize. This is necessary as the + // act of distributing a loop creates new loops and can invalidate iterators + // across the loops. + SmallVector<Loop *, 8> Worklist; + + for (Loop *TopLevelLoop : *LI) + for (Loop *L : depth_first(TopLevelLoop)) + // We only handle inner-most loops. + if (L->empty()) + Worklist.push_back(L); + + // Now walk the identified inner loops. + bool Changed = false; + for (Loop *L : Worklist) { + LoopDistributeForLoop LDL(L, &F, LI, DT, SE, ORE); + + // If distribution was forced for the specific loop to be + // enabled/disabled, follow that. Otherwise use the global flag. + if (LDL.isForced().getValueOr(EnableLoopDistribute)) + Changed |= LDL.processLoop(GetLAA); + } + + // Process each loop nest in the function. + return Changed; +} + +namespace { + +/// \brief The pass class. +class LoopDistributeLegacy : public FunctionPass { +public: + static char ID; + + LoopDistributeLegacy() : FunctionPass(ID) { + // The default is set by the caller. + initializeLoopDistributeLegacyPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto *LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + std::function<const LoopAccessInfo &(Loop &)> GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { return LAA->getInfo(&L); }; + + return runImpl(F, LI, DT, SE, ORE, GetLAA); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<LoopAccessLegacyAnalysis>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; + +} // end anonymous namespace + +PreservedAnalyses LoopDistributePass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + // We don't directly need these analyses but they're required for loop + // analyses so provide them below. + auto &AA = AM.getResult<AAManager>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + + auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); + std::function<const LoopAccessInfo &(Loop &)> GetLAA = + [&](Loop &L) -> const LoopAccessInfo & { + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, SE, TLI, TTI, nullptr}; + return LAM.getResult<LoopAccessAnalysis>(L, AR); + }; + + bool Changed = runImpl(F, &LI, &DT, &SE, &ORE, GetLAA); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<LoopAnalysis>(); + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); + return PA; +} + +char LoopDistributeLegacy::ID; + +static const char ldist_name[] = "Loop Distribution"; + +INITIALIZE_PASS_BEGIN(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, + false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(LoopDistributeLegacy, LDIST_NAME, ldist_name, false, false) + +FunctionPass *llvm::createLoopDistributePass() { return new LoopDistributeLegacy(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp new file mode 100644 index 000000000000..21551f0a0825 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopIdiomRecognize.cpp @@ -0,0 +1,1728 @@ +//===- LoopIdiomRecognize.cpp - Loop idiom recognition --------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements an idiom recognizer that transforms simple loops into a +// non-loop form. In cases that this kicks in, it can be a significant +// performance win. +// +// If compiling for code size we avoid idiom recognition if the resulting +// code could be larger than the code for the original loop. One way this could +// happen is if the loop is not removable after idiom recognition due to the +// presence of non-idiom instructions. The initial implementation of the +// heuristics applies to idioms in multi-block loops. +// +//===----------------------------------------------------------------------===// +// +// TODO List: +// +// Future loop memory idioms to recognize: +// memcmp, memmove, strlen, etc. +// Future floating point idioms to recognize in -ffast-math mode: +// fpowi +// Future integer operation idioms to recognize: +// ctpop, ctlz, cttz +// +// Beware that isel's default lowering for ctpop is highly inefficient for +// i64 and larger types when i64 is legal and the value has few bits set. It +// would be good to enhance isel to emit a loop for ctpop in this case. +// +// This could recognize common matrix multiplies and dot product idioms and +// replace them with calls to BLAS (if linked in??). +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopIdiomRecognize.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugLoc.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "loop-idiom" + +STATISTIC(NumMemSet, "Number of memset's formed from loop stores"); +STATISTIC(NumMemCpy, "Number of memcpy's formed from loop load+stores"); + +static cl::opt<bool> UseLIRCodeSizeHeurs( + "use-lir-code-size-heurs", + cl::desc("Use loop idiom recognition code size heuristics when compiling" + "with -Os/-Oz"), + cl::init(true), cl::Hidden); + +namespace { + +class LoopIdiomRecognize { + Loop *CurLoop = nullptr; + AliasAnalysis *AA; + DominatorTree *DT; + LoopInfo *LI; + ScalarEvolution *SE; + TargetLibraryInfo *TLI; + const TargetTransformInfo *TTI; + const DataLayout *DL; + bool ApplyCodeSizeHeuristics; + +public: + explicit LoopIdiomRecognize(AliasAnalysis *AA, DominatorTree *DT, + LoopInfo *LI, ScalarEvolution *SE, + TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI, + const DataLayout *DL) + : AA(AA), DT(DT), LI(LI), SE(SE), TLI(TLI), TTI(TTI), DL(DL) {} + + bool runOnLoop(Loop *L); + +private: + using StoreList = SmallVector<StoreInst *, 8>; + using StoreListMap = MapVector<Value *, StoreList>; + + StoreListMap StoreRefsForMemset; + StoreListMap StoreRefsForMemsetPattern; + StoreList StoreRefsForMemcpy; + bool HasMemset; + bool HasMemsetPattern; + bool HasMemcpy; + + /// Return code for isLegalStore() + enum LegalStoreKind { + None = 0, + Memset, + MemsetPattern, + Memcpy, + UnorderedAtomicMemcpy, + DontUse // Dummy retval never to be used. Allows catching errors in retval + // handling. + }; + + /// \name Countable Loop Idiom Handling + /// @{ + + bool runOnCountableLoop(); + bool runOnLoopBlock(BasicBlock *BB, const SCEV *BECount, + SmallVectorImpl<BasicBlock *> &ExitBlocks); + + void collectStores(BasicBlock *BB); + LegalStoreKind isLegalStore(StoreInst *SI); + bool processLoopStores(SmallVectorImpl<StoreInst *> &SL, const SCEV *BECount, + bool ForMemset); + bool processLoopMemSet(MemSetInst *MSI, const SCEV *BECount); + + bool processLoopStridedStore(Value *DestPtr, unsigned StoreSize, + unsigned StoreAlignment, Value *StoredVal, + Instruction *TheStore, + SmallPtrSetImpl<Instruction *> &Stores, + const SCEVAddRecExpr *Ev, const SCEV *BECount, + bool NegStride, bool IsLoopMemset = false); + bool processLoopStoreOfLoopLoad(StoreInst *SI, const SCEV *BECount); + bool avoidLIRForMultiBlockLoop(bool IsMemset = false, + bool IsLoopMemset = false); + + /// @} + /// \name Noncountable Loop Idiom Handling + /// @{ + + bool runOnNoncountableLoop(); + + bool recognizePopcount(); + void transformLoopToPopcount(BasicBlock *PreCondBB, Instruction *CntInst, + PHINode *CntPhi, Value *Var); + bool recognizeAndInsertCTLZ(); + void transformLoopToCountable(BasicBlock *PreCondBB, Instruction *CntInst, + PHINode *CntPhi, Value *Var, const DebugLoc DL, + bool ZeroCheck, bool IsCntPhiUsedOutsideLoop); + + /// @} +}; + +class LoopIdiomRecognizeLegacyPass : public LoopPass { +public: + static char ID; + + explicit LoopIdiomRecognizeLegacyPass() : LoopPass(ID) { + initializeLoopIdiomRecognizeLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const TargetTransformInfo *TTI = + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()); + const DataLayout *DL = &L->getHeader()->getModule()->getDataLayout(); + + LoopIdiomRecognize LIR(AA, DT, LI, SE, TLI, TTI, DL); + return LIR.runOnLoop(L); + } + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG. + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +char LoopIdiomRecognizeLegacyPass::ID = 0; + +PreservedAnalyses LoopIdiomRecognizePass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + const auto *DL = &L.getHeader()->getModule()->getDataLayout(); + + LoopIdiomRecognize LIR(&AR.AA, &AR.DT, &AR.LI, &AR.SE, &AR.TLI, &AR.TTI, DL); + if (!LIR.runOnLoop(&L)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +INITIALIZE_PASS_BEGIN(LoopIdiomRecognizeLegacyPass, "loop-idiom", + "Recognize loop idioms", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(LoopIdiomRecognizeLegacyPass, "loop-idiom", + "Recognize loop idioms", false, false) + +Pass *llvm::createLoopIdiomPass() { return new LoopIdiomRecognizeLegacyPass(); } + +static void deleteDeadInstruction(Instruction *I) { + I->replaceAllUsesWith(UndefValue::get(I->getType())); + I->eraseFromParent(); +} + +//===----------------------------------------------------------------------===// +// +// Implementation of LoopIdiomRecognize +// +//===----------------------------------------------------------------------===// + +bool LoopIdiomRecognize::runOnLoop(Loop *L) { + CurLoop = L; + // If the loop could not be converted to canonical form, it must have an + // indirectbr in it, just give up. + if (!L->getLoopPreheader()) + return false; + + // Disable loop idiom recognition if the function's name is a common idiom. + StringRef Name = L->getHeader()->getParent()->getName(); + if (Name == "memset" || Name == "memcpy") + return false; + + // Determine if code size heuristics need to be applied. + ApplyCodeSizeHeuristics = + L->getHeader()->getParent()->optForSize() && UseLIRCodeSizeHeurs; + + HasMemset = TLI->has(LibFunc_memset); + HasMemsetPattern = TLI->has(LibFunc_memset_pattern16); + HasMemcpy = TLI->has(LibFunc_memcpy); + + if (HasMemset || HasMemsetPattern || HasMemcpy) + if (SE->hasLoopInvariantBackedgeTakenCount(L)) + return runOnCountableLoop(); + + return runOnNoncountableLoop(); +} + +bool LoopIdiomRecognize::runOnCountableLoop() { + const SCEV *BECount = SE->getBackedgeTakenCount(CurLoop); + assert(!isa<SCEVCouldNotCompute>(BECount) && + "runOnCountableLoop() called on a loop without a predictable" + "backedge-taken count"); + + // If this loop executes exactly one time, then it should be peeled, not + // optimized by this pass. + if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) + if (BECst->getAPInt() == 0) + return false; + + SmallVector<BasicBlock *, 8> ExitBlocks; + CurLoop->getUniqueExitBlocks(ExitBlocks); + + DEBUG(dbgs() << "loop-idiom Scanning: F[" + << CurLoop->getHeader()->getParent()->getName() << "] Loop %" + << CurLoop->getHeader()->getName() << "\n"); + + bool MadeChange = false; + + // The following transforms hoist stores/memsets into the loop pre-header. + // Give up if the loop has instructions may throw. + LoopSafetyInfo SafetyInfo; + computeLoopSafetyInfo(&SafetyInfo, CurLoop); + if (SafetyInfo.MayThrow) + return MadeChange; + + // Scan all the blocks in the loop that are not in subloops. + for (auto *BB : CurLoop->getBlocks()) { + // Ignore blocks in subloops. + if (LI->getLoopFor(BB) != CurLoop) + continue; + + MadeChange |= runOnLoopBlock(BB, BECount, ExitBlocks); + } + return MadeChange; +} + +static APInt getStoreStride(const SCEVAddRecExpr *StoreEv) { + const SCEVConstant *ConstStride = cast<SCEVConstant>(StoreEv->getOperand(1)); + return ConstStride->getAPInt(); +} + +/// getMemSetPatternValue - If a strided store of the specified value is safe to +/// turn into a memset_pattern16, return a ConstantArray of 16 bytes that should +/// be passed in. Otherwise, return null. +/// +/// Note that we don't ever attempt to use memset_pattern8 or 4, because these +/// just replicate their input array and then pass on to memset_pattern16. +static Constant *getMemSetPatternValue(Value *V, const DataLayout *DL) { + // If the value isn't a constant, we can't promote it to being in a constant + // array. We could theoretically do a store to an alloca or something, but + // that doesn't seem worthwhile. + Constant *C = dyn_cast<Constant>(V); + if (!C) + return nullptr; + + // Only handle simple values that are a power of two bytes in size. + uint64_t Size = DL->getTypeSizeInBits(V->getType()); + if (Size == 0 || (Size & 7) || (Size & (Size - 1))) + return nullptr; + + // Don't care enough about darwin/ppc to implement this. + if (DL->isBigEndian()) + return nullptr; + + // Convert to size in bytes. + Size /= 8; + + // TODO: If CI is larger than 16-bytes, we can try slicing it in half to see + // if the top and bottom are the same (e.g. for vectors and large integers). + if (Size > 16) + return nullptr; + + // If the constant is exactly 16 bytes, just use it. + if (Size == 16) + return C; + + // Otherwise, we'll use an array of the constants. + unsigned ArraySize = 16 / Size; + ArrayType *AT = ArrayType::get(V->getType(), ArraySize); + return ConstantArray::get(AT, std::vector<Constant *>(ArraySize, C)); +} + +LoopIdiomRecognize::LegalStoreKind +LoopIdiomRecognize::isLegalStore(StoreInst *SI) { + // Don't touch volatile stores. + if (SI->isVolatile()) + return LegalStoreKind::None; + // We only want simple or unordered-atomic stores. + if (!SI->isUnordered()) + return LegalStoreKind::None; + + // Don't convert stores of non-integral pointer types to memsets (which stores + // integers). + if (DL->isNonIntegralPointerType(SI->getValueOperand()->getType())) + return LegalStoreKind::None; + + // Avoid merging nontemporal stores. + if (SI->getMetadata(LLVMContext::MD_nontemporal)) + return LegalStoreKind::None; + + Value *StoredVal = SI->getValueOperand(); + Value *StorePtr = SI->getPointerOperand(); + + // Reject stores that are so large that they overflow an unsigned. + uint64_t SizeInBits = DL->getTypeSizeInBits(StoredVal->getType()); + if ((SizeInBits & 7) || (SizeInBits >> 32) != 0) + return LegalStoreKind::None; + + // See if the pointer expression is an AddRec like {base,+,1} on the current + // loop, which indicates a strided store. If we have something else, it's a + // random store we can't handle. + const SCEVAddRecExpr *StoreEv = + dyn_cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); + if (!StoreEv || StoreEv->getLoop() != CurLoop || !StoreEv->isAffine()) + return LegalStoreKind::None; + + // Check to see if we have a constant stride. + if (!isa<SCEVConstant>(StoreEv->getOperand(1))) + return LegalStoreKind::None; + + // See if the store can be turned into a memset. + + // If the stored value is a byte-wise value (like i32 -1), then it may be + // turned into a memset of i8 -1, assuming that all the consecutive bytes + // are stored. A store of i32 0x01020304 can never be turned into a memset, + // but it can be turned into memset_pattern if the target supports it. + Value *SplatValue = isBytewiseValue(StoredVal); + Constant *PatternValue = nullptr; + + // Note: memset and memset_pattern on unordered-atomic is yet not supported + bool UnorderedAtomic = SI->isUnordered() && !SI->isSimple(); + + // If we're allowed to form a memset, and the stored value would be + // acceptable for memset, use it. + if (!UnorderedAtomic && HasMemset && SplatValue && + // Verify that the stored value is loop invariant. If not, we can't + // promote the memset. + CurLoop->isLoopInvariant(SplatValue)) { + // It looks like we can use SplatValue. + return LegalStoreKind::Memset; + } else if (!UnorderedAtomic && HasMemsetPattern && + // Don't create memset_pattern16s with address spaces. + StorePtr->getType()->getPointerAddressSpace() == 0 && + (PatternValue = getMemSetPatternValue(StoredVal, DL))) { + // It looks like we can use PatternValue! + return LegalStoreKind::MemsetPattern; + } + + // Otherwise, see if the store can be turned into a memcpy. + if (HasMemcpy) { + // Check to see if the stride matches the size of the store. If so, then we + // know that every byte is touched in the loop. + APInt Stride = getStoreStride(StoreEv); + unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); + if (StoreSize != Stride && StoreSize != -Stride) + return LegalStoreKind::None; + + // The store must be feeding a non-volatile load. + LoadInst *LI = dyn_cast<LoadInst>(SI->getValueOperand()); + + // Only allow non-volatile loads + if (!LI || LI->isVolatile()) + return LegalStoreKind::None; + // Only allow simple or unordered-atomic loads + if (!LI->isUnordered()) + return LegalStoreKind::None; + + // See if the pointer expression is an AddRec like {base,+,1} on the current + // loop, which indicates a strided load. If we have something else, it's a + // random load we can't handle. + const SCEVAddRecExpr *LoadEv = + dyn_cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand())); + if (!LoadEv || LoadEv->getLoop() != CurLoop || !LoadEv->isAffine()) + return LegalStoreKind::None; + + // The store and load must share the same stride. + if (StoreEv->getOperand(1) != LoadEv->getOperand(1)) + return LegalStoreKind::None; + + // Success. This store can be converted into a memcpy. + UnorderedAtomic = UnorderedAtomic || LI->isAtomic(); + return UnorderedAtomic ? LegalStoreKind::UnorderedAtomicMemcpy + : LegalStoreKind::Memcpy; + } + // This store can't be transformed into a memset/memcpy. + return LegalStoreKind::None; +} + +void LoopIdiomRecognize::collectStores(BasicBlock *BB) { + StoreRefsForMemset.clear(); + StoreRefsForMemsetPattern.clear(); + StoreRefsForMemcpy.clear(); + for (Instruction &I : *BB) { + StoreInst *SI = dyn_cast<StoreInst>(&I); + if (!SI) + continue; + + // Make sure this is a strided store with a constant stride. + switch (isLegalStore(SI)) { + case LegalStoreKind::None: + // Nothing to do + break; + case LegalStoreKind::Memset: { + // Find the base pointer. + Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); + StoreRefsForMemset[Ptr].push_back(SI); + } break; + case LegalStoreKind::MemsetPattern: { + // Find the base pointer. + Value *Ptr = GetUnderlyingObject(SI->getPointerOperand(), *DL); + StoreRefsForMemsetPattern[Ptr].push_back(SI); + } break; + case LegalStoreKind::Memcpy: + case LegalStoreKind::UnorderedAtomicMemcpy: + StoreRefsForMemcpy.push_back(SI); + break; + default: + assert(false && "unhandled return value"); + break; + } + } +} + +/// runOnLoopBlock - Process the specified block, which lives in a counted loop +/// with the specified backedge count. This block is known to be in the current +/// loop and not in any subloops. +bool LoopIdiomRecognize::runOnLoopBlock( + BasicBlock *BB, const SCEV *BECount, + SmallVectorImpl<BasicBlock *> &ExitBlocks) { + // We can only promote stores in this block if they are unconditionally + // executed in the loop. For a block to be unconditionally executed, it has + // to dominate all the exit blocks of the loop. Verify this now. + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) + if (!DT->dominates(BB, ExitBlocks[i])) + return false; + + bool MadeChange = false; + // Look for store instructions, which may be optimized to memset/memcpy. + collectStores(BB); + + // Look for a single store or sets of stores with a common base, which can be + // optimized into a memset (memset_pattern). The latter most commonly happens + // with structs and handunrolled loops. + for (auto &SL : StoreRefsForMemset) + MadeChange |= processLoopStores(SL.second, BECount, true); + + for (auto &SL : StoreRefsForMemsetPattern) + MadeChange |= processLoopStores(SL.second, BECount, false); + + // Optimize the store into a memcpy, if it feeds an similarly strided load. + for (auto &SI : StoreRefsForMemcpy) + MadeChange |= processLoopStoreOfLoopLoad(SI, BECount); + + for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E;) { + Instruction *Inst = &*I++; + // Look for memset instructions, which may be optimized to a larger memset. + if (MemSetInst *MSI = dyn_cast<MemSetInst>(Inst)) { + WeakTrackingVH InstPtr(&*I); + if (!processLoopMemSet(MSI, BECount)) + continue; + MadeChange = true; + + // If processing the memset invalidated our iterator, start over from the + // top of the block. + if (!InstPtr) + I = BB->begin(); + continue; + } + } + + return MadeChange; +} + +/// processLoopStores - See if this store(s) can be promoted to a memset. +bool LoopIdiomRecognize::processLoopStores(SmallVectorImpl<StoreInst *> &SL, + const SCEV *BECount, + bool ForMemset) { + // Try to find consecutive stores that can be transformed into memsets. + SetVector<StoreInst *> Heads, Tails; + SmallDenseMap<StoreInst *, StoreInst *> ConsecutiveChain; + + // Do a quadratic search on all of the given stores and find + // all of the pairs of stores that follow each other. + SmallVector<unsigned, 16> IndexQueue; + for (unsigned i = 0, e = SL.size(); i < e; ++i) { + assert(SL[i]->isSimple() && "Expected only non-volatile stores."); + + Value *FirstStoredVal = SL[i]->getValueOperand(); + Value *FirstStorePtr = SL[i]->getPointerOperand(); + const SCEVAddRecExpr *FirstStoreEv = + cast<SCEVAddRecExpr>(SE->getSCEV(FirstStorePtr)); + APInt FirstStride = getStoreStride(FirstStoreEv); + unsigned FirstStoreSize = DL->getTypeStoreSize(SL[i]->getValueOperand()->getType()); + + // See if we can optimize just this store in isolation. + if (FirstStride == FirstStoreSize || -FirstStride == FirstStoreSize) { + Heads.insert(SL[i]); + continue; + } + + Value *FirstSplatValue = nullptr; + Constant *FirstPatternValue = nullptr; + + if (ForMemset) + FirstSplatValue = isBytewiseValue(FirstStoredVal); + else + FirstPatternValue = getMemSetPatternValue(FirstStoredVal, DL); + + assert((FirstSplatValue || FirstPatternValue) && + "Expected either splat value or pattern value."); + + IndexQueue.clear(); + // If a store has multiple consecutive store candidates, search Stores + // array according to the sequence: from i+1 to e, then from i-1 to 0. + // This is because usually pairing with immediate succeeding or preceding + // candidate create the best chance to find memset opportunity. + unsigned j = 0; + for (j = i + 1; j < e; ++j) + IndexQueue.push_back(j); + for (j = i; j > 0; --j) + IndexQueue.push_back(j - 1); + + for (auto &k : IndexQueue) { + assert(SL[k]->isSimple() && "Expected only non-volatile stores."); + Value *SecondStorePtr = SL[k]->getPointerOperand(); + const SCEVAddRecExpr *SecondStoreEv = + cast<SCEVAddRecExpr>(SE->getSCEV(SecondStorePtr)); + APInt SecondStride = getStoreStride(SecondStoreEv); + + if (FirstStride != SecondStride) + continue; + + Value *SecondStoredVal = SL[k]->getValueOperand(); + Value *SecondSplatValue = nullptr; + Constant *SecondPatternValue = nullptr; + + if (ForMemset) + SecondSplatValue = isBytewiseValue(SecondStoredVal); + else + SecondPatternValue = getMemSetPatternValue(SecondStoredVal, DL); + + assert((SecondSplatValue || SecondPatternValue) && + "Expected either splat value or pattern value."); + + if (isConsecutiveAccess(SL[i], SL[k], *DL, *SE, false)) { + if (ForMemset) { + if (FirstSplatValue != SecondSplatValue) + continue; + } else { + if (FirstPatternValue != SecondPatternValue) + continue; + } + Tails.insert(SL[k]); + Heads.insert(SL[i]); + ConsecutiveChain[SL[i]] = SL[k]; + break; + } + } + } + + // We may run into multiple chains that merge into a single chain. We mark the + // stores that we transformed so that we don't visit the same store twice. + SmallPtrSet<Value *, 16> TransformedStores; + bool Changed = false; + + // For stores that start but don't end a link in the chain: + for (SetVector<StoreInst *>::iterator it = Heads.begin(), e = Heads.end(); + it != e; ++it) { + if (Tails.count(*it)) + continue; + + // We found a store instr that starts a chain. Now follow the chain and try + // to transform it. + SmallPtrSet<Instruction *, 8> AdjacentStores; + StoreInst *I = *it; + + StoreInst *HeadStore = I; + unsigned StoreSize = 0; + + // Collect the chain into a list. + while (Tails.count(I) || Heads.count(I)) { + if (TransformedStores.count(I)) + break; + AdjacentStores.insert(I); + + StoreSize += DL->getTypeStoreSize(I->getValueOperand()->getType()); + // Move to the next value in the chain. + I = ConsecutiveChain[I]; + } + + Value *StoredVal = HeadStore->getValueOperand(); + Value *StorePtr = HeadStore->getPointerOperand(); + const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); + APInt Stride = getStoreStride(StoreEv); + + // Check to see if the stride matches the size of the stores. If so, then + // we know that every byte is touched in the loop. + if (StoreSize != Stride && StoreSize != -Stride) + continue; + + bool NegStride = StoreSize == -Stride; + + if (processLoopStridedStore(StorePtr, StoreSize, HeadStore->getAlignment(), + StoredVal, HeadStore, AdjacentStores, StoreEv, + BECount, NegStride)) { + TransformedStores.insert(AdjacentStores.begin(), AdjacentStores.end()); + Changed = true; + } + } + + return Changed; +} + +/// processLoopMemSet - See if this memset can be promoted to a large memset. +bool LoopIdiomRecognize::processLoopMemSet(MemSetInst *MSI, + const SCEV *BECount) { + // We can only handle non-volatile memsets with a constant size. + if (MSI->isVolatile() || !isa<ConstantInt>(MSI->getLength())) + return false; + + // If we're not allowed to hack on memset, we fail. + if (!HasMemset) + return false; + + Value *Pointer = MSI->getDest(); + + // See if the pointer expression is an AddRec like {base,+,1} on the current + // loop, which indicates a strided store. If we have something else, it's a + // random store we can't handle. + const SCEVAddRecExpr *Ev = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(Pointer)); + if (!Ev || Ev->getLoop() != CurLoop || !Ev->isAffine()) + return false; + + // Reject memsets that are so large that they overflow an unsigned. + uint64_t SizeInBytes = cast<ConstantInt>(MSI->getLength())->getZExtValue(); + if ((SizeInBytes >> 32) != 0) + return false; + + // Check to see if the stride matches the size of the memset. If so, then we + // know that every byte is touched in the loop. + const SCEVConstant *ConstStride = dyn_cast<SCEVConstant>(Ev->getOperand(1)); + if (!ConstStride) + return false; + + APInt Stride = ConstStride->getAPInt(); + if (SizeInBytes != Stride && SizeInBytes != -Stride) + return false; + + // Verify that the memset value is loop invariant. If not, we can't promote + // the memset. + Value *SplatValue = MSI->getValue(); + if (!SplatValue || !CurLoop->isLoopInvariant(SplatValue)) + return false; + + SmallPtrSet<Instruction *, 1> MSIs; + MSIs.insert(MSI); + bool NegStride = SizeInBytes == -Stride; + return processLoopStridedStore(Pointer, (unsigned)SizeInBytes, + MSI->getAlignment(), SplatValue, MSI, MSIs, Ev, + BECount, NegStride, /*IsLoopMemset=*/true); +} + +/// mayLoopAccessLocation - Return true if the specified loop might access the +/// specified pointer location, which is a loop-strided access. The 'Access' +/// argument specifies what the verboten forms of access are (read or write). +static bool +mayLoopAccessLocation(Value *Ptr, ModRefInfo Access, Loop *L, + const SCEV *BECount, unsigned StoreSize, + AliasAnalysis &AA, + SmallPtrSetImpl<Instruction *> &IgnoredStores) { + // Get the location that may be stored across the loop. Since the access is + // strided positively through memory, we say that the modified location starts + // at the pointer and has infinite size. + uint64_t AccessSize = MemoryLocation::UnknownSize; + + // If the loop iterates a fixed number of times, we can refine the access size + // to be exactly the size of the memset, which is (BECount+1)*StoreSize + if (const SCEVConstant *BECst = dyn_cast<SCEVConstant>(BECount)) + AccessSize = (BECst->getValue()->getZExtValue() + 1) * StoreSize; + + // TODO: For this to be really effective, we have to dive into the pointer + // operand in the store. Store to &A[i] of 100 will always return may alias + // with store of &A[100], we need to StoreLoc to be "A" with size of 100, + // which will then no-alias a store to &A[100]. + MemoryLocation StoreLoc(Ptr, AccessSize); + + for (Loop::block_iterator BI = L->block_begin(), E = L->block_end(); BI != E; + ++BI) + for (Instruction &I : **BI) + if (IgnoredStores.count(&I) == 0 && + isModOrRefSet( + intersectModRef(AA.getModRefInfo(&I, StoreLoc), Access))) + return true; + + return false; +} + +// If we have a negative stride, Start refers to the end of the memory location +// we're trying to memset. Therefore, we need to recompute the base pointer, +// which is just Start - BECount*Size. +static const SCEV *getStartForNegStride(const SCEV *Start, const SCEV *BECount, + Type *IntPtr, unsigned StoreSize, + ScalarEvolution *SE) { + const SCEV *Index = SE->getTruncateOrZeroExtend(BECount, IntPtr); + if (StoreSize != 1) + Index = SE->getMulExpr(Index, SE->getConstant(IntPtr, StoreSize), + SCEV::FlagNUW); + return SE->getMinusSCEV(Start, Index); +} + +/// Compute the number of bytes as a SCEV from the backedge taken count. +/// +/// This also maps the SCEV into the provided type and tries to handle the +/// computation in a way that will fold cleanly. +static const SCEV *getNumBytes(const SCEV *BECount, Type *IntPtr, + unsigned StoreSize, Loop *CurLoop, + const DataLayout *DL, ScalarEvolution *SE) { + const SCEV *NumBytesS; + // The # stored bytes is (BECount+1)*Size. Expand the trip count out to + // pointer size if it isn't already. + // + // If we're going to need to zero extend the BE count, check if we can add + // one to it prior to zero extending without overflow. Provided this is safe, + // it allows better simplification of the +1. + if (DL->getTypeSizeInBits(BECount->getType()) < + DL->getTypeSizeInBits(IntPtr) && + SE->isLoopEntryGuardedByCond( + CurLoop, ICmpInst::ICMP_NE, BECount, + SE->getNegativeSCEV(SE->getOne(BECount->getType())))) { + NumBytesS = SE->getZeroExtendExpr( + SE->getAddExpr(BECount, SE->getOne(BECount->getType()), SCEV::FlagNUW), + IntPtr); + } else { + NumBytesS = SE->getAddExpr(SE->getTruncateOrZeroExtend(BECount, IntPtr), + SE->getOne(IntPtr), SCEV::FlagNUW); + } + + // And scale it based on the store size. + if (StoreSize != 1) { + NumBytesS = SE->getMulExpr(NumBytesS, SE->getConstant(IntPtr, StoreSize), + SCEV::FlagNUW); + } + return NumBytesS; +} + +/// processLoopStridedStore - We see a strided store of some value. If we can +/// transform this into a memset or memset_pattern in the loop preheader, do so. +bool LoopIdiomRecognize::processLoopStridedStore( + Value *DestPtr, unsigned StoreSize, unsigned StoreAlignment, + Value *StoredVal, Instruction *TheStore, + SmallPtrSetImpl<Instruction *> &Stores, const SCEVAddRecExpr *Ev, + const SCEV *BECount, bool NegStride, bool IsLoopMemset) { + Value *SplatValue = isBytewiseValue(StoredVal); + Constant *PatternValue = nullptr; + + if (!SplatValue) + PatternValue = getMemSetPatternValue(StoredVal, DL); + + assert((SplatValue || PatternValue) && + "Expected either splat value or pattern value."); + + // The trip count of the loop and the base pointer of the addrec SCEV is + // guaranteed to be loop invariant, which means that it should dominate the + // header. This allows us to insert code for it in the preheader. + unsigned DestAS = DestPtr->getType()->getPointerAddressSpace(); + BasicBlock *Preheader = CurLoop->getLoopPreheader(); + IRBuilder<> Builder(Preheader->getTerminator()); + SCEVExpander Expander(*SE, *DL, "loop-idiom"); + + Type *DestInt8PtrTy = Builder.getInt8PtrTy(DestAS); + Type *IntPtr = Builder.getIntPtrTy(*DL, DestAS); + + const SCEV *Start = Ev->getStart(); + // Handle negative strided loops. + if (NegStride) + Start = getStartForNegStride(Start, BECount, IntPtr, StoreSize, SE); + + // TODO: ideally we should still be able to generate memset if SCEV expander + // is taught to generate the dependencies at the latest point. + if (!isSafeToExpand(Start, *SE)) + return false; + + // Okay, we have a strided store "p[i]" of a splattable value. We can turn + // this into a memset in the loop preheader now if we want. However, this + // would be unsafe to do if there is anything else in the loop that may read + // or write to the aliased location. Check for any overlap by generating the + // base pointer and checking the region. + Value *BasePtr = + Expander.expandCodeFor(Start, DestInt8PtrTy, Preheader->getTerminator()); + if (mayLoopAccessLocation(BasePtr, ModRefInfo::ModRef, CurLoop, BECount, + StoreSize, *AA, Stores)) { + Expander.clear(); + // If we generated new code for the base pointer, clean up. + RecursivelyDeleteTriviallyDeadInstructions(BasePtr, TLI); + return false; + } + + if (avoidLIRForMultiBlockLoop(/*IsMemset=*/true, IsLoopMemset)) + return false; + + // Okay, everything looks good, insert the memset. + + const SCEV *NumBytesS = + getNumBytes(BECount, IntPtr, StoreSize, CurLoop, DL, SE); + + // TODO: ideally we should still be able to generate memset if SCEV expander + // is taught to generate the dependencies at the latest point. + if (!isSafeToExpand(NumBytesS, *SE)) + return false; + + Value *NumBytes = + Expander.expandCodeFor(NumBytesS, IntPtr, Preheader->getTerminator()); + + CallInst *NewCall; + if (SplatValue) { + NewCall = + Builder.CreateMemSet(BasePtr, SplatValue, NumBytes, StoreAlignment); + } else { + // Everything is emitted in default address space + Type *Int8PtrTy = DestInt8PtrTy; + + Module *M = TheStore->getModule(); + Value *MSP = + M->getOrInsertFunction("memset_pattern16", Builder.getVoidTy(), + Int8PtrTy, Int8PtrTy, IntPtr); + inferLibFuncAttributes(*M->getFunction("memset_pattern16"), *TLI); + + // Otherwise we should form a memset_pattern16. PatternValue is known to be + // an constant array of 16-bytes. Plop the value into a mergable global. + GlobalVariable *GV = new GlobalVariable(*M, PatternValue->getType(), true, + GlobalValue::PrivateLinkage, + PatternValue, ".memset_pattern"); + GV->setUnnamedAddr(GlobalValue::UnnamedAddr::Global); // Ok to merge these. + GV->setAlignment(16); + Value *PatternPtr = ConstantExpr::getBitCast(GV, Int8PtrTy); + NewCall = Builder.CreateCall(MSP, {BasePtr, PatternPtr, NumBytes}); + } + + DEBUG(dbgs() << " Formed memset: " << *NewCall << "\n" + << " from store to: " << *Ev << " at: " << *TheStore << "\n"); + NewCall->setDebugLoc(TheStore->getDebugLoc()); + + // Okay, the memset has been formed. Zap the original store and anything that + // feeds into it. + for (auto *I : Stores) + deleteDeadInstruction(I); + ++NumMemSet; + return true; +} + +/// If the stored value is a strided load in the same loop with the same stride +/// this may be transformable into a memcpy. This kicks in for stuff like +/// for (i) A[i] = B[i]; +bool LoopIdiomRecognize::processLoopStoreOfLoopLoad(StoreInst *SI, + const SCEV *BECount) { + assert(SI->isUnordered() && "Expected only non-volatile non-ordered stores."); + + Value *StorePtr = SI->getPointerOperand(); + const SCEVAddRecExpr *StoreEv = cast<SCEVAddRecExpr>(SE->getSCEV(StorePtr)); + APInt Stride = getStoreStride(StoreEv); + unsigned StoreSize = DL->getTypeStoreSize(SI->getValueOperand()->getType()); + bool NegStride = StoreSize == -Stride; + + // The store must be feeding a non-volatile load. + LoadInst *LI = cast<LoadInst>(SI->getValueOperand()); + assert(LI->isUnordered() && "Expected only non-volatile non-ordered loads."); + + // See if the pointer expression is an AddRec like {base,+,1} on the current + // loop, which indicates a strided load. If we have something else, it's a + // random load we can't handle. + const SCEVAddRecExpr *LoadEv = + cast<SCEVAddRecExpr>(SE->getSCEV(LI->getPointerOperand())); + + // The trip count of the loop and the base pointer of the addrec SCEV is + // guaranteed to be loop invariant, which means that it should dominate the + // header. This allows us to insert code for it in the preheader. + BasicBlock *Preheader = CurLoop->getLoopPreheader(); + IRBuilder<> Builder(Preheader->getTerminator()); + SCEVExpander Expander(*SE, *DL, "loop-idiom"); + + const SCEV *StrStart = StoreEv->getStart(); + unsigned StrAS = SI->getPointerAddressSpace(); + Type *IntPtrTy = Builder.getIntPtrTy(*DL, StrAS); + + // Handle negative strided loops. + if (NegStride) + StrStart = getStartForNegStride(StrStart, BECount, IntPtrTy, StoreSize, SE); + + // Okay, we have a strided store "p[i]" of a loaded value. We can turn + // this into a memcpy in the loop preheader now if we want. However, this + // would be unsafe to do if there is anything else in the loop that may read + // or write the memory region we're storing to. This includes the load that + // feeds the stores. Check for an alias by generating the base address and + // checking everything. + Value *StoreBasePtr = Expander.expandCodeFor( + StrStart, Builder.getInt8PtrTy(StrAS), Preheader->getTerminator()); + + SmallPtrSet<Instruction *, 1> Stores; + Stores.insert(SI); + if (mayLoopAccessLocation(StoreBasePtr, ModRefInfo::ModRef, CurLoop, BECount, + StoreSize, *AA, Stores)) { + Expander.clear(); + // If we generated new code for the base pointer, clean up. + RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); + return false; + } + + const SCEV *LdStart = LoadEv->getStart(); + unsigned LdAS = LI->getPointerAddressSpace(); + + // Handle negative strided loops. + if (NegStride) + LdStart = getStartForNegStride(LdStart, BECount, IntPtrTy, StoreSize, SE); + + // For a memcpy, we have to make sure that the input array is not being + // mutated by the loop. + Value *LoadBasePtr = Expander.expandCodeFor( + LdStart, Builder.getInt8PtrTy(LdAS), Preheader->getTerminator()); + + if (mayLoopAccessLocation(LoadBasePtr, ModRefInfo::Mod, CurLoop, BECount, + StoreSize, *AA, Stores)) { + Expander.clear(); + // If we generated new code for the base pointer, clean up. + RecursivelyDeleteTriviallyDeadInstructions(LoadBasePtr, TLI); + RecursivelyDeleteTriviallyDeadInstructions(StoreBasePtr, TLI); + return false; + } + + if (avoidLIRForMultiBlockLoop()) + return false; + + // Okay, everything is safe, we can transform this! + + const SCEV *NumBytesS = + getNumBytes(BECount, IntPtrTy, StoreSize, CurLoop, DL, SE); + + Value *NumBytes = + Expander.expandCodeFor(NumBytesS, IntPtrTy, Preheader->getTerminator()); + + unsigned Align = std::min(SI->getAlignment(), LI->getAlignment()); + CallInst *NewCall = nullptr; + // Check whether to generate an unordered atomic memcpy: + // If the load or store are atomic, then they must neccessarily be unordered + // by previous checks. + if (!SI->isAtomic() && !LI->isAtomic()) + NewCall = Builder.CreateMemCpy(StoreBasePtr, LoadBasePtr, NumBytes, Align); + else { + // We cannot allow unaligned ops for unordered load/store, so reject + // anything where the alignment isn't at least the element size. + if (Align < StoreSize) + return false; + + // If the element.atomic memcpy is not lowered into explicit + // loads/stores later, then it will be lowered into an element-size + // specific lib call. If the lib call doesn't exist for our store size, then + // we shouldn't generate the memcpy. + if (StoreSize > TTI->getAtomicMemIntrinsicMaxElementSize()) + return false; + + // Create the call. + // Note that unordered atomic loads/stores are *required* by the spec to + // have an alignment but non-atomic loads/stores may not. + NewCall = Builder.CreateElementUnorderedAtomicMemCpy( + StoreBasePtr, SI->getAlignment(), LoadBasePtr, LI->getAlignment(), + NumBytes, StoreSize); + } + NewCall->setDebugLoc(SI->getDebugLoc()); + + DEBUG(dbgs() << " Formed memcpy: " << *NewCall << "\n" + << " from load ptr=" << *LoadEv << " at: " << *LI << "\n" + << " from store ptr=" << *StoreEv << " at: " << *SI << "\n"); + + // Okay, the memcpy has been formed. Zap the original store and anything that + // feeds into it. + deleteDeadInstruction(SI); + ++NumMemCpy; + return true; +} + +// When compiling for codesize we avoid idiom recognition for a multi-block loop +// unless it is a loop_memset idiom or a memset/memcpy idiom in a nested loop. +// +bool LoopIdiomRecognize::avoidLIRForMultiBlockLoop(bool IsMemset, + bool IsLoopMemset) { + if (ApplyCodeSizeHeuristics && CurLoop->getNumBlocks() > 1) { + if (!CurLoop->getParentLoop() && (!IsMemset || !IsLoopMemset)) { + DEBUG(dbgs() << " " << CurLoop->getHeader()->getParent()->getName() + << " : LIR " << (IsMemset ? "Memset" : "Memcpy") + << " avoided: multi-block top-level loop\n"); + return true; + } + } + + return false; +} + +bool LoopIdiomRecognize::runOnNoncountableLoop() { + return recognizePopcount() || recognizeAndInsertCTLZ(); +} + +/// Check if the given conditional branch is based on the comparison between +/// a variable and zero, and if the variable is non-zero, the control yields to +/// the loop entry. If the branch matches the behavior, the variable involved +/// in the comparison is returned. This function will be called to see if the +/// precondition and postcondition of the loop are in desirable form. +static Value *matchCondition(BranchInst *BI, BasicBlock *LoopEntry) { + if (!BI || !BI->isConditional()) + return nullptr; + + ICmpInst *Cond = dyn_cast<ICmpInst>(BI->getCondition()); + if (!Cond) + return nullptr; + + ConstantInt *CmpZero = dyn_cast<ConstantInt>(Cond->getOperand(1)); + if (!CmpZero || !CmpZero->isZero()) + return nullptr; + + ICmpInst::Predicate Pred = Cond->getPredicate(); + if ((Pred == ICmpInst::ICMP_NE && BI->getSuccessor(0) == LoopEntry) || + (Pred == ICmpInst::ICMP_EQ && BI->getSuccessor(1) == LoopEntry)) + return Cond->getOperand(0); + + return nullptr; +} + +// Check if the recurrence variable `VarX` is in the right form to create +// the idiom. Returns the value coerced to a PHINode if so. +static PHINode *getRecurrenceVar(Value *VarX, Instruction *DefX, + BasicBlock *LoopEntry) { + auto *PhiX = dyn_cast<PHINode>(VarX); + if (PhiX && PhiX->getParent() == LoopEntry && + (PhiX->getOperand(0) == DefX || PhiX->getOperand(1) == DefX)) + return PhiX; + return nullptr; +} + +/// Return true iff the idiom is detected in the loop. +/// +/// Additionally: +/// 1) \p CntInst is set to the instruction counting the population bit. +/// 2) \p CntPhi is set to the corresponding phi node. +/// 3) \p Var is set to the value whose population bits are being counted. +/// +/// The core idiom we are trying to detect is: +/// \code +/// if (x0 != 0) +/// goto loop-exit // the precondition of the loop +/// cnt0 = init-val; +/// do { +/// x1 = phi (x0, x2); +/// cnt1 = phi(cnt0, cnt2); +/// +/// cnt2 = cnt1 + 1; +/// ... +/// x2 = x1 & (x1 - 1); +/// ... +/// } while(x != 0); +/// +/// loop-exit: +/// \endcode +static bool detectPopcountIdiom(Loop *CurLoop, BasicBlock *PreCondBB, + Instruction *&CntInst, PHINode *&CntPhi, + Value *&Var) { + // step 1: Check to see if the look-back branch match this pattern: + // "if (a!=0) goto loop-entry". + BasicBlock *LoopEntry; + Instruction *DefX2, *CountInst; + Value *VarX1, *VarX0; + PHINode *PhiX, *CountPhi; + + DefX2 = CountInst = nullptr; + VarX1 = VarX0 = nullptr; + PhiX = CountPhi = nullptr; + LoopEntry = *(CurLoop->block_begin()); + + // step 1: Check if the loop-back branch is in desirable form. + { + if (Value *T = matchCondition( + dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry)) + DefX2 = dyn_cast<Instruction>(T); + else + return false; + } + + // step 2: detect instructions corresponding to "x2 = x1 & (x1 - 1)" + { + if (!DefX2 || DefX2->getOpcode() != Instruction::And) + return false; + + BinaryOperator *SubOneOp; + + if ((SubOneOp = dyn_cast<BinaryOperator>(DefX2->getOperand(0)))) + VarX1 = DefX2->getOperand(1); + else { + VarX1 = DefX2->getOperand(0); + SubOneOp = dyn_cast<BinaryOperator>(DefX2->getOperand(1)); + } + if (!SubOneOp) + return false; + + Instruction *SubInst = cast<Instruction>(SubOneOp); + ConstantInt *Dec = dyn_cast<ConstantInt>(SubInst->getOperand(1)); + if (!Dec || + !((SubInst->getOpcode() == Instruction::Sub && Dec->isOne()) || + (SubInst->getOpcode() == Instruction::Add && + Dec->isMinusOne()))) { + return false; + } + } + + // step 3: Check the recurrence of variable X + PhiX = getRecurrenceVar(VarX1, DefX2, LoopEntry); + if (!PhiX) + return false; + + // step 4: Find the instruction which count the population: cnt2 = cnt1 + 1 + { + CountInst = nullptr; + for (BasicBlock::iterator Iter = LoopEntry->getFirstNonPHI()->getIterator(), + IterE = LoopEntry->end(); + Iter != IterE; Iter++) { + Instruction *Inst = &*Iter; + if (Inst->getOpcode() != Instruction::Add) + continue; + + ConstantInt *Inc = dyn_cast<ConstantInt>(Inst->getOperand(1)); + if (!Inc || !Inc->isOne()) + continue; + + PHINode *Phi = getRecurrenceVar(Inst->getOperand(0), Inst, LoopEntry); + if (!Phi) + continue; + + // Check if the result of the instruction is live of the loop. + bool LiveOutLoop = false; + for (User *U : Inst->users()) { + if ((cast<Instruction>(U))->getParent() != LoopEntry) { + LiveOutLoop = true; + break; + } + } + + if (LiveOutLoop) { + CountInst = Inst; + CountPhi = Phi; + break; + } + } + + if (!CountInst) + return false; + } + + // step 5: check if the precondition is in this form: + // "if (x != 0) goto loop-head ; else goto somewhere-we-don't-care;" + { + auto *PreCondBr = dyn_cast<BranchInst>(PreCondBB->getTerminator()); + Value *T = matchCondition(PreCondBr, CurLoop->getLoopPreheader()); + if (T != PhiX->getOperand(0) && T != PhiX->getOperand(1)) + return false; + + CntInst = CountInst; + CntPhi = CountPhi; + Var = T; + } + + return true; +} + +/// Return true if the idiom is detected in the loop. +/// +/// Additionally: +/// 1) \p CntInst is set to the instruction Counting Leading Zeros (CTLZ) +/// or nullptr if there is no such. +/// 2) \p CntPhi is set to the corresponding phi node +/// or nullptr if there is no such. +/// 3) \p Var is set to the value whose CTLZ could be used. +/// 4) \p DefX is set to the instruction calculating Loop exit condition. +/// +/// The core idiom we are trying to detect is: +/// \code +/// if (x0 == 0) +/// goto loop-exit // the precondition of the loop +/// cnt0 = init-val; +/// do { +/// x = phi (x0, x.next); //PhiX +/// cnt = phi(cnt0, cnt.next); +/// +/// cnt.next = cnt + 1; +/// ... +/// x.next = x >> 1; // DefX +/// ... +/// } while(x.next != 0); +/// +/// loop-exit: +/// \endcode +static bool detectCTLZIdiom(Loop *CurLoop, PHINode *&PhiX, + Instruction *&CntInst, PHINode *&CntPhi, + Instruction *&DefX) { + BasicBlock *LoopEntry; + Value *VarX = nullptr; + + DefX = nullptr; + PhiX = nullptr; + CntInst = nullptr; + CntPhi = nullptr; + LoopEntry = *(CurLoop->block_begin()); + + // step 1: Check if the loop-back branch is in desirable form. + if (Value *T = matchCondition( + dyn_cast<BranchInst>(LoopEntry->getTerminator()), LoopEntry)) + DefX = dyn_cast<Instruction>(T); + else + return false; + + // step 2: detect instructions corresponding to "x.next = x >> 1" + if (!DefX || DefX->getOpcode() != Instruction::AShr) + return false; + ConstantInt *Shft = dyn_cast<ConstantInt>(DefX->getOperand(1)); + if (!Shft || !Shft->isOne()) + return false; + VarX = DefX->getOperand(0); + + // step 3: Check the recurrence of variable X + PhiX = getRecurrenceVar(VarX, DefX, LoopEntry); + if (!PhiX) + return false; + + // step 4: Find the instruction which count the CTLZ: cnt.next = cnt + 1 + // TODO: We can skip the step. If loop trip count is known (CTLZ), + // then all uses of "cnt.next" could be optimized to the trip count + // plus "cnt0". Currently it is not optimized. + // This step could be used to detect POPCNT instruction: + // cnt.next = cnt + (x.next & 1) + for (BasicBlock::iterator Iter = LoopEntry->getFirstNonPHI()->getIterator(), + IterE = LoopEntry->end(); + Iter != IterE; Iter++) { + Instruction *Inst = &*Iter; + if (Inst->getOpcode() != Instruction::Add) + continue; + + ConstantInt *Inc = dyn_cast<ConstantInt>(Inst->getOperand(1)); + if (!Inc || !Inc->isOne()) + continue; + + PHINode *Phi = getRecurrenceVar(Inst->getOperand(0), Inst, LoopEntry); + if (!Phi) + continue; + + CntInst = Inst; + CntPhi = Phi; + break; + } + if (!CntInst) + return false; + + return true; +} + +/// Recognize CTLZ idiom in a non-countable loop and convert the loop +/// to countable (with CTLZ trip count). +/// If CTLZ inserted as a new trip count returns true; otherwise, returns false. +bool LoopIdiomRecognize::recognizeAndInsertCTLZ() { + // Give up if the loop has multiple blocks or multiple backedges. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) + return false; + + Instruction *CntInst, *DefX; + PHINode *CntPhi, *PhiX; + if (!detectCTLZIdiom(CurLoop, PhiX, CntInst, CntPhi, DefX)) + return false; + + bool IsCntPhiUsedOutsideLoop = false; + for (User *U : CntPhi->users()) + if (!CurLoop->contains(dyn_cast<Instruction>(U))) { + IsCntPhiUsedOutsideLoop = true; + break; + } + bool IsCntInstUsedOutsideLoop = false; + for (User *U : CntInst->users()) + if (!CurLoop->contains(dyn_cast<Instruction>(U))) { + IsCntInstUsedOutsideLoop = true; + break; + } + // If both CntInst and CntPhi are used outside the loop the profitability + // is questionable. + if (IsCntInstUsedOutsideLoop && IsCntPhiUsedOutsideLoop) + return false; + + // For some CPUs result of CTLZ(X) intrinsic is undefined + // when X is 0. If we can not guarantee X != 0, we need to check this + // when expand. + bool ZeroCheck = false; + // It is safe to assume Preheader exist as it was checked in + // parent function RunOnLoop. + BasicBlock *PH = CurLoop->getLoopPreheader(); + Value *InitX = PhiX->getIncomingValueForBlock(PH); + // If we check X != 0 before entering the loop we don't need a zero + // check in CTLZ intrinsic, but only if Cnt Phi is not used outside of the + // loop (if it is used we count CTLZ(X >> 1)). + if (!IsCntPhiUsedOutsideLoop) + if (BasicBlock *PreCondBB = PH->getSinglePredecessor()) + if (BranchInst *PreCondBr = + dyn_cast<BranchInst>(PreCondBB->getTerminator())) { + if (matchCondition(PreCondBr, PH) == InitX) + ZeroCheck = true; + } + + // Check if CTLZ intrinsic is profitable. Assume it is always profitable + // if we delete the loop (the loop has only 6 instructions): + // %n.addr.0 = phi [ %n, %entry ], [ %shr, %while.cond ] + // %i.0 = phi [ %i0, %entry ], [ %inc, %while.cond ] + // %shr = ashr %n.addr.0, 1 + // %tobool = icmp eq %shr, 0 + // %inc = add nsw %i.0, 1 + // br i1 %tobool + + IRBuilder<> Builder(PH->getTerminator()); + SmallVector<const Value *, 2> Ops = + {InitX, ZeroCheck ? Builder.getTrue() : Builder.getFalse()}; + ArrayRef<const Value *> Args(Ops); + if (CurLoop->getHeader()->size() != 6 && + TTI->getIntrinsicCost(Intrinsic::ctlz, InitX->getType(), Args) > + TargetTransformInfo::TCC_Basic) + return false; + + const DebugLoc DL = DefX->getDebugLoc(); + transformLoopToCountable(PH, CntInst, CntPhi, InitX, DL, ZeroCheck, + IsCntPhiUsedOutsideLoop); + return true; +} + +/// Recognizes a population count idiom in a non-countable loop. +/// +/// If detected, transforms the relevant code to issue the popcount intrinsic +/// function call, and returns true; otherwise, returns false. +bool LoopIdiomRecognize::recognizePopcount() { + if (TTI->getPopcntSupport(32) != TargetTransformInfo::PSK_FastHardware) + return false; + + // Counting population are usually conducted by few arithmetic instructions. + // Such instructions can be easily "absorbed" by vacant slots in a + // non-compact loop. Therefore, recognizing popcount idiom only makes sense + // in a compact loop. + + // Give up if the loop has multiple blocks or multiple backedges. + if (CurLoop->getNumBackEdges() != 1 || CurLoop->getNumBlocks() != 1) + return false; + + BasicBlock *LoopBody = *(CurLoop->block_begin()); + if (LoopBody->size() >= 20) { + // The loop is too big, bail out. + return false; + } + + // It should have a preheader containing nothing but an unconditional branch. + BasicBlock *PH = CurLoop->getLoopPreheader(); + if (!PH || &PH->front() != PH->getTerminator()) + return false; + auto *EntryBI = dyn_cast<BranchInst>(PH->getTerminator()); + if (!EntryBI || EntryBI->isConditional()) + return false; + + // It should have a precondition block where the generated popcount instrinsic + // function can be inserted. + auto *PreCondBB = PH->getSinglePredecessor(); + if (!PreCondBB) + return false; + auto *PreCondBI = dyn_cast<BranchInst>(PreCondBB->getTerminator()); + if (!PreCondBI || PreCondBI->isUnconditional()) + return false; + + Instruction *CntInst; + PHINode *CntPhi; + Value *Val; + if (!detectPopcountIdiom(CurLoop, PreCondBB, CntInst, CntPhi, Val)) + return false; + + transformLoopToPopcount(PreCondBB, CntInst, CntPhi, Val); + return true; +} + +static CallInst *createPopcntIntrinsic(IRBuilder<> &IRBuilder, Value *Val, + const DebugLoc &DL) { + Value *Ops[] = {Val}; + Type *Tys[] = {Val->getType()}; + + Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent(); + Value *Func = Intrinsic::getDeclaration(M, Intrinsic::ctpop, Tys); + CallInst *CI = IRBuilder.CreateCall(Func, Ops); + CI->setDebugLoc(DL); + + return CI; +} + +static CallInst *createCTLZIntrinsic(IRBuilder<> &IRBuilder, Value *Val, + const DebugLoc &DL, bool ZeroCheck) { + Value *Ops[] = {Val, ZeroCheck ? IRBuilder.getTrue() : IRBuilder.getFalse()}; + Type *Tys[] = {Val->getType()}; + + Module *M = IRBuilder.GetInsertBlock()->getParent()->getParent(); + Value *Func = Intrinsic::getDeclaration(M, Intrinsic::ctlz, Tys); + CallInst *CI = IRBuilder.CreateCall(Func, Ops); + CI->setDebugLoc(DL); + + return CI; +} + +/// Transform the following loop: +/// loop: +/// CntPhi = PHI [Cnt0, CntInst] +/// PhiX = PHI [InitX, DefX] +/// CntInst = CntPhi + 1 +/// DefX = PhiX >> 1 +/// LOOP_BODY +/// Br: loop if (DefX != 0) +/// Use(CntPhi) or Use(CntInst) +/// +/// Into: +/// If CntPhi used outside the loop: +/// CountPrev = BitWidth(InitX) - CTLZ(InitX >> 1) +/// Count = CountPrev + 1 +/// else +/// Count = BitWidth(InitX) - CTLZ(InitX) +/// loop: +/// CntPhi = PHI [Cnt0, CntInst] +/// PhiX = PHI [InitX, DefX] +/// PhiCount = PHI [Count, Dec] +/// CntInst = CntPhi + 1 +/// DefX = PhiX >> 1 +/// Dec = PhiCount - 1 +/// LOOP_BODY +/// Br: loop if (Dec != 0) +/// Use(CountPrev + Cnt0) // Use(CntPhi) +/// or +/// Use(Count + Cnt0) // Use(CntInst) +/// +/// If LOOP_BODY is empty the loop will be deleted. +/// If CntInst and DefX are not used in LOOP_BODY they will be removed. +void LoopIdiomRecognize::transformLoopToCountable( + BasicBlock *Preheader, Instruction *CntInst, PHINode *CntPhi, Value *InitX, + const DebugLoc DL, bool ZeroCheck, bool IsCntPhiUsedOutsideLoop) { + BranchInst *PreheaderBr = dyn_cast<BranchInst>(Preheader->getTerminator()); + + // Step 1: Insert the CTLZ instruction at the end of the preheader block + // Count = BitWidth - CTLZ(InitX); + // If there are uses of CntPhi create: + // CountPrev = BitWidth - CTLZ(InitX >> 1); + IRBuilder<> Builder(PreheaderBr); + Builder.SetCurrentDebugLocation(DL); + Value *CTLZ, *Count, *CountPrev, *NewCount, *InitXNext; + + if (IsCntPhiUsedOutsideLoop) + InitXNext = Builder.CreateAShr(InitX, + ConstantInt::get(InitX->getType(), 1)); + else + InitXNext = InitX; + CTLZ = createCTLZIntrinsic(Builder, InitXNext, DL, ZeroCheck); + Count = Builder.CreateSub( + ConstantInt::get(CTLZ->getType(), + CTLZ->getType()->getIntegerBitWidth()), + CTLZ); + if (IsCntPhiUsedOutsideLoop) { + CountPrev = Count; + Count = Builder.CreateAdd( + CountPrev, + ConstantInt::get(CountPrev->getType(), 1)); + } + if (IsCntPhiUsedOutsideLoop) + NewCount = Builder.CreateZExtOrTrunc(CountPrev, + cast<IntegerType>(CntInst->getType())); + else + NewCount = Builder.CreateZExtOrTrunc(Count, + cast<IntegerType>(CntInst->getType())); + + // If the CTLZ counter's initial value is not zero, insert Add Inst. + Value *CntInitVal = CntPhi->getIncomingValueForBlock(Preheader); + ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal); + if (!InitConst || !InitConst->isZero()) + NewCount = Builder.CreateAdd(NewCount, CntInitVal); + + // Step 2: Insert new IV and loop condition: + // loop: + // ... + // PhiCount = PHI [Count, Dec] + // ... + // Dec = PhiCount - 1 + // ... + // Br: loop if (Dec != 0) + BasicBlock *Body = *(CurLoop->block_begin()); + auto *LbBr = dyn_cast<BranchInst>(Body->getTerminator()); + ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition()); + Type *Ty = Count->getType(); + + PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi", &Body->front()); + + Builder.SetInsertPoint(LbCond); + Instruction *TcDec = cast<Instruction>( + Builder.CreateSub(TcPhi, ConstantInt::get(Ty, 1), + "tcdec", false, true)); + + TcPhi->addIncoming(Count, Preheader); + TcPhi->addIncoming(TcDec, Body); + + CmpInst::Predicate Pred = + (LbBr->getSuccessor(0) == Body) ? CmpInst::ICMP_NE : CmpInst::ICMP_EQ; + LbCond->setPredicate(Pred); + LbCond->setOperand(0, TcDec); + LbCond->setOperand(1, ConstantInt::get(Ty, 0)); + + // Step 3: All the references to the original counter outside + // the loop are replaced with the NewCount -- the value returned from + // __builtin_ctlz(x). + if (IsCntPhiUsedOutsideLoop) + CntPhi->replaceUsesOutsideBlock(NewCount, Body); + else + CntInst->replaceUsesOutsideBlock(NewCount, Body); + + // step 4: Forget the "non-computable" trip-count SCEV associated with the + // loop. The loop would otherwise not be deleted even if it becomes empty. + SE->forgetLoop(CurLoop); +} + +void LoopIdiomRecognize::transformLoopToPopcount(BasicBlock *PreCondBB, + Instruction *CntInst, + PHINode *CntPhi, Value *Var) { + BasicBlock *PreHead = CurLoop->getLoopPreheader(); + auto *PreCondBr = dyn_cast<BranchInst>(PreCondBB->getTerminator()); + const DebugLoc DL = CntInst->getDebugLoc(); + + // Assuming before transformation, the loop is following: + // if (x) // the precondition + // do { cnt++; x &= x - 1; } while(x); + + // Step 1: Insert the ctpop instruction at the end of the precondition block + IRBuilder<> Builder(PreCondBr); + Value *PopCnt, *PopCntZext, *NewCount, *TripCnt; + { + PopCnt = createPopcntIntrinsic(Builder, Var, DL); + NewCount = PopCntZext = + Builder.CreateZExtOrTrunc(PopCnt, cast<IntegerType>(CntPhi->getType())); + + if (NewCount != PopCnt) + (cast<Instruction>(NewCount))->setDebugLoc(DL); + + // TripCnt is exactly the number of iterations the loop has + TripCnt = NewCount; + + // If the population counter's initial value is not zero, insert Add Inst. + Value *CntInitVal = CntPhi->getIncomingValueForBlock(PreHead); + ConstantInt *InitConst = dyn_cast<ConstantInt>(CntInitVal); + if (!InitConst || !InitConst->isZero()) { + NewCount = Builder.CreateAdd(NewCount, CntInitVal); + (cast<Instruction>(NewCount))->setDebugLoc(DL); + } + } + + // Step 2: Replace the precondition from "if (x == 0) goto loop-exit" to + // "if (NewCount == 0) loop-exit". Without this change, the intrinsic + // function would be partial dead code, and downstream passes will drag + // it back from the precondition block to the preheader. + { + ICmpInst *PreCond = cast<ICmpInst>(PreCondBr->getCondition()); + + Value *Opnd0 = PopCntZext; + Value *Opnd1 = ConstantInt::get(PopCntZext->getType(), 0); + if (PreCond->getOperand(0) != Var) + std::swap(Opnd0, Opnd1); + + ICmpInst *NewPreCond = cast<ICmpInst>( + Builder.CreateICmp(PreCond->getPredicate(), Opnd0, Opnd1)); + PreCondBr->setCondition(NewPreCond); + + RecursivelyDeleteTriviallyDeadInstructions(PreCond, TLI); + } + + // Step 3: Note that the population count is exactly the trip count of the + // loop in question, which enable us to to convert the loop from noncountable + // loop into a countable one. The benefit is twofold: + // + // - If the loop only counts population, the entire loop becomes dead after + // the transformation. It is a lot easier to prove a countable loop dead + // than to prove a noncountable one. (In some C dialects, an infinite loop + // isn't dead even if it computes nothing useful. In general, DCE needs + // to prove a noncountable loop finite before safely delete it.) + // + // - If the loop also performs something else, it remains alive. + // Since it is transformed to countable form, it can be aggressively + // optimized by some optimizations which are in general not applicable + // to a noncountable loop. + // + // After this step, this loop (conceptually) would look like following: + // newcnt = __builtin_ctpop(x); + // t = newcnt; + // if (x) + // do { cnt++; x &= x-1; t--) } while (t > 0); + BasicBlock *Body = *(CurLoop->block_begin()); + { + auto *LbBr = dyn_cast<BranchInst>(Body->getTerminator()); + ICmpInst *LbCond = cast<ICmpInst>(LbBr->getCondition()); + Type *Ty = TripCnt->getType(); + + PHINode *TcPhi = PHINode::Create(Ty, 2, "tcphi", &Body->front()); + + Builder.SetInsertPoint(LbCond); + Instruction *TcDec = cast<Instruction>( + Builder.CreateSub(TcPhi, ConstantInt::get(Ty, 1), + "tcdec", false, true)); + + TcPhi->addIncoming(TripCnt, PreHead); + TcPhi->addIncoming(TcDec, Body); + + CmpInst::Predicate Pred = + (LbBr->getSuccessor(0) == Body) ? CmpInst::ICMP_UGT : CmpInst::ICMP_SLE; + LbCond->setPredicate(Pred); + LbCond->setOperand(0, TcDec); + LbCond->setOperand(1, ConstantInt::get(Ty, 0)); + } + + // Step 4: All the references to the original population counter outside + // the loop are replaced with the NewCount -- the value returned from + // __builtin_ctpop(). + CntInst->replaceUsesOutsideBlock(NewCount, Body); + + // step 5: Forget the "non-computable" trip-count SCEV associated with the + // loop. The loop would otherwise not be deleted even if it becomes empty. + SE->forgetLoop(CurLoop); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp new file mode 100644 index 000000000000..40d468a084d4 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopInstSimplify.cpp @@ -0,0 +1,223 @@ +//===- LoopInstSimplify.cpp - Loop Instruction Simplification Pass --------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs lightweight instruction simplification on loop bodies. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopInstSimplify.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/User.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include <algorithm> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "loop-instsimplify" + +STATISTIC(NumSimplified, "Number of redundant instructions simplified"); + +static bool SimplifyLoopInst(Loop *L, DominatorTree *DT, LoopInfo *LI, + AssumptionCache *AC, + const TargetLibraryInfo *TLI) { + SmallVector<BasicBlock *, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + array_pod_sort(ExitBlocks.begin(), ExitBlocks.end()); + + SmallPtrSet<const Instruction *, 8> S1, S2, *ToSimplify = &S1, *Next = &S2; + + // The bit we are stealing from the pointer represents whether this basic + // block is the header of a subloop, in which case we only process its phis. + using WorklistItem = PointerIntPair<BasicBlock *, 1>; + SmallVector<WorklistItem, 16> VisitStack; + SmallPtrSet<BasicBlock *, 32> Visited; + + bool Changed = false; + bool LocalChanged; + do { + LocalChanged = false; + + VisitStack.clear(); + Visited.clear(); + + VisitStack.push_back(WorklistItem(L->getHeader(), false)); + + while (!VisitStack.empty()) { + WorklistItem Item = VisitStack.pop_back_val(); + BasicBlock *BB = Item.getPointer(); + bool IsSubloopHeader = Item.getInt(); + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + + // Simplify instructions in the current basic block. + for (BasicBlock::iterator BI = BB->begin(), BE = BB->end(); BI != BE;) { + Instruction *I = &*BI++; + + // The first time through the loop ToSimplify is empty and we try to + // simplify all instructions. On later iterations ToSimplify is not + // empty and we only bother simplifying instructions that are in it. + if (!ToSimplify->empty() && !ToSimplify->count(I)) + continue; + + // Don't bother simplifying unused instructions. + if (!I->use_empty()) { + Value *V = SimplifyInstruction(I, {DL, TLI, DT, AC}); + if (V && LI->replacementPreservesLCSSAForm(I, V)) { + // Mark all uses for resimplification next time round the loop. + for (User *U : I->users()) + Next->insert(cast<Instruction>(U)); + + I->replaceAllUsesWith(V); + LocalChanged = true; + ++NumSimplified; + } + } + if (RecursivelyDeleteTriviallyDeadInstructions(I, TLI)) { + // RecursivelyDeleteTriviallyDeadInstruction can remove more than one + // instruction, so simply incrementing the iterator does not work. + // When instructions get deleted re-iterate instead. + BI = BB->begin(); + BE = BB->end(); + LocalChanged = true; + } + + if (IsSubloopHeader && !isa<PHINode>(I)) + break; + } + + // Add all successors to the worklist, except for loop exit blocks and the + // bodies of subloops. We visit the headers of loops so that we can + // process + // their phis, but we contract the rest of the subloop body and only + // follow + // edges leading back to the original loop. + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); SI != SE; + ++SI) { + BasicBlock *SuccBB = *SI; + if (!Visited.insert(SuccBB).second) + continue; + + const Loop *SuccLoop = LI->getLoopFor(SuccBB); + if (SuccLoop && SuccLoop->getHeader() == SuccBB && + L->contains(SuccLoop)) { + VisitStack.push_back(WorklistItem(SuccBB, true)); + + SmallVector<BasicBlock *, 8> SubLoopExitBlocks; + SuccLoop->getExitBlocks(SubLoopExitBlocks); + + for (unsigned i = 0; i < SubLoopExitBlocks.size(); ++i) { + BasicBlock *ExitBB = SubLoopExitBlocks[i]; + if (LI->getLoopFor(ExitBB) == L && Visited.insert(ExitBB).second) + VisitStack.push_back(WorklistItem(ExitBB, false)); + } + + continue; + } + + bool IsExitBlock = + std::binary_search(ExitBlocks.begin(), ExitBlocks.end(), SuccBB); + if (IsExitBlock) + continue; + + VisitStack.push_back(WorklistItem(SuccBB, false)); + } + } + + // Place the list of instructions to simplify on the next loop iteration + // into ToSimplify. + std::swap(ToSimplify, Next); + Next->clear(); + + Changed |= LocalChanged; + } while (LocalChanged); + + return Changed; +} + +namespace { + +class LoopInstSimplifyLegacyPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + + LoopInstSimplifyLegacyPass() : LoopPass(ID) { + initializeLoopInstSimplifyLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + DominatorTreeWrapperPass *DTWP = + getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DominatorTree *DT = DTWP ? &DTWP->getDomTree() : nullptr; + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + AssumptionCache *AC = + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + return SimplifyLoopInst(L, DT, LI, AC, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.setPreservesCFG(); + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +PreservedAnalyses LoopInstSimplifyPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (!SimplifyLoopInst(&L, &AR.DT, &AR.LI, &AR.AC, &AR.TLI)) + return PreservedAnalyses::all(); + + auto PA = getLoopPassPreservedAnalyses(); + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +char LoopInstSimplifyLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopInstSimplifyLegacyPass, "loop-instsimplify", + "Simplify instructions in loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LoopInstSimplifyLegacyPass, "loop-instsimplify", + "Simplify instructions in loops", false, false) + +Pass *llvm::createLoopInstSimplifyPass() { + return new LoopInstSimplifyLegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp new file mode 100644 index 000000000000..4f8dafef230a --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopInterchange.cpp @@ -0,0 +1,1427 @@ +//===- LoopInterchange.cpp - Loop interchange pass-------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This Pass handles loop interchange transform. +// This pass interchanges loops to provide a more cache-friendly memory access +// patterns. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include <cassert> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "loop-interchange" + +static cl::opt<int> LoopInterchangeCostThreshold( + "loop-interchange-threshold", cl::init(0), cl::Hidden, + cl::desc("Interchange if you gain more than this number")); + +namespace { + +using LoopVector = SmallVector<Loop *, 8>; + +// TODO: Check if we can use a sparse matrix here. +using CharMatrix = std::vector<std::vector<char>>; + +} // end anonymous namespace + +// Maximum number of dependencies that can be handled in the dependency matrix. +static const unsigned MaxMemInstrCount = 100; + +// Maximum loop depth supported. +static const unsigned MaxLoopNestDepth = 10; + +#ifdef DUMP_DEP_MATRICIES +static void printDepMatrix(CharMatrix &DepMatrix) { + for (auto &Row : DepMatrix) { + for (auto D : Row) + DEBUG(dbgs() << D << " "); + DEBUG(dbgs() << "\n"); + } +} +#endif + +static bool populateDependencyMatrix(CharMatrix &DepMatrix, unsigned Level, + Loop *L, DependenceInfo *DI) { + using ValueVector = SmallVector<Value *, 16>; + + ValueVector MemInstr; + + // For each block. + for (BasicBlock *BB : L->blocks()) { + // Scan the BB and collect legal loads and stores. + for (Instruction &I : *BB) { + if (!isa<Instruction>(I)) + return false; + if (auto *Ld = dyn_cast<LoadInst>(&I)) { + if (!Ld->isSimple()) + return false; + MemInstr.push_back(&I); + } else if (auto *St = dyn_cast<StoreInst>(&I)) { + if (!St->isSimple()) + return false; + MemInstr.push_back(&I); + } + } + } + + DEBUG(dbgs() << "Found " << MemInstr.size() + << " Loads and Stores to analyze\n"); + + ValueVector::iterator I, IE, J, JE; + + for (I = MemInstr.begin(), IE = MemInstr.end(); I != IE; ++I) { + for (J = I, JE = MemInstr.end(); J != JE; ++J) { + std::vector<char> Dep; + Instruction *Src = cast<Instruction>(*I); + Instruction *Dst = cast<Instruction>(*J); + if (Src == Dst) + continue; + // Ignore Input dependencies. + if (isa<LoadInst>(Src) && isa<LoadInst>(Dst)) + continue; + // Track Output, Flow, and Anti dependencies. + if (auto D = DI->depends(Src, Dst, true)) { + assert(D->isOrdered() && "Expected an output, flow or anti dep."); + DEBUG(StringRef DepType = + D->isFlow() ? "flow" : D->isAnti() ? "anti" : "output"; + dbgs() << "Found " << DepType + << " dependency between Src and Dst\n" + << " Src:" << *Src << "\n Dst:" << *Dst << '\n'); + unsigned Levels = D->getLevels(); + char Direction; + for (unsigned II = 1; II <= Levels; ++II) { + const SCEV *Distance = D->getDistance(II); + const SCEVConstant *SCEVConst = + dyn_cast_or_null<SCEVConstant>(Distance); + if (SCEVConst) { + const ConstantInt *CI = SCEVConst->getValue(); + if (CI->isNegative()) + Direction = '<'; + else if (CI->isZero()) + Direction = '='; + else + Direction = '>'; + Dep.push_back(Direction); + } else if (D->isScalar(II)) { + Direction = 'S'; + Dep.push_back(Direction); + } else { + unsigned Dir = D->getDirection(II); + if (Dir == Dependence::DVEntry::LT || + Dir == Dependence::DVEntry::LE) + Direction = '<'; + else if (Dir == Dependence::DVEntry::GT || + Dir == Dependence::DVEntry::GE) + Direction = '>'; + else if (Dir == Dependence::DVEntry::EQ) + Direction = '='; + else + Direction = '*'; + Dep.push_back(Direction); + } + } + while (Dep.size() != Level) { + Dep.push_back('I'); + } + + DepMatrix.push_back(Dep); + if (DepMatrix.size() > MaxMemInstrCount) { + DEBUG(dbgs() << "Cannot handle more than " << MaxMemInstrCount + << " dependencies inside loop\n"); + return false; + } + } + } + } + + // We don't have a DepMatrix to check legality return false. + if (DepMatrix.empty()) + return false; + return true; +} + +// A loop is moved from index 'from' to an index 'to'. Update the Dependence +// matrix by exchanging the two columns. +static void interChangeDependencies(CharMatrix &DepMatrix, unsigned FromIndx, + unsigned ToIndx) { + unsigned numRows = DepMatrix.size(); + for (unsigned i = 0; i < numRows; ++i) { + char TmpVal = DepMatrix[i][ToIndx]; + DepMatrix[i][ToIndx] = DepMatrix[i][FromIndx]; + DepMatrix[i][FromIndx] = TmpVal; + } +} + +// Checks if outermost non '=','S'or'I' dependence in the dependence matrix is +// '>' +static bool isOuterMostDepPositive(CharMatrix &DepMatrix, unsigned Row, + unsigned Column) { + for (unsigned i = 0; i <= Column; ++i) { + if (DepMatrix[Row][i] == '<') + return false; + if (DepMatrix[Row][i] == '>') + return true; + } + // All dependencies were '=','S' or 'I' + return false; +} + +// Checks if no dependence exist in the dependency matrix in Row before Column. +static bool containsNoDependence(CharMatrix &DepMatrix, unsigned Row, + unsigned Column) { + for (unsigned i = 0; i < Column; ++i) { + if (DepMatrix[Row][i] != '=' && DepMatrix[Row][i] != 'S' && + DepMatrix[Row][i] != 'I') + return false; + } + return true; +} + +static bool validDepInterchange(CharMatrix &DepMatrix, unsigned Row, + unsigned OuterLoopId, char InnerDep, + char OuterDep) { + if (isOuterMostDepPositive(DepMatrix, Row, OuterLoopId)) + return false; + + if (InnerDep == OuterDep) + return true; + + // It is legal to interchange if and only if after interchange no row has a + // '>' direction as the leftmost non-'='. + + if (InnerDep == '=' || InnerDep == 'S' || InnerDep == 'I') + return true; + + if (InnerDep == '<') + return true; + + if (InnerDep == '>') { + // If OuterLoopId represents outermost loop then interchanging will make the + // 1st dependency as '>' + if (OuterLoopId == 0) + return false; + + // If all dependencies before OuterloopId are '=','S'or 'I'. Then + // interchanging will result in this row having an outermost non '=' + // dependency of '>' + if (!containsNoDependence(DepMatrix, Row, OuterLoopId)) + return true; + } + + return false; +} + +// Checks if it is legal to interchange 2 loops. +// [Theorem] A permutation of the loops in a perfect nest is legal if and only +// if the direction matrix, after the same permutation is applied to its +// columns, has no ">" direction as the leftmost non-"=" direction in any row. +static bool isLegalToInterChangeLoops(CharMatrix &DepMatrix, + unsigned InnerLoopId, + unsigned OuterLoopId) { + unsigned NumRows = DepMatrix.size(); + // For each row check if it is valid to interchange. + for (unsigned Row = 0; Row < NumRows; ++Row) { + char InnerDep = DepMatrix[Row][InnerLoopId]; + char OuterDep = DepMatrix[Row][OuterLoopId]; + if (InnerDep == '*' || OuterDep == '*') + return false; + if (!validDepInterchange(DepMatrix, Row, OuterLoopId, InnerDep, OuterDep)) + return false; + } + return true; +} + +static void populateWorklist(Loop &L, SmallVector<LoopVector, 8> &V) { + DEBUG(dbgs() << "Calling populateWorklist on Func: " + << L.getHeader()->getParent()->getName() << " Loop: %" + << L.getHeader()->getName() << '\n'); + LoopVector LoopList; + Loop *CurrentLoop = &L; + const std::vector<Loop *> *Vec = &CurrentLoop->getSubLoops(); + while (!Vec->empty()) { + // The current loop has multiple subloops in it hence it is not tightly + // nested. + // Discard all loops above it added into Worklist. + if (Vec->size() != 1) { + LoopList.clear(); + return; + } + LoopList.push_back(CurrentLoop); + CurrentLoop = Vec->front(); + Vec = &CurrentLoop->getSubLoops(); + } + LoopList.push_back(CurrentLoop); + V.push_back(std::move(LoopList)); +} + +static PHINode *getInductionVariable(Loop *L, ScalarEvolution *SE) { + PHINode *InnerIndexVar = L->getCanonicalInductionVariable(); + if (InnerIndexVar) + return InnerIndexVar; + if (L->getLoopLatch() == nullptr || L->getLoopPredecessor() == nullptr) + return nullptr; + for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { + PHINode *PhiVar = cast<PHINode>(I); + Type *PhiTy = PhiVar->getType(); + if (!PhiTy->isIntegerTy() && !PhiTy->isFloatingPointTy() && + !PhiTy->isPointerTy()) + return nullptr; + const SCEVAddRecExpr *AddRec = + dyn_cast<SCEVAddRecExpr>(SE->getSCEV(PhiVar)); + if (!AddRec || !AddRec->isAffine()) + continue; + const SCEV *Step = AddRec->getStepRecurrence(*SE); + if (!isa<SCEVConstant>(Step)) + continue; + // Found the induction variable. + // FIXME: Handle loops with more than one induction variable. Note that, + // currently, legality makes sure we have only one induction variable. + return PhiVar; + } + return nullptr; +} + +namespace { + +/// LoopInterchangeLegality checks if it is legal to interchange the loop. +class LoopInterchangeLegality { +public: + LoopInterchangeLegality(Loop *Outer, Loop *Inner, ScalarEvolution *SE, + LoopInfo *LI, DominatorTree *DT, bool PreserveLCSSA, + OptimizationRemarkEmitter *ORE) + : OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), + PreserveLCSSA(PreserveLCSSA), ORE(ORE) {} + + /// Check if the loops can be interchanged. + bool canInterchangeLoops(unsigned InnerLoopId, unsigned OuterLoopId, + CharMatrix &DepMatrix); + + /// Check if the loop structure is understood. We do not handle triangular + /// loops for now. + bool isLoopStructureUnderstood(PHINode *InnerInductionVar); + + bool currentLimitations(); + + bool hasInnerLoopReduction() { return InnerLoopHasReduction; } + +private: + bool tightlyNested(Loop *Outer, Loop *Inner); + bool containsUnsafeInstructionsInHeader(BasicBlock *BB); + bool areAllUsesReductions(Instruction *Ins, Loop *L); + bool containsUnsafeInstructionsInLatch(BasicBlock *BB); + bool findInductionAndReductions(Loop *L, + SmallVector<PHINode *, 8> &Inductions, + SmallVector<PHINode *, 8> &Reductions); + + Loop *OuterLoop; + Loop *InnerLoop; + + ScalarEvolution *SE; + LoopInfo *LI; + DominatorTree *DT; + bool PreserveLCSSA; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + + bool InnerLoopHasReduction = false; +}; + +/// LoopInterchangeProfitability checks if it is profitable to interchange the +/// loop. +class LoopInterchangeProfitability { +public: + LoopInterchangeProfitability(Loop *Outer, Loop *Inner, ScalarEvolution *SE, + OptimizationRemarkEmitter *ORE) + : OuterLoop(Outer), InnerLoop(Inner), SE(SE), ORE(ORE) {} + + /// Check if the loop interchange is profitable. + bool isProfitable(unsigned InnerLoopId, unsigned OuterLoopId, + CharMatrix &DepMatrix); + +private: + int getInstrOrderCost(); + + Loop *OuterLoop; + Loop *InnerLoop; + + /// Scev analysis. + ScalarEvolution *SE; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; +}; + +/// LoopInterchangeTransform interchanges the loop. +class LoopInterchangeTransform { +public: + LoopInterchangeTransform(Loop *Outer, Loop *Inner, ScalarEvolution *SE, + LoopInfo *LI, DominatorTree *DT, + BasicBlock *LoopNestExit, + bool InnerLoopContainsReductions) + : OuterLoop(Outer), InnerLoop(Inner), SE(SE), LI(LI), DT(DT), + LoopExit(LoopNestExit), + InnerLoopHasReduction(InnerLoopContainsReductions) {} + + /// Interchange OuterLoop and InnerLoop. + bool transform(); + void restructureLoops(Loop *InnerLoop, Loop *OuterLoop); + void removeChildLoop(Loop *OuterLoop, Loop *InnerLoop); + +private: + void splitInnerLoopLatch(Instruction *); + void splitInnerLoopHeader(); + bool adjustLoopLinks(); + void adjustLoopPreheaders(); + bool adjustLoopBranches(); + void updateIncomingBlock(BasicBlock *CurrBlock, BasicBlock *OldPred, + BasicBlock *NewPred); + + Loop *OuterLoop; + Loop *InnerLoop; + + /// Scev analysis. + ScalarEvolution *SE; + + LoopInfo *LI; + DominatorTree *DT; + BasicBlock *LoopExit; + bool InnerLoopHasReduction; +}; + +// Main LoopInterchange Pass. +struct LoopInterchange : public FunctionPass { + static char ID; + ScalarEvolution *SE = nullptr; + LoopInfo *LI = nullptr; + DependenceInfo *DI = nullptr; + DominatorTree *DT = nullptr; + bool PreserveLCSSA; + + /// Interface to emit optimization remarks. + OptimizationRemarkEmitter *ORE; + + LoopInterchange() : FunctionPass(ID) { + initializeLoopInterchangePass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<DependenceAnalysisWrapperPass>(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequiredID(LCSSAID); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + DI = &getAnalysis<DependenceAnalysisWrapperPass>().getDI(); + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + DT = DTWP ? &DTWP->getDomTree() : nullptr; + ORE = &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE(); + PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); + + // Build up a worklist of loop pairs to analyze. + SmallVector<LoopVector, 8> Worklist; + + for (Loop *L : *LI) + populateWorklist(*L, Worklist); + + DEBUG(dbgs() << "Worklist size = " << Worklist.size() << "\n"); + bool Changed = true; + while (!Worklist.empty()) { + LoopVector LoopList = Worklist.pop_back_val(); + Changed = processLoopList(LoopList, F); + } + return Changed; + } + + bool isComputableLoopNest(LoopVector LoopList) { + for (Loop *L : LoopList) { + const SCEV *ExitCountOuter = SE->getBackedgeTakenCount(L); + if (ExitCountOuter == SE->getCouldNotCompute()) { + DEBUG(dbgs() << "Couldn't compute backedge count\n"); + return false; + } + if (L->getNumBackEdges() != 1) { + DEBUG(dbgs() << "NumBackEdges is not equal to 1\n"); + return false; + } + if (!L->getExitingBlock()) { + DEBUG(dbgs() << "Loop doesn't have unique exit block\n"); + return false; + } + } + return true; + } + + unsigned selectLoopForInterchange(const LoopVector &LoopList) { + // TODO: Add a better heuristic to select the loop to be interchanged based + // on the dependence matrix. Currently we select the innermost loop. + return LoopList.size() - 1; + } + + bool processLoopList(LoopVector LoopList, Function &F) { + bool Changed = false; + unsigned LoopNestDepth = LoopList.size(); + if (LoopNestDepth < 2) { + DEBUG(dbgs() << "Loop doesn't contain minimum nesting level.\n"); + return false; + } + if (LoopNestDepth > MaxLoopNestDepth) { + DEBUG(dbgs() << "Cannot handle loops of depth greater than " + << MaxLoopNestDepth << "\n"); + return false; + } + if (!isComputableLoopNest(LoopList)) { + DEBUG(dbgs() << "Not valid loop candidate for interchange\n"); + return false; + } + + DEBUG(dbgs() << "Processing LoopList of size = " << LoopNestDepth << "\n"); + + CharMatrix DependencyMatrix; + Loop *OuterMostLoop = *(LoopList.begin()); + if (!populateDependencyMatrix(DependencyMatrix, LoopNestDepth, + OuterMostLoop, DI)) { + DEBUG(dbgs() << "Populating dependency matrix failed\n"); + return false; + } +#ifdef DUMP_DEP_MATRICIES + DEBUG(dbgs() << "Dependence before interchange\n"); + printDepMatrix(DependencyMatrix); +#endif + + BasicBlock *OuterMostLoopLatch = OuterMostLoop->getLoopLatch(); + BranchInst *OuterMostLoopLatchBI = + dyn_cast<BranchInst>(OuterMostLoopLatch->getTerminator()); + if (!OuterMostLoopLatchBI) + return false; + + // Since we currently do not handle LCSSA PHI's any failure in loop + // condition will now branch to LoopNestExit. + // TODO: This should be removed once we handle LCSSA PHI nodes. + + // Get the Outermost loop exit. + BasicBlock *LoopNestExit; + if (OuterMostLoopLatchBI->getSuccessor(0) == OuterMostLoop->getHeader()) + LoopNestExit = OuterMostLoopLatchBI->getSuccessor(1); + else + LoopNestExit = OuterMostLoopLatchBI->getSuccessor(0); + + if (isa<PHINode>(LoopNestExit->begin())) { + DEBUG(dbgs() << "PHI Nodes in loop nest exit is not handled for now " + "since on failure all loops branch to loop nest exit.\n"); + return false; + } + + unsigned SelecLoopId = selectLoopForInterchange(LoopList); + // Move the selected loop outwards to the best possible position. + for (unsigned i = SelecLoopId; i > 0; i--) { + bool Interchanged = + processLoop(LoopList, i, i - 1, LoopNestExit, DependencyMatrix); + if (!Interchanged) + return Changed; + // Loops interchanged reflect the same in LoopList + std::swap(LoopList[i - 1], LoopList[i]); + + // Update the DependencyMatrix + interChangeDependencies(DependencyMatrix, i, i - 1); + DT->recalculate(F); +#ifdef DUMP_DEP_MATRICIES + DEBUG(dbgs() << "Dependence after interchange\n"); + printDepMatrix(DependencyMatrix); +#endif + Changed |= Interchanged; + } + return Changed; + } + + bool processLoop(LoopVector LoopList, unsigned InnerLoopId, + unsigned OuterLoopId, BasicBlock *LoopNestExit, + std::vector<std::vector<char>> &DependencyMatrix) { + DEBUG(dbgs() << "Processing Inner Loop Id = " << InnerLoopId + << " and OuterLoopId = " << OuterLoopId << "\n"); + Loop *InnerLoop = LoopList[InnerLoopId]; + Loop *OuterLoop = LoopList[OuterLoopId]; + + LoopInterchangeLegality LIL(OuterLoop, InnerLoop, SE, LI, DT, + PreserveLCSSA, ORE); + if (!LIL.canInterchangeLoops(InnerLoopId, OuterLoopId, DependencyMatrix)) { + DEBUG(dbgs() << "Not interchanging Loops. Cannot prove legality\n"); + return false; + } + DEBUG(dbgs() << "Loops are legal to interchange\n"); + LoopInterchangeProfitability LIP(OuterLoop, InnerLoop, SE, ORE); + if (!LIP.isProfitable(InnerLoopId, OuterLoopId, DependencyMatrix)) { + DEBUG(dbgs() << "Interchanging loops not profitable\n"); + return false; + } + + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "Interchanged", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Loop interchanged with enclosing loop."; + }); + + LoopInterchangeTransform LIT(OuterLoop, InnerLoop, SE, LI, DT, + LoopNestExit, LIL.hasInnerLoopReduction()); + LIT.transform(); + DEBUG(dbgs() << "Loops interchanged\n"); + return true; + } +}; + +} // end anonymous namespace + +bool LoopInterchangeLegality::areAllUsesReductions(Instruction *Ins, Loop *L) { + return llvm::none_of(Ins->users(), [=](User *U) -> bool { + auto *UserIns = dyn_cast<PHINode>(U); + RecurrenceDescriptor RD; + return !UserIns || !RecurrenceDescriptor::isReductionPHI(UserIns, L, RD); + }); +} + +bool LoopInterchangeLegality::containsUnsafeInstructionsInHeader( + BasicBlock *BB) { + for (auto I = BB->begin(), E = BB->end(); I != E; ++I) { + // Load corresponding to reduction PHI's are safe while concluding if + // tightly nested. + if (LoadInst *L = dyn_cast<LoadInst>(I)) { + if (!areAllUsesReductions(L, InnerLoop)) + return true; + } else if (I->mayHaveSideEffects() || I->mayReadFromMemory()) + return true; + } + return false; +} + +bool LoopInterchangeLegality::containsUnsafeInstructionsInLatch( + BasicBlock *BB) { + for (auto I = BB->begin(), E = BB->end(); I != E; ++I) { + // Stores corresponding to reductions are safe while concluding if tightly + // nested. + if (StoreInst *L = dyn_cast<StoreInst>(I)) { + if (!isa<PHINode>(L->getOperand(0))) + return true; + } else if (I->mayHaveSideEffects() || I->mayReadFromMemory()) + return true; + } + return false; +} + +bool LoopInterchangeLegality::tightlyNested(Loop *OuterLoop, Loop *InnerLoop) { + BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); + + DEBUG(dbgs() << "Checking if loops are tightly nested\n"); + + // A perfectly nested loop will not have any branch in between the outer and + // inner block i.e. outer header will branch to either inner preheader and + // outerloop latch. + BranchInst *OuterLoopHeaderBI = + dyn_cast<BranchInst>(OuterLoopHeader->getTerminator()); + if (!OuterLoopHeaderBI) + return false; + + for (BasicBlock *Succ : OuterLoopHeaderBI->successors()) + if (Succ != InnerLoopPreHeader && Succ != OuterLoopLatch) + return false; + + DEBUG(dbgs() << "Checking instructions in Loop header and Loop latch\n"); + // We do not have any basic block in between now make sure the outer header + // and outer loop latch doesn't contain any unsafe instructions. + if (containsUnsafeInstructionsInHeader(OuterLoopHeader) || + containsUnsafeInstructionsInLatch(OuterLoopLatch)) + return false; + + DEBUG(dbgs() << "Loops are perfectly nested\n"); + // We have a perfect loop nest. + return true; +} + +bool LoopInterchangeLegality::isLoopStructureUnderstood( + PHINode *InnerInduction) { + unsigned Num = InnerInduction->getNumOperands(); + BasicBlock *InnerLoopPreheader = InnerLoop->getLoopPreheader(); + for (unsigned i = 0; i < Num; ++i) { + Value *Val = InnerInduction->getOperand(i); + if (isa<Constant>(Val)) + continue; + Instruction *I = dyn_cast<Instruction>(Val); + if (!I) + return false; + // TODO: Handle triangular loops. + // e.g. for(int i=0;i<N;i++) + // for(int j=i;j<N;j++) + unsigned IncomBlockIndx = PHINode::getIncomingValueNumForOperand(i); + if (InnerInduction->getIncomingBlock(IncomBlockIndx) == + InnerLoopPreheader && + !OuterLoop->isLoopInvariant(I)) { + return false; + } + } + return true; +} + +bool LoopInterchangeLegality::findInductionAndReductions( + Loop *L, SmallVector<PHINode *, 8> &Inductions, + SmallVector<PHINode *, 8> &Reductions) { + if (!L->getLoopLatch() || !L->getLoopPredecessor()) + return false; + for (BasicBlock::iterator I = L->getHeader()->begin(); isa<PHINode>(I); ++I) { + RecurrenceDescriptor RD; + InductionDescriptor ID; + PHINode *PHI = cast<PHINode>(I); + if (InductionDescriptor::isInductionPHI(PHI, L, SE, ID)) + Inductions.push_back(PHI); + else if (RecurrenceDescriptor::isReductionPHI(PHI, L, RD)) + Reductions.push_back(PHI); + else { + DEBUG( + dbgs() << "Failed to recognize PHI as an induction or reduction.\n"); + return false; + } + } + return true; +} + +static bool containsSafePHI(BasicBlock *Block, bool isOuterLoopExitBlock) { + for (auto I = Block->begin(); isa<PHINode>(I); ++I) { + PHINode *PHI = cast<PHINode>(I); + // Reduction lcssa phi will have only 1 incoming block that from loop latch. + if (PHI->getNumIncomingValues() > 1) + return false; + Instruction *Ins = dyn_cast<Instruction>(PHI->getIncomingValue(0)); + if (!Ins) + return false; + // Incoming value for lcssa phi's in outer loop exit can only be inner loop + // exits lcssa phi else it would not be tightly nested. + if (!isa<PHINode>(Ins) && isOuterLoopExitBlock) + return false; + } + return true; +} + +static BasicBlock *getLoopLatchExitBlock(BasicBlock *LatchBlock, + BasicBlock *LoopHeader) { + if (BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator())) { + assert(BI->getNumSuccessors() == 2 && + "Branch leaving loop latch must have 2 successors"); + for (BasicBlock *Succ : BI->successors()) { + if (Succ == LoopHeader) + continue; + return Succ; + } + } + return nullptr; +} + +// This function indicates the current limitations in the transform as a result +// of which we do not proceed. +bool LoopInterchangeLegality::currentLimitations() { + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); + BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); + BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); + + PHINode *InnerInductionVar; + SmallVector<PHINode *, 8> Inductions; + SmallVector<PHINode *, 8> Reductions; + if (!findInductionAndReductions(InnerLoop, Inductions, Reductions)) { + DEBUG(dbgs() << "Only inner loops with induction or reduction PHI nodes " + << "are supported currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with induction or reduction PHI nodes can be" + " interchange currently."; + }); + return true; + } + + // TODO: Currently we handle only loops with 1 induction variable. + if (Inductions.size() != 1) { + DEBUG(dbgs() << "We currently only support loops with 1 induction variable." + << "Failed to interchange due to current limitation\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "MultiInductionInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with 1 induction variable can be " + "interchanged currently."; + }); + return true; + } + if (Reductions.size() > 0) + InnerLoopHasReduction = true; + + InnerInductionVar = Inductions.pop_back_val(); + Reductions.clear(); + if (!findInductionAndReductions(OuterLoop, Inductions, Reductions)) { + DEBUG(dbgs() << "Only outer loops with induction or reduction PHI nodes " + << "are supported currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedPHIOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with induction or reduction PHI nodes can be" + " interchanged currently."; + }); + return true; + } + + // Outer loop cannot have reduction because then loops will not be tightly + // nested. + if (!Reductions.empty()) { + DEBUG(dbgs() << "Outer loops with reductions are not supported " + << "currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "ReductionsOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Outer loops with reductions cannot be interchangeed " + "currently."; + }); + return true; + } + // TODO: Currently we handle only loops with 1 induction variable. + if (Inductions.size() != 1) { + DEBUG(dbgs() << "Loops with more than 1 induction variables are not " + << "supported currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "MultiIndutionOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with 1 induction variable can be " + "interchanged currently."; + }); + return true; + } + + // TODO: Triangular loops are not handled for now. + if (!isLoopStructureUnderstood(InnerInductionVar)) { + DEBUG(dbgs() << "Loop structure not understood by pass\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "UnsupportedStructureInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Inner loop structure not understood currently."; + }); + return true; + } + + // TODO: We only handle LCSSA PHI's corresponding to reduction for now. + BasicBlock *LoopExitBlock = + getLoopLatchExitBlock(OuterLoopLatch, OuterLoopHeader); + if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, true)) { + DEBUG(dbgs() << "Can only handle LCSSA PHIs in outer loops currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoLCSSAPHIOuter", + OuterLoop->getStartLoc(), + OuterLoop->getHeader()) + << "Only outer loops with LCSSA PHIs can be interchange " + "currently."; + }); + return true; + } + + LoopExitBlock = getLoopLatchExitBlock(InnerLoopLatch, InnerLoopHeader); + if (!LoopExitBlock || !containsSafePHI(LoopExitBlock, false)) { + DEBUG(dbgs() << "Can only handle LCSSA PHIs in inner loops currently.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoLCSSAPHIOuterInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Only inner loops with LCSSA PHIs can be interchange " + "currently."; + }); + return true; + } + + // TODO: Current limitation: Since we split the inner loop latch at the point + // were induction variable is incremented (induction.next); We cannot have + // more than 1 user of induction.next since it would result in broken code + // after split. + // e.g. + // for(i=0;i<N;i++) { + // for(j = 0;j<M;j++) { + // A[j+1][i+2] = A[j][i]+k; + // } + // } + Instruction *InnerIndexVarInc = nullptr; + if (InnerInductionVar->getIncomingBlock(0) == InnerLoopPreHeader) + InnerIndexVarInc = + dyn_cast<Instruction>(InnerInductionVar->getIncomingValue(1)); + else + InnerIndexVarInc = + dyn_cast<Instruction>(InnerInductionVar->getIncomingValue(0)); + + if (!InnerIndexVarInc) { + DEBUG(dbgs() << "Did not find an instruction to increment the induction " + << "variable.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoIncrementInInner", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "The inner loop does not increment the induction variable."; + }); + return true; + } + + // Since we split the inner loop latch on this induction variable. Make sure + // we do not have any instruction between the induction variable and branch + // instruction. + + bool FoundInduction = false; + for (const Instruction &I : llvm::reverse(*InnerLoopLatch)) { + if (isa<BranchInst>(I) || isa<CmpInst>(I) || isa<TruncInst>(I) || + isa<ZExtInst>(I)) + continue; + + // We found an instruction. If this is not induction variable then it is not + // safe to split this loop latch. + if (!I.isIdenticalTo(InnerIndexVarInc)) { + DEBUG(dbgs() << "Found unsupported instructions between induction " + << "variable increment and branch.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "UnsupportedInsBetweenInduction", + InnerLoop->getStartLoc(), InnerLoop->getHeader()) + << "Found unsupported instruction between induction variable " + "increment and branch."; + }); + return true; + } + + FoundInduction = true; + break; + } + // The loop latch ended and we didn't find the induction variable return as + // current limitation. + if (!FoundInduction) { + DEBUG(dbgs() << "Did not find the induction variable.\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NoIndutionVariable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Did not find the induction variable."; + }); + return true; + } + return false; +} + +bool LoopInterchangeLegality::canInterchangeLoops(unsigned InnerLoopId, + unsigned OuterLoopId, + CharMatrix &DepMatrix) { + if (!isLegalToInterChangeLoops(DepMatrix, InnerLoopId, OuterLoopId)) { + DEBUG(dbgs() << "Failed interchange InnerLoopId = " << InnerLoopId + << " and OuterLoopId = " << OuterLoopId + << " due to dependence\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "Dependence", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Cannot interchange loops due to dependences."; + }); + return false; + } + + // Check if outer and inner loop contain legal instructions only. + for (auto *BB : OuterLoop->blocks()) + for (Instruction &I : *BB) + if (CallInst *CI = dyn_cast<CallInst>(&I)) { + // readnone functions do not prevent interchanging. + if (CI->doesNotReadMemory()) + continue; + DEBUG(dbgs() << "Loops with call instructions cannot be interchanged " + << "safely."); + return false; + } + + // Create unique Preheaders if we already do not have one. + BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + + // Create a unique outer preheader - + // 1) If OuterLoop preheader is not present. + // 2) If OuterLoop Preheader is same as OuterLoop Header + // 3) If OuterLoop Preheader is same as Header of the previous loop. + // 4) If OuterLoop Preheader is Entry node. + if (!OuterLoopPreHeader || OuterLoopPreHeader == OuterLoop->getHeader() || + isa<PHINode>(OuterLoopPreHeader->begin()) || + !OuterLoopPreHeader->getUniquePredecessor()) { + OuterLoopPreHeader = + InsertPreheaderForLoop(OuterLoop, DT, LI, PreserveLCSSA); + } + + if (!InnerLoopPreHeader || InnerLoopPreHeader == InnerLoop->getHeader() || + InnerLoopPreHeader == OuterLoop->getHeader()) { + InnerLoopPreHeader = + InsertPreheaderForLoop(InnerLoop, DT, LI, PreserveLCSSA); + } + + // TODO: The loops could not be interchanged due to current limitations in the + // transform module. + if (currentLimitations()) { + DEBUG(dbgs() << "Not legal because of current transform limitation\n"); + return false; + } + + // Check if the loops are tightly nested. + if (!tightlyNested(OuterLoop, InnerLoop)) { + DEBUG(dbgs() << "Loops not tightly nested\n"); + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "NotTightlyNested", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Cannot interchange loops because they are not tightly " + "nested."; + }); + return false; + } + + return true; +} + +int LoopInterchangeProfitability::getInstrOrderCost() { + unsigned GoodOrder, BadOrder; + BadOrder = GoodOrder = 0; + for (BasicBlock *BB : InnerLoop->blocks()) { + for (Instruction &Ins : *BB) { + if (const GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(&Ins)) { + unsigned NumOp = GEP->getNumOperands(); + bool FoundInnerInduction = false; + bool FoundOuterInduction = false; + for (unsigned i = 0; i < NumOp; ++i) { + const SCEV *OperandVal = SE->getSCEV(GEP->getOperand(i)); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(OperandVal); + if (!AR) + continue; + + // If we find the inner induction after an outer induction e.g. + // for(int i=0;i<N;i++) + // for(int j=0;j<N;j++) + // A[i][j] = A[i-1][j-1]+k; + // then it is a good order. + if (AR->getLoop() == InnerLoop) { + // We found an InnerLoop induction after OuterLoop induction. It is + // a good order. + FoundInnerInduction = true; + if (FoundOuterInduction) { + GoodOrder++; + break; + } + } + // If we find the outer induction after an inner induction e.g. + // for(int i=0;i<N;i++) + // for(int j=0;j<N;j++) + // A[j][i] = A[j-1][i-1]+k; + // then it is a bad order. + if (AR->getLoop() == OuterLoop) { + // We found an OuterLoop induction after InnerLoop induction. It is + // a bad order. + FoundOuterInduction = true; + if (FoundInnerInduction) { + BadOrder++; + break; + } + } + } + } + } + } + return GoodOrder - BadOrder; +} + +static bool isProfitableForVectorization(unsigned InnerLoopId, + unsigned OuterLoopId, + CharMatrix &DepMatrix) { + // TODO: Improve this heuristic to catch more cases. + // If the inner loop is loop independent or doesn't carry any dependency it is + // profitable to move this to outer position. + for (auto &Row : DepMatrix) { + if (Row[InnerLoopId] != 'S' && Row[InnerLoopId] != 'I') + return false; + // TODO: We need to improve this heuristic. + if (Row[OuterLoopId] != '=') + return false; + } + // If outer loop has dependence and inner loop is loop independent then it is + // profitable to interchange to enable parallelism. + return true; +} + +bool LoopInterchangeProfitability::isProfitable(unsigned InnerLoopId, + unsigned OuterLoopId, + CharMatrix &DepMatrix) { + // TODO: Add better profitability checks. + // e.g + // 1) Construct dependency matrix and move the one with no loop carried dep + // inside to enable vectorization. + + // This is rough cost estimation algorithm. It counts the good and bad order + // of induction variables in the instruction and allows reordering if number + // of bad orders is more than good. + int Cost = getInstrOrderCost(); + DEBUG(dbgs() << "Cost = " << Cost << "\n"); + if (Cost < -LoopInterchangeCostThreshold) + return true; + + // It is not profitable as per current cache profitability model. But check if + // we can move this loop outside to improve parallelism. + if (isProfitableForVectorization(InnerLoopId, OuterLoopId, DepMatrix)) + return true; + + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, "InterchangeNotProfitable", + InnerLoop->getStartLoc(), + InnerLoop->getHeader()) + << "Interchanging loops is too costly (cost=" + << ore::NV("Cost", Cost) << ", threshold=" + << ore::NV("Threshold", LoopInterchangeCostThreshold) + << ") and it does not improve parallelism."; + }); + return false; +} + +void LoopInterchangeTransform::removeChildLoop(Loop *OuterLoop, + Loop *InnerLoop) { + for (Loop::iterator I = OuterLoop->begin(), E = OuterLoop->end(); I != E; + ++I) { + if (*I == InnerLoop) { + OuterLoop->removeChildLoop(I); + return; + } + } + llvm_unreachable("Couldn't find loop"); +} + +void LoopInterchangeTransform::restructureLoops(Loop *InnerLoop, + Loop *OuterLoop) { + Loop *OuterLoopParent = OuterLoop->getParentLoop(); + if (OuterLoopParent) { + // Remove the loop from its parent loop. + removeChildLoop(OuterLoopParent, OuterLoop); + removeChildLoop(OuterLoop, InnerLoop); + OuterLoopParent->addChildLoop(InnerLoop); + } else { + removeChildLoop(OuterLoop, InnerLoop); + LI->changeTopLevelLoop(OuterLoop, InnerLoop); + } + + while (!InnerLoop->empty()) + OuterLoop->addChildLoop(InnerLoop->removeChildLoop(InnerLoop->begin())); + + InnerLoop->addChildLoop(OuterLoop); +} + +bool LoopInterchangeTransform::transform() { + bool Transformed = false; + Instruction *InnerIndexVar; + + if (InnerLoop->getSubLoops().empty()) { + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + DEBUG(dbgs() << "Calling Split Inner Loop\n"); + PHINode *InductionPHI = getInductionVariable(InnerLoop, SE); + if (!InductionPHI) { + DEBUG(dbgs() << "Failed to find the point to split loop latch \n"); + return false; + } + + if (InductionPHI->getIncomingBlock(0) == InnerLoopPreHeader) + InnerIndexVar = dyn_cast<Instruction>(InductionPHI->getIncomingValue(1)); + else + InnerIndexVar = dyn_cast<Instruction>(InductionPHI->getIncomingValue(0)); + + // Ensure that InductionPHI is the first Phi node as required by + // splitInnerLoopHeader + if (&InductionPHI->getParent()->front() != InductionPHI) + InductionPHI->moveBefore(&InductionPHI->getParent()->front()); + + // Split at the place were the induction variable is + // incremented/decremented. + // TODO: This splitting logic may not work always. Fix this. + splitInnerLoopLatch(InnerIndexVar); + DEBUG(dbgs() << "splitInnerLoopLatch done\n"); + + // Splits the inner loops phi nodes out into a separate basic block. + splitInnerLoopHeader(); + DEBUG(dbgs() << "splitInnerLoopHeader done\n"); + } + + Transformed |= adjustLoopLinks(); + if (!Transformed) { + DEBUG(dbgs() << "adjustLoopLinks failed\n"); + return false; + } + + restructureLoops(InnerLoop, OuterLoop); + return true; +} + +void LoopInterchangeTransform::splitInnerLoopLatch(Instruction *Inc) { + BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); + BasicBlock *InnerLoopLatchPred = InnerLoopLatch; + InnerLoopLatch = SplitBlock(InnerLoopLatchPred, Inc, DT, LI); +} + +void LoopInterchangeTransform::splitInnerLoopHeader() { + // Split the inner loop header out. Here make sure that the reduction PHI's + // stay in the innerloop body. + BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + if (InnerLoopHasReduction) { + // Note: The induction PHI must be the first PHI for this to work + BasicBlock *New = InnerLoopHeader->splitBasicBlock( + ++(InnerLoopHeader->begin()), InnerLoopHeader->getName() + ".split"); + if (LI) + if (Loop *L = LI->getLoopFor(InnerLoopHeader)) + L->addBasicBlockToLoop(New, *LI); + + // Adjust Reduction PHI's in the block. + SmallVector<PHINode *, 8> PHIVec; + for (auto I = New->begin(); isa<PHINode>(I); ++I) { + PHINode *PHI = dyn_cast<PHINode>(I); + Value *V = PHI->getIncomingValueForBlock(InnerLoopPreHeader); + PHI->replaceAllUsesWith(V); + PHIVec.push_back((PHI)); + } + for (PHINode *P : PHIVec) { + P->eraseFromParent(); + } + } else { + SplitBlock(InnerLoopHeader, InnerLoopHeader->getFirstNonPHI(), DT, LI); + } + + DEBUG(dbgs() << "Output of splitInnerLoopHeader InnerLoopHeaderSucc & " + "InnerLoopHeader\n"); +} + +/// \brief Move all instructions except the terminator from FromBB right before +/// InsertBefore +static void moveBBContents(BasicBlock *FromBB, Instruction *InsertBefore) { + auto &ToList = InsertBefore->getParent()->getInstList(); + auto &FromList = FromBB->getInstList(); + + ToList.splice(InsertBefore->getIterator(), FromList, FromList.begin(), + FromBB->getTerminator()->getIterator()); +} + +void LoopInterchangeTransform::updateIncomingBlock(BasicBlock *CurrBlock, + BasicBlock *OldPred, + BasicBlock *NewPred) { + for (auto I = CurrBlock->begin(); isa<PHINode>(I); ++I) { + PHINode *PHI = cast<PHINode>(I); + unsigned Num = PHI->getNumIncomingValues(); + for (unsigned i = 0; i < Num; ++i) { + if (PHI->getIncomingBlock(i) == OldPred) + PHI->setIncomingBlock(i, NewPred); + } + } +} + +bool LoopInterchangeTransform::adjustLoopBranches() { + DEBUG(dbgs() << "adjustLoopBranches called\n"); + // Adjust the loop preheader + BasicBlock *InnerLoopHeader = InnerLoop->getHeader(); + BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); + BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch(); + BasicBlock *OuterLoopLatch = OuterLoop->getLoopLatch(); + BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + BasicBlock *OuterLoopPredecessor = OuterLoopPreHeader->getUniquePredecessor(); + BasicBlock *InnerLoopLatchPredecessor = + InnerLoopLatch->getUniquePredecessor(); + BasicBlock *InnerLoopLatchSuccessor; + BasicBlock *OuterLoopLatchSuccessor; + + BranchInst *OuterLoopLatchBI = + dyn_cast<BranchInst>(OuterLoopLatch->getTerminator()); + BranchInst *InnerLoopLatchBI = + dyn_cast<BranchInst>(InnerLoopLatch->getTerminator()); + BranchInst *OuterLoopHeaderBI = + dyn_cast<BranchInst>(OuterLoopHeader->getTerminator()); + BranchInst *InnerLoopHeaderBI = + dyn_cast<BranchInst>(InnerLoopHeader->getTerminator()); + + if (!OuterLoopPredecessor || !InnerLoopLatchPredecessor || + !OuterLoopLatchBI || !InnerLoopLatchBI || !OuterLoopHeaderBI || + !InnerLoopHeaderBI) + return false; + + BranchInst *InnerLoopLatchPredecessorBI = + dyn_cast<BranchInst>(InnerLoopLatchPredecessor->getTerminator()); + BranchInst *OuterLoopPredecessorBI = + dyn_cast<BranchInst>(OuterLoopPredecessor->getTerminator()); + + if (!OuterLoopPredecessorBI || !InnerLoopLatchPredecessorBI) + return false; + BasicBlock *InnerLoopHeaderSuccessor = InnerLoopHeader->getUniqueSuccessor(); + if (!InnerLoopHeaderSuccessor) + return false; + + // Adjust Loop Preheader and headers + + unsigned NumSucc = OuterLoopPredecessorBI->getNumSuccessors(); + for (unsigned i = 0; i < NumSucc; ++i) { + if (OuterLoopPredecessorBI->getSuccessor(i) == OuterLoopPreHeader) + OuterLoopPredecessorBI->setSuccessor(i, InnerLoopPreHeader); + } + + NumSucc = OuterLoopHeaderBI->getNumSuccessors(); + for (unsigned i = 0; i < NumSucc; ++i) { + if (OuterLoopHeaderBI->getSuccessor(i) == OuterLoopLatch) + OuterLoopHeaderBI->setSuccessor(i, LoopExit); + else if (OuterLoopHeaderBI->getSuccessor(i) == InnerLoopPreHeader) + OuterLoopHeaderBI->setSuccessor(i, InnerLoopHeaderSuccessor); + } + + // Adjust reduction PHI's now that the incoming block has changed. + updateIncomingBlock(InnerLoopHeaderSuccessor, InnerLoopHeader, + OuterLoopHeader); + + BranchInst::Create(OuterLoopPreHeader, InnerLoopHeaderBI); + InnerLoopHeaderBI->eraseFromParent(); + + // -------------Adjust loop latches----------- + if (InnerLoopLatchBI->getSuccessor(0) == InnerLoopHeader) + InnerLoopLatchSuccessor = InnerLoopLatchBI->getSuccessor(1); + else + InnerLoopLatchSuccessor = InnerLoopLatchBI->getSuccessor(0); + + NumSucc = InnerLoopLatchPredecessorBI->getNumSuccessors(); + for (unsigned i = 0; i < NumSucc; ++i) { + if (InnerLoopLatchPredecessorBI->getSuccessor(i) == InnerLoopLatch) + InnerLoopLatchPredecessorBI->setSuccessor(i, InnerLoopLatchSuccessor); + } + + // Adjust PHI nodes in InnerLoopLatchSuccessor. Update all uses of PHI with + // the value and remove this PHI node from inner loop. + SmallVector<PHINode *, 8> LcssaVec; + for (auto I = InnerLoopLatchSuccessor->begin(); isa<PHINode>(I); ++I) { + PHINode *LcssaPhi = cast<PHINode>(I); + LcssaVec.push_back(LcssaPhi); + } + for (PHINode *P : LcssaVec) { + Value *Incoming = P->getIncomingValueForBlock(InnerLoopLatch); + P->replaceAllUsesWith(Incoming); + P->eraseFromParent(); + } + + if (OuterLoopLatchBI->getSuccessor(0) == OuterLoopHeader) + OuterLoopLatchSuccessor = OuterLoopLatchBI->getSuccessor(1); + else + OuterLoopLatchSuccessor = OuterLoopLatchBI->getSuccessor(0); + + if (InnerLoopLatchBI->getSuccessor(1) == InnerLoopLatchSuccessor) + InnerLoopLatchBI->setSuccessor(1, OuterLoopLatchSuccessor); + else + InnerLoopLatchBI->setSuccessor(0, OuterLoopLatchSuccessor); + + updateIncomingBlock(OuterLoopLatchSuccessor, OuterLoopLatch, InnerLoopLatch); + + if (OuterLoopLatchBI->getSuccessor(0) == OuterLoopLatchSuccessor) { + OuterLoopLatchBI->setSuccessor(0, InnerLoopLatch); + } else { + OuterLoopLatchBI->setSuccessor(1, InnerLoopLatch); + } + + return true; +} + +void LoopInterchangeTransform::adjustLoopPreheaders() { + // We have interchanged the preheaders so we need to interchange the data in + // the preheader as well. + // This is because the content of inner preheader was previously executed + // inside the outer loop. + BasicBlock *OuterLoopPreHeader = OuterLoop->getLoopPreheader(); + BasicBlock *InnerLoopPreHeader = InnerLoop->getLoopPreheader(); + BasicBlock *OuterLoopHeader = OuterLoop->getHeader(); + BranchInst *InnerTermBI = + cast<BranchInst>(InnerLoopPreHeader->getTerminator()); + + // These instructions should now be executed inside the loop. + // Move instruction into a new block after outer header. + moveBBContents(InnerLoopPreHeader, OuterLoopHeader->getTerminator()); + // These instructions were not executed previously in the loop so move them to + // the older inner loop preheader. + moveBBContents(OuterLoopPreHeader, InnerTermBI); +} + +bool LoopInterchangeTransform::adjustLoopLinks() { + // Adjust all branches in the inner and outer loop. + bool Changed = adjustLoopBranches(); + if (Changed) + adjustLoopPreheaders(); + return Changed; +} + +char LoopInterchange::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopInterchange, "loop-interchange", + "Interchanges loops for cache reuse", false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DependenceAnalysisWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) + +INITIALIZE_PASS_END(LoopInterchange, "loop-interchange", + "Interchanges loops for cache reuse", false, false) + +Pass *llvm::createLoopInterchangePass() { return new LoopInterchange(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp new file mode 100644 index 000000000000..dfa5ec1f354d --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopLoadElimination.cpp @@ -0,0 +1,679 @@ +//===- LoopLoadElimination.cpp - Loop Load Elimination Pass ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implement a loop-aware load elimination pass. +// +// It uses LoopAccessAnalysis to identify loop-carried dependences with a +// distance of one between stores and loads. These form the candidates for the +// transformation. The source value of each store then propagated to the user +// of the corresponding load. This makes the load dead. +// +// The pass can also version the loop and add memchecks in order to prove that +// may-aliasing stores can't change the value in memory before it's read by the +// load. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopLoadElimination.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include <algorithm> +#include <cassert> +#include <forward_list> +#include <set> +#include <tuple> +#include <utility> + +using namespace llvm; + +#define LLE_OPTION "loop-load-elim" +#define DEBUG_TYPE LLE_OPTION + +static cl::opt<unsigned> CheckPerElim( + "runtime-check-per-loop-load-elim", cl::Hidden, + cl::desc("Max number of memchecks allowed per eliminated load on average"), + cl::init(1)); + +static cl::opt<unsigned> LoadElimSCEVCheckThreshold( + "loop-load-elimination-scev-check-threshold", cl::init(8), cl::Hidden, + cl::desc("The maximum number of SCEV checks allowed for Loop " + "Load Elimination")); + +STATISTIC(NumLoopLoadEliminted, "Number of loads eliminated by LLE"); + +namespace { + +/// \brief Represent a store-to-forwarding candidate. +struct StoreToLoadForwardingCandidate { + LoadInst *Load; + StoreInst *Store; + + StoreToLoadForwardingCandidate(LoadInst *Load, StoreInst *Store) + : Load(Load), Store(Store) {} + + /// \brief Return true if the dependence from the store to the load has a + /// distance of one. E.g. A[i+1] = A[i] + bool isDependenceDistanceOfOne(PredicatedScalarEvolution &PSE, + Loop *L) const { + Value *LoadPtr = Load->getPointerOperand(); + Value *StorePtr = Store->getPointerOperand(); + Type *LoadPtrType = LoadPtr->getType(); + Type *LoadType = LoadPtrType->getPointerElementType(); + + assert(LoadPtrType->getPointerAddressSpace() == + StorePtr->getType()->getPointerAddressSpace() && + LoadType == StorePtr->getType()->getPointerElementType() && + "Should be a known dependence"); + + // Currently we only support accesses with unit stride. FIXME: we should be + // able to handle non unit stirde as well as long as the stride is equal to + // the dependence distance. + if (getPtrStride(PSE, LoadPtr, L) != 1 || + getPtrStride(PSE, StorePtr, L) != 1) + return false; + + auto &DL = Load->getParent()->getModule()->getDataLayout(); + unsigned TypeByteSize = DL.getTypeAllocSize(const_cast<Type *>(LoadType)); + + auto *LoadPtrSCEV = cast<SCEVAddRecExpr>(PSE.getSCEV(LoadPtr)); + auto *StorePtrSCEV = cast<SCEVAddRecExpr>(PSE.getSCEV(StorePtr)); + + // We don't need to check non-wrapping here because forward/backward + // dependence wouldn't be valid if these weren't monotonic accesses. + auto *Dist = cast<SCEVConstant>( + PSE.getSE()->getMinusSCEV(StorePtrSCEV, LoadPtrSCEV)); + const APInt &Val = Dist->getAPInt(); + return Val == TypeByteSize; + } + + Value *getLoadPtr() const { return Load->getPointerOperand(); } + +#ifndef NDEBUG + friend raw_ostream &operator<<(raw_ostream &OS, + const StoreToLoadForwardingCandidate &Cand) { + OS << *Cand.Store << " -->\n"; + OS.indent(2) << *Cand.Load << "\n"; + return OS; + } +#endif +}; + +} // end anonymous namespace + +/// \brief Check if the store dominates all latches, so as long as there is no +/// intervening store this value will be loaded in the next iteration. +static bool doesStoreDominatesAllLatches(BasicBlock *StoreBlock, Loop *L, + DominatorTree *DT) { + SmallVector<BasicBlock *, 8> Latches; + L->getLoopLatches(Latches); + return llvm::all_of(Latches, [&](const BasicBlock *Latch) { + return DT->dominates(StoreBlock, Latch); + }); +} + +/// \brief Return true if the load is not executed on all paths in the loop. +static bool isLoadConditional(LoadInst *Load, Loop *L) { + return Load->getParent() != L->getHeader(); +} + +namespace { + +/// \brief The per-loop class that does most of the work. +class LoadEliminationForLoop { +public: + LoadEliminationForLoop(Loop *L, LoopInfo *LI, const LoopAccessInfo &LAI, + DominatorTree *DT) + : L(L), LI(LI), LAI(LAI), DT(DT), PSE(LAI.getPSE()) {} + + /// \brief Look through the loop-carried and loop-independent dependences in + /// this loop and find store->load dependences. + /// + /// Note that no candidate is returned if LAA has failed to analyze the loop + /// (e.g. if it's not bottom-tested, contains volatile memops, etc.) + std::forward_list<StoreToLoadForwardingCandidate> + findStoreToLoadDependences(const LoopAccessInfo &LAI) { + std::forward_list<StoreToLoadForwardingCandidate> Candidates; + + const auto *Deps = LAI.getDepChecker().getDependences(); + if (!Deps) + return Candidates; + + // Find store->load dependences (consequently true dep). Both lexically + // forward and backward dependences qualify. Disqualify loads that have + // other unknown dependences. + + SmallSet<Instruction *, 4> LoadsWithUnknownDepedence; + + for (const auto &Dep : *Deps) { + Instruction *Source = Dep.getSource(LAI); + Instruction *Destination = Dep.getDestination(LAI); + + if (Dep.Type == MemoryDepChecker::Dependence::Unknown) { + if (isa<LoadInst>(Source)) + LoadsWithUnknownDepedence.insert(Source); + if (isa<LoadInst>(Destination)) + LoadsWithUnknownDepedence.insert(Destination); + continue; + } + + if (Dep.isBackward()) + // Note that the designations source and destination follow the program + // order, i.e. source is always first. (The direction is given by the + // DepType.) + std::swap(Source, Destination); + else + assert(Dep.isForward() && "Needs to be a forward dependence"); + + auto *Store = dyn_cast<StoreInst>(Source); + if (!Store) + continue; + auto *Load = dyn_cast<LoadInst>(Destination); + if (!Load) + continue; + + // Only progagate the value if they are of the same type. + if (Store->getPointerOperandType() != Load->getPointerOperandType()) + continue; + + Candidates.emplace_front(Load, Store); + } + + if (!LoadsWithUnknownDepedence.empty()) + Candidates.remove_if([&](const StoreToLoadForwardingCandidate &C) { + return LoadsWithUnknownDepedence.count(C.Load); + }); + + return Candidates; + } + + /// \brief Return the index of the instruction according to program order. + unsigned getInstrIndex(Instruction *Inst) { + auto I = InstOrder.find(Inst); + assert(I != InstOrder.end() && "No index for instruction"); + return I->second; + } + + /// \brief If a load has multiple candidates associated (i.e. different + /// stores), it means that it could be forwarding from multiple stores + /// depending on control flow. Remove these candidates. + /// + /// Here, we rely on LAA to include the relevant loop-independent dependences. + /// LAA is known to omit these in the very simple case when the read and the + /// write within an alias set always takes place using the *same* pointer. + /// + /// However, we know that this is not the case here, i.e. we can rely on LAA + /// to provide us with loop-independent dependences for the cases we're + /// interested. Consider the case for example where a loop-independent + /// dependece S1->S2 invalidates the forwarding S3->S2. + /// + /// A[i] = ... (S1) + /// ... = A[i] (S2) + /// A[i+1] = ... (S3) + /// + /// LAA will perform dependence analysis here because there are two + /// *different* pointers involved in the same alias set (&A[i] and &A[i+1]). + void removeDependencesFromMultipleStores( + std::forward_list<StoreToLoadForwardingCandidate> &Candidates) { + // If Store is nullptr it means that we have multiple stores forwarding to + // this store. + using LoadToSingleCandT = + DenseMap<LoadInst *, const StoreToLoadForwardingCandidate *>; + LoadToSingleCandT LoadToSingleCand; + + for (const auto &Cand : Candidates) { + bool NewElt; + LoadToSingleCandT::iterator Iter; + + std::tie(Iter, NewElt) = + LoadToSingleCand.insert(std::make_pair(Cand.Load, &Cand)); + if (!NewElt) { + const StoreToLoadForwardingCandidate *&OtherCand = Iter->second; + // Already multiple stores forward to this load. + if (OtherCand == nullptr) + continue; + + // Handle the very basic case when the two stores are in the same block + // so deciding which one forwards is easy. The later one forwards as + // long as they both have a dependence distance of one to the load. + if (Cand.Store->getParent() == OtherCand->Store->getParent() && + Cand.isDependenceDistanceOfOne(PSE, L) && + OtherCand->isDependenceDistanceOfOne(PSE, L)) { + // They are in the same block, the later one will forward to the load. + if (getInstrIndex(OtherCand->Store) < getInstrIndex(Cand.Store)) + OtherCand = &Cand; + } else + OtherCand = nullptr; + } + } + + Candidates.remove_if([&](const StoreToLoadForwardingCandidate &Cand) { + if (LoadToSingleCand[Cand.Load] != &Cand) { + DEBUG(dbgs() << "Removing from candidates: \n" << Cand + << " The load may have multiple stores forwarding to " + << "it\n"); + return true; + } + return false; + }); + } + + /// \brief Given two pointers operations by their RuntimePointerChecking + /// indices, return true if they require an alias check. + /// + /// We need a check if one is a pointer for a candidate load and the other is + /// a pointer for a possibly intervening store. + bool needsChecking(unsigned PtrIdx1, unsigned PtrIdx2, + const SmallSet<Value *, 4> &PtrsWrittenOnFwdingPath, + const std::set<Value *> &CandLoadPtrs) { + Value *Ptr1 = + LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx1).PointerValue; + Value *Ptr2 = + LAI.getRuntimePointerChecking()->getPointerInfo(PtrIdx2).PointerValue; + return ((PtrsWrittenOnFwdingPath.count(Ptr1) && CandLoadPtrs.count(Ptr2)) || + (PtrsWrittenOnFwdingPath.count(Ptr2) && CandLoadPtrs.count(Ptr1))); + } + + /// \brief Return pointers that are possibly written to on the path from a + /// forwarding store to a load. + /// + /// These pointers need to be alias-checked against the forwarding candidates. + SmallSet<Value *, 4> findPointersWrittenOnForwardingPath( + const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) { + // From FirstStore to LastLoad neither of the elimination candidate loads + // should overlap with any of the stores. + // + // E.g.: + // + // st1 C[i] + // ld1 B[i] <-------, + // ld0 A[i] <----, | * LastLoad + // ... | | + // st2 E[i] | | + // st3 B[i+1] -- | -' * FirstStore + // st0 A[i+1] ---' + // st4 D[i] + // + // st0 forwards to ld0 if the accesses in st4 and st1 don't overlap with + // ld0. + + LoadInst *LastLoad = + std::max_element(Candidates.begin(), Candidates.end(), + [&](const StoreToLoadForwardingCandidate &A, + const StoreToLoadForwardingCandidate &B) { + return getInstrIndex(A.Load) < getInstrIndex(B.Load); + }) + ->Load; + StoreInst *FirstStore = + std::min_element(Candidates.begin(), Candidates.end(), + [&](const StoreToLoadForwardingCandidate &A, + const StoreToLoadForwardingCandidate &B) { + return getInstrIndex(A.Store) < + getInstrIndex(B.Store); + }) + ->Store; + + // We're looking for stores after the first forwarding store until the end + // of the loop, then from the beginning of the loop until the last + // forwarded-to load. Collect the pointer for the stores. + SmallSet<Value *, 4> PtrsWrittenOnFwdingPath; + + auto InsertStorePtr = [&](Instruction *I) { + if (auto *S = dyn_cast<StoreInst>(I)) + PtrsWrittenOnFwdingPath.insert(S->getPointerOperand()); + }; + const auto &MemInstrs = LAI.getDepChecker().getMemoryInstructions(); + std::for_each(MemInstrs.begin() + getInstrIndex(FirstStore) + 1, + MemInstrs.end(), InsertStorePtr); + std::for_each(MemInstrs.begin(), &MemInstrs[getInstrIndex(LastLoad)], + InsertStorePtr); + + return PtrsWrittenOnFwdingPath; + } + + /// \brief Determine the pointer alias checks to prove that there are no + /// intervening stores. + SmallVector<RuntimePointerChecking::PointerCheck, 4> collectMemchecks( + const SmallVectorImpl<StoreToLoadForwardingCandidate> &Candidates) { + + SmallSet<Value *, 4> PtrsWrittenOnFwdingPath = + findPointersWrittenOnForwardingPath(Candidates); + + // Collect the pointers of the candidate loads. + // FIXME: SmallSet does not work with std::inserter. + std::set<Value *> CandLoadPtrs; + transform(Candidates, + std::inserter(CandLoadPtrs, CandLoadPtrs.begin()), + std::mem_fn(&StoreToLoadForwardingCandidate::getLoadPtr)); + + const auto &AllChecks = LAI.getRuntimePointerChecking()->getChecks(); + SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks; + + copy_if(AllChecks, std::back_inserter(Checks), + [&](const RuntimePointerChecking::PointerCheck &Check) { + for (auto PtrIdx1 : Check.first->Members) + for (auto PtrIdx2 : Check.second->Members) + if (needsChecking(PtrIdx1, PtrIdx2, PtrsWrittenOnFwdingPath, + CandLoadPtrs)) + return true; + return false; + }); + + DEBUG(dbgs() << "\nPointer Checks (count: " << Checks.size() << "):\n"); + DEBUG(LAI.getRuntimePointerChecking()->printChecks(dbgs(), Checks)); + + return Checks; + } + + /// \brief Perform the transformation for a candidate. + void + propagateStoredValueToLoadUsers(const StoreToLoadForwardingCandidate &Cand, + SCEVExpander &SEE) { + // loop: + // %x = load %gep_i + // = ... %x + // store %y, %gep_i_plus_1 + // + // => + // + // ph: + // %x.initial = load %gep_0 + // loop: + // %x.storeforward = phi [%x.initial, %ph] [%y, %loop] + // %x = load %gep_i <---- now dead + // = ... %x.storeforward + // store %y, %gep_i_plus_1 + + Value *Ptr = Cand.Load->getPointerOperand(); + auto *PtrSCEV = cast<SCEVAddRecExpr>(PSE.getSCEV(Ptr)); + auto *PH = L->getLoopPreheader(); + Value *InitialPtr = SEE.expandCodeFor(PtrSCEV->getStart(), Ptr->getType(), + PH->getTerminator()); + Value *Initial = + new LoadInst(InitialPtr, "load_initial", /* isVolatile */ false, + Cand.Load->getAlignment(), PH->getTerminator()); + + PHINode *PHI = PHINode::Create(Initial->getType(), 2, "store_forwarded", + &L->getHeader()->front()); + PHI->addIncoming(Initial, PH); + PHI->addIncoming(Cand.Store->getOperand(0), L->getLoopLatch()); + + Cand.Load->replaceAllUsesWith(PHI); + } + + /// \brief Top-level driver for each loop: find store->load forwarding + /// candidates, add run-time checks and perform transformation. + bool processLoop() { + DEBUG(dbgs() << "\nIn \"" << L->getHeader()->getParent()->getName() + << "\" checking " << *L << "\n"); + + // Look for store-to-load forwarding cases across the + // backedge. E.g.: + // + // loop: + // %x = load %gep_i + // = ... %x + // store %y, %gep_i_plus_1 + // + // => + // + // ph: + // %x.initial = load %gep_0 + // loop: + // %x.storeforward = phi [%x.initial, %ph] [%y, %loop] + // %x = load %gep_i <---- now dead + // = ... %x.storeforward + // store %y, %gep_i_plus_1 + + // First start with store->load dependences. + auto StoreToLoadDependences = findStoreToLoadDependences(LAI); + if (StoreToLoadDependences.empty()) + return false; + + // Generate an index for each load and store according to the original + // program order. This will be used later. + InstOrder = LAI.getDepChecker().generateInstructionOrderMap(); + + // To keep things simple for now, remove those where the load is potentially + // fed by multiple stores. + removeDependencesFromMultipleStores(StoreToLoadDependences); + if (StoreToLoadDependences.empty()) + return false; + + // Filter the candidates further. + SmallVector<StoreToLoadForwardingCandidate, 4> Candidates; + unsigned NumForwarding = 0; + for (const StoreToLoadForwardingCandidate Cand : StoreToLoadDependences) { + DEBUG(dbgs() << "Candidate " << Cand); + + // Make sure that the stored values is available everywhere in the loop in + // the next iteration. + if (!doesStoreDominatesAllLatches(Cand.Store->getParent(), L, DT)) + continue; + + // If the load is conditional we can't hoist its 0-iteration instance to + // the preheader because that would make it unconditional. Thus we would + // access a memory location that the original loop did not access. + if (isLoadConditional(Cand.Load, L)) + continue; + + // Check whether the SCEV difference is the same as the induction step, + // thus we load the value in the next iteration. + if (!Cand.isDependenceDistanceOfOne(PSE, L)) + continue; + + ++NumForwarding; + DEBUG(dbgs() + << NumForwarding + << ". Valid store-to-load forwarding across the loop backedge\n"); + Candidates.push_back(Cand); + } + if (Candidates.empty()) + return false; + + // Check intervening may-alias stores. These need runtime checks for alias + // disambiguation. + SmallVector<RuntimePointerChecking::PointerCheck, 4> Checks = + collectMemchecks(Candidates); + + // Too many checks are likely to outweigh the benefits of forwarding. + if (Checks.size() > Candidates.size() * CheckPerElim) { + DEBUG(dbgs() << "Too many run-time checks needed.\n"); + return false; + } + + if (LAI.getPSE().getUnionPredicate().getComplexity() > + LoadElimSCEVCheckThreshold) { + DEBUG(dbgs() << "Too many SCEV run-time checks needed.\n"); + return false; + } + + if (!Checks.empty() || !LAI.getPSE().getUnionPredicate().isAlwaysTrue()) { + if (L->getHeader()->getParent()->optForSize()) { + DEBUG(dbgs() << "Versioning is needed but not allowed when optimizing " + "for size.\n"); + return false; + } + + if (!L->isLoopSimplifyForm()) { + DEBUG(dbgs() << "Loop is not is loop-simplify form"); + return false; + } + + // Point of no-return, start the transformation. First, version the loop + // if necessary. + + LoopVersioning LV(LAI, L, LI, DT, PSE.getSE(), false); + LV.setAliasChecks(std::move(Checks)); + LV.setSCEVChecks(LAI.getPSE().getUnionPredicate()); + LV.versionLoop(); + } + + // Next, propagate the value stored by the store to the users of the load. + // Also for the first iteration, generate the initial value of the load. + SCEVExpander SEE(*PSE.getSE(), L->getHeader()->getModule()->getDataLayout(), + "storeforward"); + for (const auto &Cand : Candidates) + propagateStoredValueToLoadUsers(Cand, SEE); + NumLoopLoadEliminted += NumForwarding; + + return true; + } + +private: + Loop *L; + + /// \brief Maps the load/store instructions to their index according to + /// program order. + DenseMap<Instruction *, unsigned> InstOrder; + + // Analyses used. + LoopInfo *LI; + const LoopAccessInfo &LAI; + DominatorTree *DT; + PredicatedScalarEvolution PSE; +}; + +} // end anonymous namespace + +static bool +eliminateLoadsAcrossLoops(Function &F, LoopInfo &LI, DominatorTree &DT, + function_ref<const LoopAccessInfo &(Loop &)> GetLAI) { + // Build up a worklist of inner-loops to transform to avoid iterator + // invalidation. + // FIXME: This logic comes from other passes that actually change the loop + // nest structure. It isn't clear this is necessary (or useful) for a pass + // which merely optimizes the use of loads in a loop. + SmallVector<Loop *, 8> Worklist; + + for (Loop *TopLevelLoop : LI) + for (Loop *L : depth_first(TopLevelLoop)) + // We only handle inner-most loops. + if (L->empty()) + Worklist.push_back(L); + + // Now walk the identified inner loops. + bool Changed = false; + for (Loop *L : Worklist) { + // The actual work is performed by LoadEliminationForLoop. + LoadEliminationForLoop LEL(L, &LI, GetLAI(*L), &DT); + Changed |= LEL.processLoop(); + } + return Changed; +} + +namespace { + +/// \brief The pass. Most of the work is delegated to the per-loop +/// LoadEliminationForLoop class. +class LoopLoadElimination : public FunctionPass { +public: + static char ID; + + LoopLoadElimination() : FunctionPass(ID) { + initializeLoopLoadEliminationPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &LAA = getAnalysis<LoopAccessLegacyAnalysis>(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + + // Process each loop nest in the function. + return eliminateLoadsAcrossLoops( + F, LI, DT, + [&LAA](Loop &L) -> const LoopAccessInfo & { return LAA.getInfo(&L); }); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequired<LoopAccessLegacyAnalysis>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; + +} // end anonymous namespace + +char LoopLoadElimination::ID; + +static const char LLE_name[] = "Loop Load Elimination"; + +INITIALIZE_PASS_BEGIN(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_END(LoopLoadElimination, LLE_OPTION, LLE_name, false, false) + +FunctionPass *llvm::createLoopLoadEliminationPass() { + return new LoopLoadElimination(); +} + +PreservedAnalyses LoopLoadEliminationPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + + auto &LAM = AM.getResult<LoopAnalysisManagerFunctionProxy>(F).getManager(); + bool Changed = eliminateLoadsAcrossLoops( + F, LI, DT, [&](Loop &L) -> const LoopAccessInfo & { + LoopStandardAnalysisResults AR = {AA, AC, DT, LI, + SE, TLI, TTI, nullptr}; + return LAM.getResult<LoopAccessAnalysis>(L, AR); + }); + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp new file mode 100644 index 000000000000..10f6fcdcfdb7 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopPassManager.cpp @@ -0,0 +1,92 @@ +//===- LoopPassManager.cpp - Loop pass management -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Analysis/LoopInfo.h" + +using namespace llvm; + +// Explicit template instantiations and specialization defininitions for core +// template typedefs. +namespace llvm { +template class PassManager<Loop, LoopAnalysisManager, + LoopStandardAnalysisResults &, LPMUpdater &>; + +/// Explicitly specialize the pass manager's run method to handle loop nest +/// structure updates. +template <> +PreservedAnalyses +PassManager<Loop, LoopAnalysisManager, LoopStandardAnalysisResults &, + LPMUpdater &>::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, LPMUpdater &U) { + PreservedAnalyses PA = PreservedAnalyses::all(); + + if (DebugLogging) + dbgs() << "Starting Loop pass manager run.\n"; + + for (auto &Pass : Passes) { + if (DebugLogging) + dbgs() << "Running pass: " << Pass->name() << " on " << L; + + PreservedAnalyses PassPA = Pass->run(L, AM, AR, U); + + // If the loop was deleted, abort the run and return to the outer walk. + if (U.skipCurrentLoop()) { + PA.intersect(std::move(PassPA)); + break; + } + +#ifndef NDEBUG + // Verify the loop structure and LCSSA form before visiting the loop. + L.verifyLoop(); + assert(L.isRecursivelyLCSSAForm(AR.DT, AR.LI) && + "Loops must remain in LCSSA form!"); +#endif + + // Update the analysis manager as each pass runs and potentially + // invalidates analyses. + AM.invalidate(L, PassPA); + + // Finally, we intersect the final preserved analyses to compute the + // aggregate preserved set for this pass manager. + PA.intersect(std::move(PassPA)); + + // FIXME: Historically, the pass managers all called the LLVM context's + // yield function here. We don't have a generic way to acquire the + // context and it isn't yet clear what the right pattern is for yielding + // in the new pass manager so it is currently omitted. + // ...getContext().yield(); + } + + // Invalidation for the current loop should be handled above, and other loop + // analysis results shouldn't be impacted by runs over this loop. Therefore, + // the remaining analysis results in the AnalysisManager are preserved. We + // mark this with a set so that we don't need to inspect each one + // individually. + // FIXME: This isn't correct! This loop and all nested loops' analyses should + // be preserved, but unrolling should invalidate the parent loop's analyses. + PA.preserveSet<AllAnalysesOn<Loop>>(); + + if (DebugLogging) + dbgs() << "Finished Loop pass manager run.\n"; + + return PA; +} +} + +PrintLoopPass::PrintLoopPass() : OS(dbgs()) {} +PrintLoopPass::PrintLoopPass(raw_ostream &OS, const std::string &Banner) + : OS(OS), Banner(Banner) {} + +PreservedAnalyses PrintLoopPass::run(Loop &L, LoopAnalysisManager &, + LoopStandardAnalysisResults &, + LPMUpdater &) { + printLoop(L, OS, Banner); + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopPredication.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopPredication.cpp new file mode 100644 index 000000000000..2e4c7b19e476 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopPredication.cpp @@ -0,0 +1,750 @@ +//===-- LoopPredication.cpp - Guard based loop predication pass -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// The LoopPredication pass tries to convert loop variant range checks to loop +// invariant by widening checks across loop iterations. For example, it will +// convert +// +// for (i = 0; i < n; i++) { +// guard(i < len); +// ... +// } +// +// to +// +// for (i = 0; i < n; i++) { +// guard(n - 1 < len); +// ... +// } +// +// After this transformation the condition of the guard is loop invariant, so +// loop-unswitch can later unswitch the loop by this condition which basically +// predicates the loop by the widened condition: +// +// if (n - 1 < len) +// for (i = 0; i < n; i++) { +// ... +// } +// else +// deoptimize +// +// It's tempting to rely on SCEV here, but it has proven to be problematic. +// Generally the facts SCEV provides about the increment step of add +// recurrences are true if the backedge of the loop is taken, which implicitly +// assumes that the guard doesn't fail. Using these facts to optimize the +// guard results in a circular logic where the guard is optimized under the +// assumption that it never fails. +// +// For example, in the loop below the induction variable will be marked as nuw +// basing on the guard. Basing on nuw the guard predicate will be considered +// monotonic. Given a monotonic condition it's tempting to replace the induction +// variable in the condition with its value on the last iteration. But this +// transformation is not correct, e.g. e = 4, b = 5 breaks the loop. +// +// for (int i = b; i != e; i++) +// guard(i u< len) +// +// One of the ways to reason about this problem is to use an inductive proof +// approach. Given the loop: +// +// if (B(0)) { +// do { +// I = PHI(0, I.INC) +// I.INC = I + Step +// guard(G(I)); +// } while (B(I)); +// } +// +// where B(x) and G(x) are predicates that map integers to booleans, we want a +// loop invariant expression M such the following program has the same semantics +// as the above: +// +// if (B(0)) { +// do { +// I = PHI(0, I.INC) +// I.INC = I + Step +// guard(G(0) && M); +// } while (B(I)); +// } +// +// One solution for M is M = forall X . (G(X) && B(X)) => G(X + Step) +// +// Informal proof that the transformation above is correct: +// +// By the definition of guards we can rewrite the guard condition to: +// G(I) && G(0) && M +// +// Let's prove that for each iteration of the loop: +// G(0) && M => G(I) +// And the condition above can be simplified to G(Start) && M. +// +// Induction base. +// G(0) && M => G(0) +// +// Induction step. Assuming G(0) && M => G(I) on the subsequent +// iteration: +// +// B(I) is true because it's the backedge condition. +// G(I) is true because the backedge is guarded by this condition. +// +// So M = forall X . (G(X) && B(X)) => G(X + Step) implies G(I + Step). +// +// Note that we can use anything stronger than M, i.e. any condition which +// implies M. +// +// When S = 1 (i.e. forward iterating loop), the transformation is supported +// when: +// * The loop has a single latch with the condition of the form: +// B(X) = latchStart + X <pred> latchLimit, +// where <pred> is u<, u<=, s<, or s<=. +// * The guard condition is of the form +// G(X) = guardStart + X u< guardLimit +// +// For the ult latch comparison case M is: +// forall X . guardStart + X u< guardLimit && latchStart + X <u latchLimit => +// guardStart + X + 1 u< guardLimit +// +// The only way the antecedent can be true and the consequent can be false is +// if +// X == guardLimit - 1 - guardStart +// (and guardLimit is non-zero, but we won't use this latter fact). +// If X == guardLimit - 1 - guardStart then the second half of the antecedent is +// latchStart + guardLimit - 1 - guardStart u< latchLimit +// and its negation is +// latchStart + guardLimit - 1 - guardStart u>= latchLimit +// +// In other words, if +// latchLimit u<= latchStart + guardLimit - 1 - guardStart +// then: +// (the ranges below are written in ConstantRange notation, where [A, B) is the +// set for (I = A; I != B; I++ /*maywrap*/) yield(I);) +// +// forall X . guardStart + X u< guardLimit && +// latchStart + X u< latchLimit => +// guardStart + X + 1 u< guardLimit +// == forall X . guardStart + X u< guardLimit && +// latchStart + X u< latchStart + guardLimit - 1 - guardStart => +// guardStart + X + 1 u< guardLimit +// == forall X . (guardStart + X) in [0, guardLimit) && +// (latchStart + X) in [0, latchStart + guardLimit - 1 - guardStart) => +// (guardStart + X + 1) in [0, guardLimit) +// == forall X . X in [-guardStart, guardLimit - guardStart) && +// X in [-latchStart, guardLimit - 1 - guardStart) => +// X in [-guardStart - 1, guardLimit - guardStart - 1) +// == true +// +// So the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart u>= latchLimit +// Similarly for ule condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart u> latchLimit +// For slt condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart s>= latchLimit +// For sle condition the widened condition is: +// guardStart u< guardLimit && +// latchStart + guardLimit - 1 - guardStart s> latchLimit +// +// When S = -1 (i.e. reverse iterating loop), the transformation is supported +// when: +// * The loop has a single latch with the condition of the form: +// B(X) = X <pred> latchLimit, where <pred> is u> or s>. +// * The guard condition is of the form +// G(X) = X - 1 u< guardLimit +// +// For the ugt latch comparison case M is: +// forall X. X-1 u< guardLimit and X u> latchLimit => X-2 u< guardLimit +// +// The only way the antecedent can be true and the consequent can be false is if +// X == 1. +// If X == 1 then the second half of the antecedent is +// 1 u> latchLimit, and its negation is latchLimit u>= 1. +// +// So the widened condition is: +// guardStart u< guardLimit && latchLimit u>= 1. +// Similarly for sgt condition the widened condition is: +// guardStart u< guardLimit && latchLimit s>= 1. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopPredication.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopUtils.h" + +#define DEBUG_TYPE "loop-predication" + +using namespace llvm; + +static cl::opt<bool> EnableIVTruncation("loop-predication-enable-iv-truncation", + cl::Hidden, cl::init(true)); + +static cl::opt<bool> EnableCountDownLoop("loop-predication-enable-count-down-loop", + cl::Hidden, cl::init(true)); +namespace { +class LoopPredication { + /// Represents an induction variable check: + /// icmp Pred, <induction variable>, <loop invariant limit> + struct LoopICmp { + ICmpInst::Predicate Pred; + const SCEVAddRecExpr *IV; + const SCEV *Limit; + LoopICmp(ICmpInst::Predicate Pred, const SCEVAddRecExpr *IV, + const SCEV *Limit) + : Pred(Pred), IV(IV), Limit(Limit) {} + LoopICmp() {} + void dump() { + dbgs() << "LoopICmp Pred = " << Pred << ", IV = " << *IV + << ", Limit = " << *Limit << "\n"; + } + }; + + ScalarEvolution *SE; + + Loop *L; + const DataLayout *DL; + BasicBlock *Preheader; + LoopICmp LatchCheck; + + bool isSupportedStep(const SCEV* Step); + Optional<LoopICmp> parseLoopICmp(ICmpInst *ICI) { + return parseLoopICmp(ICI->getPredicate(), ICI->getOperand(0), + ICI->getOperand(1)); + } + Optional<LoopICmp> parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, + Value *RHS); + + Optional<LoopICmp> parseLoopLatchICmp(); + + bool CanExpand(const SCEV* S); + Value *expandCheck(SCEVExpander &Expander, IRBuilder<> &Builder, + ICmpInst::Predicate Pred, const SCEV *LHS, const SCEV *RHS, + Instruction *InsertAt); + + Optional<Value *> widenICmpRangeCheck(ICmpInst *ICI, SCEVExpander &Expander, + IRBuilder<> &Builder); + Optional<Value *> widenICmpRangeCheckIncrementingLoop(LoopICmp LatchCheck, + LoopICmp RangeCheck, + SCEVExpander &Expander, + IRBuilder<> &Builder); + Optional<Value *> widenICmpRangeCheckDecrementingLoop(LoopICmp LatchCheck, + LoopICmp RangeCheck, + SCEVExpander &Expander, + IRBuilder<> &Builder); + bool widenGuardConditions(IntrinsicInst *II, SCEVExpander &Expander); + + // When the IV type is wider than the range operand type, we can still do loop + // predication, by generating SCEVs for the range and latch that are of the + // same type. We achieve this by generating a SCEV truncate expression for the + // latch IV. This is done iff truncation of the IV is a safe operation, + // without loss of information. + // Another way to achieve this is by generating a wider type SCEV for the + // range check operand, however, this needs a more involved check that + // operands do not overflow. This can lead to loss of information when the + // range operand is of the form: add i32 %offset, %iv. We need to prove that + // sext(x + y) is same as sext(x) + sext(y). + // This function returns true if we can safely represent the IV type in + // the RangeCheckType without loss of information. + bool isSafeToTruncateWideIVType(Type *RangeCheckType); + // Return the loopLatchCheck corresponding to the RangeCheckType if safe to do + // so. + Optional<LoopICmp> generateLoopLatchCheck(Type *RangeCheckType); +public: + LoopPredication(ScalarEvolution *SE) : SE(SE){}; + bool runOnLoop(Loop *L); +}; + +class LoopPredicationLegacyPass : public LoopPass { +public: + static char ID; + LoopPredicationLegacyPass() : LoopPass(ID) { + initializeLoopPredicationLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + getLoopAnalysisUsage(AU); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LoopPredication LP(SE); + return LP.runOnLoop(L); + } +}; + +char LoopPredicationLegacyPass::ID = 0; +} // end namespace llvm + +INITIALIZE_PASS_BEGIN(LoopPredicationLegacyPass, "loop-predication", + "Loop predication", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopPredicationLegacyPass, "loop-predication", + "Loop predication", false, false) + +Pass *llvm::createLoopPredicationPass() { + return new LoopPredicationLegacyPass(); +} + +PreservedAnalyses LoopPredicationPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + LoopPredication LP(&AR.SE); + if (!LP.runOnLoop(&L)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +Optional<LoopPredication::LoopICmp> +LoopPredication::parseLoopICmp(ICmpInst::Predicate Pred, Value *LHS, + Value *RHS) { + const SCEV *LHSS = SE->getSCEV(LHS); + if (isa<SCEVCouldNotCompute>(LHSS)) + return None; + const SCEV *RHSS = SE->getSCEV(RHS); + if (isa<SCEVCouldNotCompute>(RHSS)) + return None; + + // Canonicalize RHS to be loop invariant bound, LHS - a loop computable IV + if (SE->isLoopInvariant(LHSS, L)) { + std::swap(LHS, RHS); + std::swap(LHSS, RHSS); + Pred = ICmpInst::getSwappedPredicate(Pred); + } + + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHSS); + if (!AR || AR->getLoop() != L) + return None; + + return LoopICmp(Pred, AR, RHSS); +} + +Value *LoopPredication::expandCheck(SCEVExpander &Expander, + IRBuilder<> &Builder, + ICmpInst::Predicate Pred, const SCEV *LHS, + const SCEV *RHS, Instruction *InsertAt) { + // TODO: we can check isLoopEntryGuardedByCond before emitting the check + + Type *Ty = LHS->getType(); + assert(Ty == RHS->getType() && "expandCheck operands have different types?"); + + if (SE->isLoopEntryGuardedByCond(L, Pred, LHS, RHS)) + return Builder.getTrue(); + + Value *LHSV = Expander.expandCodeFor(LHS, Ty, InsertAt); + Value *RHSV = Expander.expandCodeFor(RHS, Ty, InsertAt); + return Builder.CreateICmp(Pred, LHSV, RHSV); +} + +Optional<LoopPredication::LoopICmp> +LoopPredication::generateLoopLatchCheck(Type *RangeCheckType) { + + auto *LatchType = LatchCheck.IV->getType(); + if (RangeCheckType == LatchType) + return LatchCheck; + // For now, bail out if latch type is narrower than range type. + if (DL->getTypeSizeInBits(LatchType) < DL->getTypeSizeInBits(RangeCheckType)) + return None; + if (!isSafeToTruncateWideIVType(RangeCheckType)) + return None; + // We can now safely identify the truncated version of the IV and limit for + // RangeCheckType. + LoopICmp NewLatchCheck; + NewLatchCheck.Pred = LatchCheck.Pred; + NewLatchCheck.IV = dyn_cast<SCEVAddRecExpr>( + SE->getTruncateExpr(LatchCheck.IV, RangeCheckType)); + if (!NewLatchCheck.IV) + return None; + NewLatchCheck.Limit = SE->getTruncateExpr(LatchCheck.Limit, RangeCheckType); + DEBUG(dbgs() << "IV of type: " << *LatchType + << "can be represented as range check type:" << *RangeCheckType + << "\n"); + DEBUG(dbgs() << "LatchCheck.IV: " << *NewLatchCheck.IV << "\n"); + DEBUG(dbgs() << "LatchCheck.Limit: " << *NewLatchCheck.Limit << "\n"); + return NewLatchCheck; +} + +bool LoopPredication::isSupportedStep(const SCEV* Step) { + return Step->isOne() || (Step->isAllOnesValue() && EnableCountDownLoop); +} + +bool LoopPredication::CanExpand(const SCEV* S) { + return SE->isLoopInvariant(S, L) && isSafeToExpand(S, *SE); +} + +Optional<Value *> LoopPredication::widenICmpRangeCheckIncrementingLoop( + LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, + SCEVExpander &Expander, IRBuilder<> &Builder) { + auto *Ty = RangeCheck.IV->getType(); + // Generate the widened condition for the forward loop: + // guardStart u< guardLimit && + // latchLimit <pred> guardLimit - 1 - guardStart + latchStart + // where <pred> depends on the latch condition predicate. See the file + // header comment for the reasoning. + // guardLimit - guardStart + latchStart - 1 + const SCEV *GuardStart = RangeCheck.IV->getStart(); + const SCEV *GuardLimit = RangeCheck.Limit; + const SCEV *LatchStart = LatchCheck.IV->getStart(); + const SCEV *LatchLimit = LatchCheck.Limit; + + // guardLimit - guardStart + latchStart - 1 + const SCEV *RHS = + SE->getAddExpr(SE->getMinusSCEV(GuardLimit, GuardStart), + SE->getMinusSCEV(LatchStart, SE->getOne(Ty))); + if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || + !CanExpand(LatchLimit) || !CanExpand(RHS)) { + DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + ICmpInst::Predicate LimitCheckPred; + switch (LatchCheck.Pred) { + case ICmpInst::ICMP_ULT: + LimitCheckPred = ICmpInst::ICMP_ULE; + break; + case ICmpInst::ICMP_ULE: + LimitCheckPred = ICmpInst::ICMP_ULT; + break; + case ICmpInst::ICMP_SLT: + LimitCheckPred = ICmpInst::ICMP_SLE; + break; + case ICmpInst::ICMP_SLE: + LimitCheckPred = ICmpInst::ICMP_SLT; + break; + default: + llvm_unreachable("Unsupported loop latch!"); + } + + DEBUG(dbgs() << "LHS: " << *LatchLimit << "\n"); + DEBUG(dbgs() << "RHS: " << *RHS << "\n"); + DEBUG(dbgs() << "Pred: " << LimitCheckPred << "\n"); + + Instruction *InsertAt = Preheader->getTerminator(); + auto *LimitCheck = + expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, RHS, InsertAt); + auto *FirstIterationCheck = expandCheck(Expander, Builder, RangeCheck.Pred, + GuardStart, GuardLimit, InsertAt); + return Builder.CreateAnd(FirstIterationCheck, LimitCheck); +} + +Optional<Value *> LoopPredication::widenICmpRangeCheckDecrementingLoop( + LoopPredication::LoopICmp LatchCheck, LoopPredication::LoopICmp RangeCheck, + SCEVExpander &Expander, IRBuilder<> &Builder) { + auto *Ty = RangeCheck.IV->getType(); + const SCEV *GuardStart = RangeCheck.IV->getStart(); + const SCEV *GuardLimit = RangeCheck.Limit; + const SCEV *LatchLimit = LatchCheck.Limit; + if (!CanExpand(GuardStart) || !CanExpand(GuardLimit) || + !CanExpand(LatchLimit)) { + DEBUG(dbgs() << "Can't expand limit check!\n"); + return None; + } + // The decrement of the latch check IV should be the same as the + // rangeCheckIV. + auto *PostDecLatchCheckIV = LatchCheck.IV->getPostIncExpr(*SE); + if (RangeCheck.IV != PostDecLatchCheckIV) { + DEBUG(dbgs() << "Not the same. PostDecLatchCheckIV: " + << *PostDecLatchCheckIV + << " and RangeCheckIV: " << *RangeCheck.IV << "\n"); + return None; + } + + // Generate the widened condition for CountDownLoop: + // guardStart u< guardLimit && + // latchLimit <pred> 1. + // See the header comment for reasoning of the checks. + Instruction *InsertAt = Preheader->getTerminator(); + auto LimitCheckPred = ICmpInst::isSigned(LatchCheck.Pred) + ? ICmpInst::ICMP_SGE + : ICmpInst::ICMP_UGE; + auto *FirstIterationCheck = expandCheck(Expander, Builder, ICmpInst::ICMP_ULT, + GuardStart, GuardLimit, InsertAt); + auto *LimitCheck = expandCheck(Expander, Builder, LimitCheckPred, LatchLimit, + SE->getOne(Ty), InsertAt); + return Builder.CreateAnd(FirstIterationCheck, LimitCheck); +} + +/// If ICI can be widened to a loop invariant condition emits the loop +/// invariant condition in the loop preheader and return it, otherwise +/// returns None. +Optional<Value *> LoopPredication::widenICmpRangeCheck(ICmpInst *ICI, + SCEVExpander &Expander, + IRBuilder<> &Builder) { + DEBUG(dbgs() << "Analyzing ICmpInst condition:\n"); + DEBUG(ICI->dump()); + + // parseLoopStructure guarantees that the latch condition is: + // ++i <pred> latchLimit, where <pred> is u<, u<=, s<, or s<=. + // We are looking for the range checks of the form: + // i u< guardLimit + auto RangeCheck = parseLoopICmp(ICI); + if (!RangeCheck) { + DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + return None; + } + DEBUG(dbgs() << "Guard check:\n"); + DEBUG(RangeCheck->dump()); + if (RangeCheck->Pred != ICmpInst::ICMP_ULT) { + DEBUG(dbgs() << "Unsupported range check predicate(" << RangeCheck->Pred + << ")!\n"); + return None; + } + auto *RangeCheckIV = RangeCheck->IV; + if (!RangeCheckIV->isAffine()) { + DEBUG(dbgs() << "Range check IV is not affine!\n"); + return None; + } + auto *Step = RangeCheckIV->getStepRecurrence(*SE); + // We cannot just compare with latch IV step because the latch and range IVs + // may have different types. + if (!isSupportedStep(Step)) { + DEBUG(dbgs() << "Range check and latch have IVs different steps!\n"); + return None; + } + auto *Ty = RangeCheckIV->getType(); + auto CurrLatchCheckOpt = generateLoopLatchCheck(Ty); + if (!CurrLatchCheckOpt) { + DEBUG(dbgs() << "Failed to generate a loop latch check " + "corresponding to range type: " + << *Ty << "\n"); + return None; + } + + LoopICmp CurrLatchCheck = *CurrLatchCheckOpt; + // At this point, the range and latch step should have the same type, but need + // not have the same value (we support both 1 and -1 steps). + assert(Step->getType() == + CurrLatchCheck.IV->getStepRecurrence(*SE)->getType() && + "Range and latch steps should be of same type!"); + if (Step != CurrLatchCheck.IV->getStepRecurrence(*SE)) { + DEBUG(dbgs() << "Range and latch have different step values!\n"); + return None; + } + + if (Step->isOne()) + return widenICmpRangeCheckIncrementingLoop(CurrLatchCheck, *RangeCheck, + Expander, Builder); + else { + assert(Step->isAllOnesValue() && "Step should be -1!"); + return widenICmpRangeCheckDecrementingLoop(CurrLatchCheck, *RangeCheck, + Expander, Builder); + } +} + +bool LoopPredication::widenGuardConditions(IntrinsicInst *Guard, + SCEVExpander &Expander) { + DEBUG(dbgs() << "Processing guard:\n"); + DEBUG(Guard->dump()); + + IRBuilder<> Builder(cast<Instruction>(Preheader->getTerminator())); + + // The guard condition is expected to be in form of: + // cond1 && cond2 && cond3 ... + // Iterate over subconditions looking for for icmp conditions which can be + // widened across loop iterations. Widening these conditions remember the + // resulting list of subconditions in Checks vector. + SmallVector<Value *, 4> Worklist(1, Guard->getOperand(0)); + SmallPtrSet<Value *, 4> Visited; + + SmallVector<Value *, 4> Checks; + + unsigned NumWidened = 0; + do { + Value *Condition = Worklist.pop_back_val(); + if (!Visited.insert(Condition).second) + continue; + + Value *LHS, *RHS; + using namespace llvm::PatternMatch; + if (match(Condition, m_And(m_Value(LHS), m_Value(RHS)))) { + Worklist.push_back(LHS); + Worklist.push_back(RHS); + continue; + } + + if (ICmpInst *ICI = dyn_cast<ICmpInst>(Condition)) { + if (auto NewRangeCheck = widenICmpRangeCheck(ICI, Expander, Builder)) { + Checks.push_back(NewRangeCheck.getValue()); + NumWidened++; + continue; + } + } + + // Save the condition as is if we can't widen it + Checks.push_back(Condition); + } while (Worklist.size() != 0); + + if (NumWidened == 0) + return false; + + // Emit the new guard condition + Builder.SetInsertPoint(Guard); + Value *LastCheck = nullptr; + for (auto *Check : Checks) + if (!LastCheck) + LastCheck = Check; + else + LastCheck = Builder.CreateAnd(LastCheck, Check); + Guard->setOperand(0, LastCheck); + + DEBUG(dbgs() << "Widened checks = " << NumWidened << "\n"); + return true; +} + +Optional<LoopPredication::LoopICmp> LoopPredication::parseLoopLatchICmp() { + using namespace PatternMatch; + + BasicBlock *LoopLatch = L->getLoopLatch(); + if (!LoopLatch) { + DEBUG(dbgs() << "The loop doesn't have a single latch!\n"); + return None; + } + + ICmpInst::Predicate Pred; + Value *LHS, *RHS; + BasicBlock *TrueDest, *FalseDest; + + if (!match(LoopLatch->getTerminator(), + m_Br(m_ICmp(Pred, m_Value(LHS), m_Value(RHS)), TrueDest, + FalseDest))) { + DEBUG(dbgs() << "Failed to match the latch terminator!\n"); + return None; + } + assert((TrueDest == L->getHeader() || FalseDest == L->getHeader()) && + "One of the latch's destinations must be the header"); + if (TrueDest != L->getHeader()) + Pred = ICmpInst::getInversePredicate(Pred); + + auto Result = parseLoopICmp(Pred, LHS, RHS); + if (!Result) { + DEBUG(dbgs() << "Failed to parse the loop latch condition!\n"); + return None; + } + + // Check affine first, so if it's not we don't try to compute the step + // recurrence. + if (!Result->IV->isAffine()) { + DEBUG(dbgs() << "The induction variable is not affine!\n"); + return None; + } + + auto *Step = Result->IV->getStepRecurrence(*SE); + if (!isSupportedStep(Step)) { + DEBUG(dbgs() << "Unsupported loop stride(" << *Step << ")!\n"); + return None; + } + + auto IsUnsupportedPredicate = [](const SCEV *Step, ICmpInst::Predicate Pred) { + if (Step->isOne()) { + return Pred != ICmpInst::ICMP_ULT && Pred != ICmpInst::ICMP_SLT && + Pred != ICmpInst::ICMP_ULE && Pred != ICmpInst::ICMP_SLE; + } else { + assert(Step->isAllOnesValue() && "Step should be -1!"); + return Pred != ICmpInst::ICMP_UGT && Pred != ICmpInst::ICMP_SGT; + } + }; + + if (IsUnsupportedPredicate(Step, Result->Pred)) { + DEBUG(dbgs() << "Unsupported loop latch predicate(" << Result->Pred + << ")!\n"); + return None; + } + return Result; +} + +// Returns true if its safe to truncate the IV to RangeCheckType. +bool LoopPredication::isSafeToTruncateWideIVType(Type *RangeCheckType) { + if (!EnableIVTruncation) + return false; + assert(DL->getTypeSizeInBits(LatchCheck.IV->getType()) > + DL->getTypeSizeInBits(RangeCheckType) && + "Expected latch check IV type to be larger than range check operand " + "type!"); + // The start and end values of the IV should be known. This is to guarantee + // that truncating the wide type will not lose information. + auto *Limit = dyn_cast<SCEVConstant>(LatchCheck.Limit); + auto *Start = dyn_cast<SCEVConstant>(LatchCheck.IV->getStart()); + if (!Limit || !Start) + return false; + // This check makes sure that the IV does not change sign during loop + // iterations. Consider latchType = i64, LatchStart = 5, Pred = ICMP_SGE, + // LatchEnd = 2, rangeCheckType = i32. If it's not a monotonic predicate, the + // IV wraps around, and the truncation of the IV would lose the range of + // iterations between 2^32 and 2^64. + bool Increasing; + if (!SE->isMonotonicPredicate(LatchCheck.IV, LatchCheck.Pred, Increasing)) + return false; + // The active bits should be less than the bits in the RangeCheckType. This + // guarantees that truncating the latch check to RangeCheckType is a safe + // operation. + auto RangeCheckTypeBitSize = DL->getTypeSizeInBits(RangeCheckType); + return Start->getAPInt().getActiveBits() < RangeCheckTypeBitSize && + Limit->getAPInt().getActiveBits() < RangeCheckTypeBitSize; +} + +bool LoopPredication::runOnLoop(Loop *Loop) { + L = Loop; + + DEBUG(dbgs() << "Analyzing "); + DEBUG(L->dump()); + + Module *M = L->getHeader()->getModule(); + + // There is nothing to do if the module doesn't use guards + auto *GuardDecl = + M->getFunction(Intrinsic::getName(Intrinsic::experimental_guard)); + if (!GuardDecl || GuardDecl->use_empty()) + return false; + + DL = &M->getDataLayout(); + + Preheader = L->getLoopPreheader(); + if (!Preheader) + return false; + + auto LatchCheckOpt = parseLoopLatchICmp(); + if (!LatchCheckOpt) + return false; + LatchCheck = *LatchCheckOpt; + + DEBUG(dbgs() << "Latch check:\n"); + DEBUG(LatchCheck.dump()); + + // Collect all the guards into a vector and process later, so as not + // to invalidate the instruction iterator. + SmallVector<IntrinsicInst *, 4> Guards; + for (const auto BB : L->blocks()) + for (auto &I : *BB) + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::experimental_guard) + Guards.push_back(II); + + if (Guards.empty()) + return false; + + SCEVExpander Expander(*SE, *DL, "loop-predication"); + + bool Changed = false; + for (auto *Guard : Guards) + Changed |= widenGuardConditions(Guard, Expander); + + return Changed; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp new file mode 100644 index 000000000000..d1a54b877950 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp @@ -0,0 +1,1801 @@ +//===- LoopReroll.cpp - Loop rerolling pass -------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements a simple loop reroller. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <iterator> +#include <map> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "loop-reroll" + +STATISTIC(NumRerolledLoops, "Number of rerolled loops"); + +static cl::opt<unsigned> +MaxInc("max-reroll-increment", cl::init(2048), cl::Hidden, + cl::desc("The maximum increment for loop rerolling")); + +static cl::opt<unsigned> +NumToleratedFailedMatches("reroll-num-tolerated-failed-matches", cl::init(400), + cl::Hidden, + cl::desc("The maximum number of failures to tolerate" + " during fuzzy matching. (default: 400)")); + +// This loop re-rolling transformation aims to transform loops like this: +// +// int foo(int a); +// void bar(int *x) { +// for (int i = 0; i < 500; i += 3) { +// foo(i); +// foo(i+1); +// foo(i+2); +// } +// } +// +// into a loop like this: +// +// void bar(int *x) { +// for (int i = 0; i < 500; ++i) +// foo(i); +// } +// +// It does this by looking for loops that, besides the latch code, are composed +// of isomorphic DAGs of instructions, with each DAG rooted at some increment +// to the induction variable, and where each DAG is isomorphic to the DAG +// rooted at the induction variable (excepting the sub-DAGs which root the +// other induction-variable increments). In other words, we're looking for loop +// bodies of the form: +// +// %iv = phi [ (preheader, ...), (body, %iv.next) ] +// f(%iv) +// %iv.1 = add %iv, 1 <-- a root increment +// f(%iv.1) +// %iv.2 = add %iv, 2 <-- a root increment +// f(%iv.2) +// %iv.scale_m_1 = add %iv, scale-1 <-- a root increment +// f(%iv.scale_m_1) +// ... +// %iv.next = add %iv, scale +// %cmp = icmp(%iv, ...) +// br %cmp, header, exit +// +// where each f(i) is a set of instructions that, collectively, are a function +// only of i (and other loop-invariant values). +// +// As a special case, we can also reroll loops like this: +// +// int foo(int); +// void bar(int *x) { +// for (int i = 0; i < 500; ++i) { +// x[3*i] = foo(0); +// x[3*i+1] = foo(0); +// x[3*i+2] = foo(0); +// } +// } +// +// into this: +// +// void bar(int *x) { +// for (int i = 0; i < 1500; ++i) +// x[i] = foo(0); +// } +// +// in which case, we're looking for inputs like this: +// +// %iv = phi [ (preheader, ...), (body, %iv.next) ] +// %scaled.iv = mul %iv, scale +// f(%scaled.iv) +// %scaled.iv.1 = add %scaled.iv, 1 +// f(%scaled.iv.1) +// %scaled.iv.2 = add %scaled.iv, 2 +// f(%scaled.iv.2) +// %scaled.iv.scale_m_1 = add %scaled.iv, scale-1 +// f(%scaled.iv.scale_m_1) +// ... +// %iv.next = add %iv, 1 +// %cmp = icmp(%iv, ...) +// br %cmp, header, exit + +namespace { + + enum IterationLimits { + /// The maximum number of iterations that we'll try and reroll. + IL_MaxRerollIterations = 32, + /// The bitvector index used by loop induction variables and other + /// instructions that belong to all iterations. + IL_All, + IL_End + }; + + class LoopReroll : public LoopPass { + public: + static char ID; // Pass ID, replacement for typeid + + LoopReroll() : LoopPass(ID) { + initializeLoopRerollPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } + + protected: + AliasAnalysis *AA; + LoopInfo *LI; + ScalarEvolution *SE; + TargetLibraryInfo *TLI; + DominatorTree *DT; + bool PreserveLCSSA; + + using SmallInstructionVector = SmallVector<Instruction *, 16>; + using SmallInstructionSet = SmallSet<Instruction *, 16>; + + // Map between induction variable and its increment + DenseMap<Instruction *, int64_t> IVToIncMap; + + // For loop with multiple induction variable, remember the one used only to + // control the loop. + Instruction *LoopControlIV; + + // A chain of isomorphic instructions, identified by a single-use PHI + // representing a reduction. Only the last value may be used outside the + // loop. + struct SimpleLoopReduction { + SimpleLoopReduction(Instruction *P, Loop *L) : Instructions(1, P) { + assert(isa<PHINode>(P) && "First reduction instruction must be a PHI"); + add(L); + } + + bool valid() const { + return Valid; + } + + Instruction *getPHI() const { + assert(Valid && "Using invalid reduction"); + return Instructions.front(); + } + + Instruction *getReducedValue() const { + assert(Valid && "Using invalid reduction"); + return Instructions.back(); + } + + Instruction *get(size_t i) const { + assert(Valid && "Using invalid reduction"); + return Instructions[i+1]; + } + + Instruction *operator [] (size_t i) const { return get(i); } + + // The size, ignoring the initial PHI. + size_t size() const { + assert(Valid && "Using invalid reduction"); + return Instructions.size()-1; + } + + using iterator = SmallInstructionVector::iterator; + using const_iterator = SmallInstructionVector::const_iterator; + + iterator begin() { + assert(Valid && "Using invalid reduction"); + return std::next(Instructions.begin()); + } + + const_iterator begin() const { + assert(Valid && "Using invalid reduction"); + return std::next(Instructions.begin()); + } + + iterator end() { return Instructions.end(); } + const_iterator end() const { return Instructions.end(); } + + protected: + bool Valid = false; + SmallInstructionVector Instructions; + + void add(Loop *L); + }; + + // The set of all reductions, and state tracking of possible reductions + // during loop instruction processing. + struct ReductionTracker { + using SmallReductionVector = SmallVector<SimpleLoopReduction, 16>; + + // Add a new possible reduction. + void addSLR(SimpleLoopReduction &SLR) { PossibleReds.push_back(SLR); } + + // Setup to track possible reductions corresponding to the provided + // rerolling scale. Only reductions with a number of non-PHI instructions + // that is divisible by the scale are considered. Three instructions sets + // are filled in: + // - A set of all possible instructions in eligible reductions. + // - A set of all PHIs in eligible reductions + // - A set of all reduced values (last instructions) in eligible + // reductions. + void restrictToScale(uint64_t Scale, + SmallInstructionSet &PossibleRedSet, + SmallInstructionSet &PossibleRedPHISet, + SmallInstructionSet &PossibleRedLastSet) { + PossibleRedIdx.clear(); + PossibleRedIter.clear(); + Reds.clear(); + + for (unsigned i = 0, e = PossibleReds.size(); i != e; ++i) + if (PossibleReds[i].size() % Scale == 0) { + PossibleRedLastSet.insert(PossibleReds[i].getReducedValue()); + PossibleRedPHISet.insert(PossibleReds[i].getPHI()); + + PossibleRedSet.insert(PossibleReds[i].getPHI()); + PossibleRedIdx[PossibleReds[i].getPHI()] = i; + for (Instruction *J : PossibleReds[i]) { + PossibleRedSet.insert(J); + PossibleRedIdx[J] = i; + } + } + } + + // The functions below are used while processing the loop instructions. + + // Are the two instructions both from reductions, and furthermore, from + // the same reduction? + bool isPairInSame(Instruction *J1, Instruction *J2) { + DenseMap<Instruction *, int>::iterator J1I = PossibleRedIdx.find(J1); + if (J1I != PossibleRedIdx.end()) { + DenseMap<Instruction *, int>::iterator J2I = PossibleRedIdx.find(J2); + if (J2I != PossibleRedIdx.end() && J1I->second == J2I->second) + return true; + } + + return false; + } + + // The two provided instructions, the first from the base iteration, and + // the second from iteration i, form a matched pair. If these are part of + // a reduction, record that fact. + void recordPair(Instruction *J1, Instruction *J2, unsigned i) { + if (PossibleRedIdx.count(J1)) { + assert(PossibleRedIdx.count(J2) && + "Recording reduction vs. non-reduction instruction?"); + + PossibleRedIter[J1] = 0; + PossibleRedIter[J2] = i; + + int Idx = PossibleRedIdx[J1]; + assert(Idx == PossibleRedIdx[J2] && + "Recording pair from different reductions?"); + Reds.insert(Idx); + } + } + + // The functions below can be called after we've finished processing all + // instructions in the loop, and we know which reductions were selected. + + bool validateSelected(); + void replaceSelected(); + + protected: + // The vector of all possible reductions (for any scale). + SmallReductionVector PossibleReds; + + DenseMap<Instruction *, int> PossibleRedIdx; + DenseMap<Instruction *, int> PossibleRedIter; + DenseSet<int> Reds; + }; + + // A DAGRootSet models an induction variable being used in a rerollable + // loop. For example, + // + // x[i*3+0] = y1 + // x[i*3+1] = y2 + // x[i*3+2] = y3 + // + // Base instruction -> i*3 + // +---+----+ + // / | \ + // ST[y1] +1 +2 <-- Roots + // | | + // ST[y2] ST[y3] + // + // There may be multiple DAGRoots, for example: + // + // x[i*2+0] = ... (1) + // x[i*2+1] = ... (1) + // x[i*2+4] = ... (2) + // x[i*2+5] = ... (2) + // x[(i+1234)*2+5678] = ... (3) + // x[(i+1234)*2+5679] = ... (3) + // + // The loop will be rerolled by adding a new loop induction variable, + // one for the Base instruction in each DAGRootSet. + // + struct DAGRootSet { + Instruction *BaseInst; + SmallInstructionVector Roots; + + // The instructions between IV and BaseInst (but not including BaseInst). + SmallInstructionSet SubsumedInsts; + }; + + // The set of all DAG roots, and state tracking of all roots + // for a particular induction variable. + struct DAGRootTracker { + DAGRootTracker(LoopReroll *Parent, Loop *L, Instruction *IV, + ScalarEvolution *SE, AliasAnalysis *AA, + TargetLibraryInfo *TLI, DominatorTree *DT, LoopInfo *LI, + bool PreserveLCSSA, + DenseMap<Instruction *, int64_t> &IncrMap, + Instruction *LoopCtrlIV) + : Parent(Parent), L(L), SE(SE), AA(AA), TLI(TLI), DT(DT), LI(LI), + PreserveLCSSA(PreserveLCSSA), IV(IV), IVToIncMap(IncrMap), + LoopControlIV(LoopCtrlIV) {} + + /// Stage 1: Find all the DAG roots for the induction variable. + bool findRoots(); + + /// Stage 2: Validate if the found roots are valid. + bool validate(ReductionTracker &Reductions); + + /// Stage 3: Assuming validate() returned true, perform the + /// replacement. + /// @param IterCount The maximum iteration count of L. + void replace(const SCEV *IterCount); + + protected: + using UsesTy = MapVector<Instruction *, BitVector>; + + void findRootsRecursive(Instruction *IVU, + SmallInstructionSet SubsumedInsts); + bool findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts); + bool collectPossibleRoots(Instruction *Base, + std::map<int64_t,Instruction*> &Roots); + bool validateRootSet(DAGRootSet &DRS); + + bool collectUsedInstructions(SmallInstructionSet &PossibleRedSet); + void collectInLoopUserSet(const SmallInstructionVector &Roots, + const SmallInstructionSet &Exclude, + const SmallInstructionSet &Final, + DenseSet<Instruction *> &Users); + void collectInLoopUserSet(Instruction *Root, + const SmallInstructionSet &Exclude, + const SmallInstructionSet &Final, + DenseSet<Instruction *> &Users); + + UsesTy::iterator nextInstr(int Val, UsesTy &In, + const SmallInstructionSet &Exclude, + UsesTy::iterator *StartI=nullptr); + bool isBaseInst(Instruction *I); + bool isRootInst(Instruction *I); + bool instrDependsOn(Instruction *I, + UsesTy::iterator Start, + UsesTy::iterator End); + void replaceIV(Instruction *Inst, Instruction *IV, const SCEV *IterCount); + void updateNonLoopCtrlIncr(); + + LoopReroll *Parent; + + // Members of Parent, replicated here for brevity. + Loop *L; + ScalarEvolution *SE; + AliasAnalysis *AA; + TargetLibraryInfo *TLI; + DominatorTree *DT; + LoopInfo *LI; + bool PreserveLCSSA; + + // The loop induction variable. + Instruction *IV; + + // Loop step amount. + int64_t Inc; + + // Loop reroll count; if Inc == 1, this records the scaling applied + // to the indvar: a[i*2+0] = ...; a[i*2+1] = ... ; + // If Inc is not 1, Scale = Inc. + uint64_t Scale; + + // The roots themselves. + SmallVector<DAGRootSet,16> RootSets; + + // All increment instructions for IV. + SmallInstructionVector LoopIncs; + + // Map of all instructions in the loop (in order) to the iterations + // they are used in (or specially, IL_All for instructions + // used in the loop increment mechanism). + UsesTy Uses; + + // Map between induction variable and its increment + DenseMap<Instruction *, int64_t> &IVToIncMap; + + Instruction *LoopControlIV; + }; + + // Check if it is a compare-like instruction whose user is a branch + bool isCompareUsedByBranch(Instruction *I) { + auto *TI = I->getParent()->getTerminator(); + if (!isa<BranchInst>(TI) || !isa<CmpInst>(I)) + return false; + return I->hasOneUse() && TI->getOperand(0) == I; + }; + + bool isLoopControlIV(Loop *L, Instruction *IV); + void collectPossibleIVs(Loop *L, SmallInstructionVector &PossibleIVs); + void collectPossibleReductions(Loop *L, + ReductionTracker &Reductions); + bool reroll(Instruction *IV, Loop *L, BasicBlock *Header, const SCEV *IterCount, + ReductionTracker &Reductions); + }; + +} // end anonymous namespace + +char LoopReroll::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopReroll, "loop-reroll", "Reroll loops", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(LoopReroll, "loop-reroll", "Reroll loops", false, false) + +Pass *llvm::createLoopRerollPass() { + return new LoopReroll; +} + +// Returns true if the provided instruction is used outside the given loop. +// This operates like Instruction::isUsedOutsideOfBlock, but considers PHIs in +// non-loop blocks to be outside the loop. +static bool hasUsesOutsideLoop(Instruction *I, Loop *L) { + for (User *U : I->users()) { + if (!L->contains(cast<Instruction>(U))) + return true; + } + return false; +} + +static const SCEVConstant *getIncrmentFactorSCEV(ScalarEvolution *SE, + const SCEV *SCEVExpr, + Instruction &IV) { + const SCEVMulExpr *MulSCEV = dyn_cast<SCEVMulExpr>(SCEVExpr); + + // If StepRecurrence of a SCEVExpr is a constant (c1 * c2, c2 = sizeof(ptr)), + // Return c1. + if (!MulSCEV && IV.getType()->isPointerTy()) + if (const SCEVConstant *IncSCEV = dyn_cast<SCEVConstant>(SCEVExpr)) { + const PointerType *PTy = cast<PointerType>(IV.getType()); + Type *ElTy = PTy->getElementType(); + const SCEV *SizeOfExpr = + SE->getSizeOfExpr(SE->getEffectiveSCEVType(IV.getType()), ElTy); + if (IncSCEV->getValue()->getValue().isNegative()) { + const SCEV *NewSCEV = + SE->getUDivExpr(SE->getNegativeSCEV(SCEVExpr), SizeOfExpr); + return dyn_cast<SCEVConstant>(SE->getNegativeSCEV(NewSCEV)); + } else { + return dyn_cast<SCEVConstant>(SE->getUDivExpr(SCEVExpr, SizeOfExpr)); + } + } + + if (!MulSCEV) + return nullptr; + + // If StepRecurrence of a SCEVExpr is a c * sizeof(x), where c is constant, + // Return c. + const SCEVConstant *CIncSCEV = nullptr; + for (const SCEV *Operand : MulSCEV->operands()) { + if (const SCEVConstant *Constant = dyn_cast<SCEVConstant>(Operand)) { + CIncSCEV = Constant; + } else if (const SCEVUnknown *Unknown = dyn_cast<SCEVUnknown>(Operand)) { + Type *AllocTy; + if (!Unknown->isSizeOf(AllocTy)) + break; + } else { + return nullptr; + } + } + return CIncSCEV; +} + +// Check if an IV is only used to control the loop. There are two cases: +// 1. It only has one use which is loop increment, and the increment is only +// used by comparison and the PHI (could has sext with nsw in between), and the +// comparison is only used by branch. +// 2. It is used by loop increment and the comparison, the loop increment is +// only used by the PHI, and the comparison is used only by the branch. +bool LoopReroll::isLoopControlIV(Loop *L, Instruction *IV) { + unsigned IVUses = IV->getNumUses(); + if (IVUses != 2 && IVUses != 1) + return false; + + for (auto *User : IV->users()) { + int32_t IncOrCmpUses = User->getNumUses(); + bool IsCompInst = isCompareUsedByBranch(cast<Instruction>(User)); + + // User can only have one or two uses. + if (IncOrCmpUses != 2 && IncOrCmpUses != 1) + return false; + + // Case 1 + if (IVUses == 1) { + // The only user must be the loop increment. + // The loop increment must have two uses. + if (IsCompInst || IncOrCmpUses != 2) + return false; + } + + // Case 2 + if (IVUses == 2 && IncOrCmpUses != 1) + return false; + + // The users of the IV must be a binary operation or a comparison + if (auto *BO = dyn_cast<BinaryOperator>(User)) { + if (BO->getOpcode() == Instruction::Add) { + // Loop Increment + // User of Loop Increment should be either PHI or CMP + for (auto *UU : User->users()) { + if (PHINode *PN = dyn_cast<PHINode>(UU)) { + if (PN != IV) + return false; + } + // Must be a CMP or an ext (of a value with nsw) then CMP + else { + Instruction *UUser = dyn_cast<Instruction>(UU); + // Skip SExt if we are extending an nsw value + // TODO: Allow ZExt too + if (BO->hasNoSignedWrap() && UUser && UUser->hasOneUse() && + isa<SExtInst>(UUser)) + UUser = dyn_cast<Instruction>(*(UUser->user_begin())); + if (!isCompareUsedByBranch(UUser)) + return false; + } + } + } else + return false; + // Compare : can only have one use, and must be branch + } else if (!IsCompInst) + return false; + } + return true; +} + +// Collect the list of loop induction variables with respect to which it might +// be possible to reroll the loop. +void LoopReroll::collectPossibleIVs(Loop *L, + SmallInstructionVector &PossibleIVs) { + BasicBlock *Header = L->getHeader(); + for (BasicBlock::iterator I = Header->begin(), + IE = Header->getFirstInsertionPt(); I != IE; ++I) { + if (!isa<PHINode>(I)) + continue; + if (!I->getType()->isIntegerTy() && !I->getType()->isPointerTy()) + continue; + + if (const SCEVAddRecExpr *PHISCEV = + dyn_cast<SCEVAddRecExpr>(SE->getSCEV(&*I))) { + if (PHISCEV->getLoop() != L) + continue; + if (!PHISCEV->isAffine()) + continue; + const SCEVConstant *IncSCEV = nullptr; + if (I->getType()->isPointerTy()) + IncSCEV = + getIncrmentFactorSCEV(SE, PHISCEV->getStepRecurrence(*SE), *I); + else + IncSCEV = dyn_cast<SCEVConstant>(PHISCEV->getStepRecurrence(*SE)); + if (IncSCEV) { + const APInt &AInt = IncSCEV->getValue()->getValue().abs(); + if (IncSCEV->getValue()->isZero() || AInt.uge(MaxInc)) + continue; + IVToIncMap[&*I] = IncSCEV->getValue()->getSExtValue(); + DEBUG(dbgs() << "LRR: Possible IV: " << *I << " = " << *PHISCEV + << "\n"); + + if (isLoopControlIV(L, &*I)) { + assert(!LoopControlIV && "Found two loop control only IV"); + LoopControlIV = &(*I); + DEBUG(dbgs() << "LRR: Possible loop control only IV: " << *I << " = " + << *PHISCEV << "\n"); + } else + PossibleIVs.push_back(&*I); + } + } + } +} + +// Add the remainder of the reduction-variable chain to the instruction vector +// (the initial PHINode has already been added). If successful, the object is +// marked as valid. +void LoopReroll::SimpleLoopReduction::add(Loop *L) { + assert(!Valid && "Cannot add to an already-valid chain"); + + // The reduction variable must be a chain of single-use instructions + // (including the PHI), except for the last value (which is used by the PHI + // and also outside the loop). + Instruction *C = Instructions.front(); + if (C->user_empty()) + return; + + do { + C = cast<Instruction>(*C->user_begin()); + if (C->hasOneUse()) { + if (!C->isBinaryOp()) + return; + + if (!(isa<PHINode>(Instructions.back()) || + C->isSameOperationAs(Instructions.back()))) + return; + + Instructions.push_back(C); + } + } while (C->hasOneUse()); + + if (Instructions.size() < 2 || + !C->isSameOperationAs(Instructions.back()) || + C->use_empty()) + return; + + // C is now the (potential) last instruction in the reduction chain. + for (User *U : C->users()) { + // The only in-loop user can be the initial PHI. + if (L->contains(cast<Instruction>(U))) + if (cast<Instruction>(U) != Instructions.front()) + return; + } + + Instructions.push_back(C); + Valid = true; +} + +// Collect the vector of possible reduction variables. +void LoopReroll::collectPossibleReductions(Loop *L, + ReductionTracker &Reductions) { + BasicBlock *Header = L->getHeader(); + for (BasicBlock::iterator I = Header->begin(), + IE = Header->getFirstInsertionPt(); I != IE; ++I) { + if (!isa<PHINode>(I)) + continue; + if (!I->getType()->isSingleValueType()) + continue; + + SimpleLoopReduction SLR(&*I, L); + if (!SLR.valid()) + continue; + + DEBUG(dbgs() << "LRR: Possible reduction: " << *I << " (with " << + SLR.size() << " chained instructions)\n"); + Reductions.addSLR(SLR); + } +} + +// Collect the set of all users of the provided root instruction. This set of +// users contains not only the direct users of the root instruction, but also +// all users of those users, and so on. There are two exceptions: +// +// 1. Instructions in the set of excluded instructions are never added to the +// use set (even if they are users). This is used, for example, to exclude +// including root increments in the use set of the primary IV. +// +// 2. Instructions in the set of final instructions are added to the use set +// if they are users, but their users are not added. This is used, for +// example, to prevent a reduction update from forcing all later reduction +// updates into the use set. +void LoopReroll::DAGRootTracker::collectInLoopUserSet( + Instruction *Root, const SmallInstructionSet &Exclude, + const SmallInstructionSet &Final, + DenseSet<Instruction *> &Users) { + SmallInstructionVector Queue(1, Root); + while (!Queue.empty()) { + Instruction *I = Queue.pop_back_val(); + if (!Users.insert(I).second) + continue; + + if (!Final.count(I)) + for (Use &U : I->uses()) { + Instruction *User = cast<Instruction>(U.getUser()); + if (PHINode *PN = dyn_cast<PHINode>(User)) { + // Ignore "wrap-around" uses to PHIs of this loop's header. + if (PN->getIncomingBlock(U) == L->getHeader()) + continue; + } + + if (L->contains(User) && !Exclude.count(User)) { + Queue.push_back(User); + } + } + + // We also want to collect single-user "feeder" values. + for (User::op_iterator OI = I->op_begin(), + OIE = I->op_end(); OI != OIE; ++OI) { + if (Instruction *Op = dyn_cast<Instruction>(*OI)) + if (Op->hasOneUse() && L->contains(Op) && !Exclude.count(Op) && + !Final.count(Op)) + Queue.push_back(Op); + } + } +} + +// Collect all of the users of all of the provided root instructions (combined +// into a single set). +void LoopReroll::DAGRootTracker::collectInLoopUserSet( + const SmallInstructionVector &Roots, + const SmallInstructionSet &Exclude, + const SmallInstructionSet &Final, + DenseSet<Instruction *> &Users) { + for (Instruction *Root : Roots) + collectInLoopUserSet(Root, Exclude, Final, Users); +} + +static bool isUnorderedLoadStore(Instruction *I) { + if (LoadInst *LI = dyn_cast<LoadInst>(I)) + return LI->isUnordered(); + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + return SI->isUnordered(); + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I)) + return !MI->isVolatile(); + return false; +} + +/// Return true if IVU is a "simple" arithmetic operation. +/// This is used for narrowing the search space for DAGRoots; only arithmetic +/// and GEPs can be part of a DAGRoot. +static bool isSimpleArithmeticOp(User *IVU) { + if (Instruction *I = dyn_cast<Instruction>(IVU)) { + switch (I->getOpcode()) { + default: return false; + case Instruction::Add: + case Instruction::Sub: + case Instruction::Mul: + case Instruction::Shl: + case Instruction::AShr: + case Instruction::LShr: + case Instruction::GetElementPtr: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + return true; + } + } + return false; +} + +static bool isLoopIncrement(User *U, Instruction *IV) { + BinaryOperator *BO = dyn_cast<BinaryOperator>(U); + + if ((BO && BO->getOpcode() != Instruction::Add) || + (!BO && !isa<GetElementPtrInst>(U))) + return false; + + for (auto *UU : U->users()) { + PHINode *PN = dyn_cast<PHINode>(UU); + if (PN && PN == IV) + return true; + } + return false; +} + +bool LoopReroll::DAGRootTracker:: +collectPossibleRoots(Instruction *Base, std::map<int64_t,Instruction*> &Roots) { + SmallInstructionVector BaseUsers; + + for (auto *I : Base->users()) { + ConstantInt *CI = nullptr; + + if (isLoopIncrement(I, IV)) { + LoopIncs.push_back(cast<Instruction>(I)); + continue; + } + + // The root nodes must be either GEPs, ORs or ADDs. + if (auto *BO = dyn_cast<BinaryOperator>(I)) { + if (BO->getOpcode() == Instruction::Add || + BO->getOpcode() == Instruction::Or) + CI = dyn_cast<ConstantInt>(BO->getOperand(1)); + } else if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) { + Value *LastOperand = GEP->getOperand(GEP->getNumOperands()-1); + CI = dyn_cast<ConstantInt>(LastOperand); + } + + if (!CI) { + if (Instruction *II = dyn_cast<Instruction>(I)) { + BaseUsers.push_back(II); + continue; + } else { + DEBUG(dbgs() << "LRR: Aborting due to non-instruction: " << *I << "\n"); + return false; + } + } + + int64_t V = std::abs(CI->getValue().getSExtValue()); + if (Roots.find(V) != Roots.end()) + // No duplicates, please. + return false; + + Roots[V] = cast<Instruction>(I); + } + + // Make sure we have at least two roots. + if (Roots.empty() || (Roots.size() == 1 && BaseUsers.empty())) + return false; + + // If we found non-loop-inc, non-root users of Base, assume they are + // for the zeroth root index. This is because "add %a, 0" gets optimized + // away. + if (BaseUsers.size()) { + if (Roots.find(0) != Roots.end()) { + DEBUG(dbgs() << "LRR: Multiple roots found for base - aborting!\n"); + return false; + } + Roots[0] = Base; + } + + // Calculate the number of users of the base, or lowest indexed, iteration. + unsigned NumBaseUses = BaseUsers.size(); + if (NumBaseUses == 0) + NumBaseUses = Roots.begin()->second->getNumUses(); + + // Check that every node has the same number of users. + for (auto &KV : Roots) { + if (KV.first == 0) + continue; + if (!KV.second->hasNUses(NumBaseUses)) { + DEBUG(dbgs() << "LRR: Aborting - Root and Base #users not the same: " + << "#Base=" << NumBaseUses << ", #Root=" << + KV.second->getNumUses() << "\n"); + return false; + } + } + + return true; +} + +void LoopReroll::DAGRootTracker:: +findRootsRecursive(Instruction *I, SmallInstructionSet SubsumedInsts) { + // Does the user look like it could be part of a root set? + // All its users must be simple arithmetic ops. + if (I->hasNUsesOrMore(IL_MaxRerollIterations + 1)) + return; + + if (I != IV && findRootsBase(I, SubsumedInsts)) + return; + + SubsumedInsts.insert(I); + + for (User *V : I->users()) { + Instruction *I = cast<Instruction>(V); + if (is_contained(LoopIncs, I)) + continue; + + if (!isSimpleArithmeticOp(I)) + continue; + + // The recursive call makes a copy of SubsumedInsts. + findRootsRecursive(I, SubsumedInsts); + } +} + +bool LoopReroll::DAGRootTracker::validateRootSet(DAGRootSet &DRS) { + if (DRS.Roots.empty()) + return false; + + // Consider a DAGRootSet with N-1 roots (so N different values including + // BaseInst). + // Define d = Roots[0] - BaseInst, which should be the same as + // Roots[I] - Roots[I-1] for all I in [1..N). + // Define D = BaseInst@J - BaseInst@J-1, where "@J" means the value at the + // loop iteration J. + // + // Now, For the loop iterations to be consecutive: + // D = d * N + const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst)); + if (!ADR) + return false; + unsigned N = DRS.Roots.size() + 1; + const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), ADR); + const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N); + if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV)) + return false; + + return true; +} + +bool LoopReroll::DAGRootTracker:: +findRootsBase(Instruction *IVU, SmallInstructionSet SubsumedInsts) { + // The base of a RootSet must be an AddRec, so it can be erased. + const auto *IVU_ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(IVU)); + if (!IVU_ADR || IVU_ADR->getLoop() != L) + return false; + + std::map<int64_t, Instruction*> V; + if (!collectPossibleRoots(IVU, V)) + return false; + + // If we didn't get a root for index zero, then IVU must be + // subsumed. + if (V.find(0) == V.end()) + SubsumedInsts.insert(IVU); + + // Partition the vector into monotonically increasing indexes. + DAGRootSet DRS; + DRS.BaseInst = nullptr; + + SmallVector<DAGRootSet, 16> PotentialRootSets; + + for (auto &KV : V) { + if (!DRS.BaseInst) { + DRS.BaseInst = KV.second; + DRS.SubsumedInsts = SubsumedInsts; + } else if (DRS.Roots.empty()) { + DRS.Roots.push_back(KV.second); + } else if (V.find(KV.first - 1) != V.end()) { + DRS.Roots.push_back(KV.second); + } else { + // Linear sequence terminated. + if (!validateRootSet(DRS)) + return false; + + // Construct a new DAGRootSet with the next sequence. + PotentialRootSets.push_back(DRS); + DRS.BaseInst = KV.second; + DRS.Roots.clear(); + } + } + + if (!validateRootSet(DRS)) + return false; + + PotentialRootSets.push_back(DRS); + + RootSets.append(PotentialRootSets.begin(), PotentialRootSets.end()); + + return true; +} + +bool LoopReroll::DAGRootTracker::findRoots() { + Inc = IVToIncMap[IV]; + + assert(RootSets.empty() && "Unclean state!"); + if (std::abs(Inc) == 1) { + for (auto *IVU : IV->users()) { + if (isLoopIncrement(IVU, IV)) + LoopIncs.push_back(cast<Instruction>(IVU)); + } + findRootsRecursive(IV, SmallInstructionSet()); + LoopIncs.push_back(IV); + } else { + if (!findRootsBase(IV, SmallInstructionSet())) + return false; + } + + // Ensure all sets have the same size. + if (RootSets.empty()) { + DEBUG(dbgs() << "LRR: Aborting because no root sets found!\n"); + return false; + } + for (auto &V : RootSets) { + if (V.Roots.empty() || V.Roots.size() != RootSets[0].Roots.size()) { + DEBUG(dbgs() + << "LRR: Aborting because not all root sets have the same size\n"); + return false; + } + } + + Scale = RootSets[0].Roots.size() + 1; + + if (Scale > IL_MaxRerollIterations) { + DEBUG(dbgs() << "LRR: Aborting - too many iterations found. " + << "#Found=" << Scale << ", #Max=" << IL_MaxRerollIterations + << "\n"); + return false; + } + + DEBUG(dbgs() << "LRR: Successfully found roots: Scale=" << Scale << "\n"); + + return true; +} + +bool LoopReroll::DAGRootTracker::collectUsedInstructions(SmallInstructionSet &PossibleRedSet) { + // Populate the MapVector with all instructions in the block, in order first, + // so we can iterate over the contents later in perfect order. + for (auto &I : *L->getHeader()) { + Uses[&I].resize(IL_End); + } + + SmallInstructionSet Exclude; + for (auto &DRS : RootSets) { + Exclude.insert(DRS.Roots.begin(), DRS.Roots.end()); + Exclude.insert(DRS.SubsumedInsts.begin(), DRS.SubsumedInsts.end()); + Exclude.insert(DRS.BaseInst); + } + Exclude.insert(LoopIncs.begin(), LoopIncs.end()); + + for (auto &DRS : RootSets) { + DenseSet<Instruction*> VBase; + collectInLoopUserSet(DRS.BaseInst, Exclude, PossibleRedSet, VBase); + for (auto *I : VBase) { + Uses[I].set(0); + } + + unsigned Idx = 1; + for (auto *Root : DRS.Roots) { + DenseSet<Instruction*> V; + collectInLoopUserSet(Root, Exclude, PossibleRedSet, V); + + // While we're here, check the use sets are the same size. + if (V.size() != VBase.size()) { + DEBUG(dbgs() << "LRR: Aborting - use sets are different sizes\n"); + return false; + } + + for (auto *I : V) { + Uses[I].set(Idx); + } + ++Idx; + } + + // Make sure our subsumed instructions are remembered too. + for (auto *I : DRS.SubsumedInsts) { + Uses[I].set(IL_All); + } + } + + // Make sure the loop increments are also accounted for. + + Exclude.clear(); + for (auto &DRS : RootSets) { + Exclude.insert(DRS.Roots.begin(), DRS.Roots.end()); + Exclude.insert(DRS.SubsumedInsts.begin(), DRS.SubsumedInsts.end()); + Exclude.insert(DRS.BaseInst); + } + + DenseSet<Instruction*> V; + collectInLoopUserSet(LoopIncs, Exclude, PossibleRedSet, V); + for (auto *I : V) { + Uses[I].set(IL_All); + } + + return true; +} + +/// Get the next instruction in "In" that is a member of set Val. +/// Start searching from StartI, and do not return anything in Exclude. +/// If StartI is not given, start from In.begin(). +LoopReroll::DAGRootTracker::UsesTy::iterator +LoopReroll::DAGRootTracker::nextInstr(int Val, UsesTy &In, + const SmallInstructionSet &Exclude, + UsesTy::iterator *StartI) { + UsesTy::iterator I = StartI ? *StartI : In.begin(); + while (I != In.end() && (I->second.test(Val) == 0 || + Exclude.count(I->first) != 0)) + ++I; + return I; +} + +bool LoopReroll::DAGRootTracker::isBaseInst(Instruction *I) { + for (auto &DRS : RootSets) { + if (DRS.BaseInst == I) + return true; + } + return false; +} + +bool LoopReroll::DAGRootTracker::isRootInst(Instruction *I) { + for (auto &DRS : RootSets) { + if (is_contained(DRS.Roots, I)) + return true; + } + return false; +} + +/// Return true if instruction I depends on any instruction between +/// Start and End. +bool LoopReroll::DAGRootTracker::instrDependsOn(Instruction *I, + UsesTy::iterator Start, + UsesTy::iterator End) { + for (auto *U : I->users()) { + for (auto It = Start; It != End; ++It) + if (U == It->first) + return true; + } + return false; +} + +static bool isIgnorableInst(const Instruction *I) { + if (isa<DbgInfoIntrinsic>(I)) + return true; + const IntrinsicInst* II = dyn_cast<IntrinsicInst>(I); + if (!II) + return false; + switch (II->getIntrinsicID()) { + default: + return false; + case Intrinsic::annotation: + case Intrinsic::ptr_annotation: + case Intrinsic::var_annotation: + // TODO: the following intrinsics may also be whitelisted: + // lifetime_start, lifetime_end, invariant_start, invariant_end + return true; + } + return false; +} + +bool LoopReroll::DAGRootTracker::validate(ReductionTracker &Reductions) { + // We now need to check for equivalence of the use graph of each root with + // that of the primary induction variable (excluding the roots). Our goal + // here is not to solve the full graph isomorphism problem, but rather to + // catch common cases without a lot of work. As a result, we will assume + // that the relative order of the instructions in each unrolled iteration + // is the same (although we will not make an assumption about how the + // different iterations are intermixed). Note that while the order must be + // the same, the instructions may not be in the same basic block. + + // An array of just the possible reductions for this scale factor. When we + // collect the set of all users of some root instructions, these reduction + // instructions are treated as 'final' (their uses are not considered). + // This is important because we don't want the root use set to search down + // the reduction chain. + SmallInstructionSet PossibleRedSet; + SmallInstructionSet PossibleRedLastSet; + SmallInstructionSet PossibleRedPHISet; + Reductions.restrictToScale(Scale, PossibleRedSet, + PossibleRedPHISet, PossibleRedLastSet); + + // Populate "Uses" with where each instruction is used. + if (!collectUsedInstructions(PossibleRedSet)) + return false; + + // Make sure we mark the reduction PHIs as used in all iterations. + for (auto *I : PossibleRedPHISet) { + Uses[I].set(IL_All); + } + + // Make sure we mark loop-control-only PHIs as used in all iterations. See + // comment above LoopReroll::isLoopControlIV for more information. + BasicBlock *Header = L->getHeader(); + if (LoopControlIV && LoopControlIV != IV) { + for (auto *U : LoopControlIV->users()) { + Instruction *IVUser = dyn_cast<Instruction>(U); + // IVUser could be loop increment or compare + Uses[IVUser].set(IL_All); + for (auto *UU : IVUser->users()) { + Instruction *UUser = dyn_cast<Instruction>(UU); + // UUser could be compare, PHI or branch + Uses[UUser].set(IL_All); + // Skip SExt + if (isa<SExtInst>(UUser)) { + UUser = dyn_cast<Instruction>(*(UUser->user_begin())); + Uses[UUser].set(IL_All); + } + // Is UUser a compare instruction? + if (UU->hasOneUse()) { + Instruction *BI = dyn_cast<BranchInst>(*UUser->user_begin()); + if (BI == cast<BranchInst>(Header->getTerminator())) + Uses[BI].set(IL_All); + } + } + } + } + + // Make sure all instructions in the loop are in one and only one + // set. + for (auto &KV : Uses) { + if (KV.second.count() != 1 && !isIgnorableInst(KV.first)) { + DEBUG(dbgs() << "LRR: Aborting - instruction is not used in 1 iteration: " + << *KV.first << " (#uses=" << KV.second.count() << ")\n"); + return false; + } + } + + DEBUG( + for (auto &KV : Uses) { + dbgs() << "LRR: " << KV.second.find_first() << "\t" << *KV.first << "\n"; + } + ); + + for (unsigned Iter = 1; Iter < Scale; ++Iter) { + // In addition to regular aliasing information, we need to look for + // instructions from later (future) iterations that have side effects + // preventing us from reordering them past other instructions with side + // effects. + bool FutureSideEffects = false; + AliasSetTracker AST(*AA); + // The map between instructions in f(%iv.(i+1)) and f(%iv). + DenseMap<Value *, Value *> BaseMap; + + // Compare iteration Iter to the base. + SmallInstructionSet Visited; + auto BaseIt = nextInstr(0, Uses, Visited); + auto RootIt = nextInstr(Iter, Uses, Visited); + auto LastRootIt = Uses.begin(); + + while (BaseIt != Uses.end() && RootIt != Uses.end()) { + Instruction *BaseInst = BaseIt->first; + Instruction *RootInst = RootIt->first; + + // Skip over the IV or root instructions; only match their users. + bool Continue = false; + if (isBaseInst(BaseInst)) { + Visited.insert(BaseInst); + BaseIt = nextInstr(0, Uses, Visited); + Continue = true; + } + if (isRootInst(RootInst)) { + LastRootIt = RootIt; + Visited.insert(RootInst); + RootIt = nextInstr(Iter, Uses, Visited); + Continue = true; + } + if (Continue) continue; + + if (!BaseInst->isSameOperationAs(RootInst)) { + // Last chance saloon. We don't try and solve the full isomorphism + // problem, but try and at least catch the case where two instructions + // *of different types* are round the wrong way. We won't be able to + // efficiently tell, given two ADD instructions, which way around we + // should match them, but given an ADD and a SUB, we can at least infer + // which one is which. + // + // This should allow us to deal with a greater subset of the isomorphism + // problem. It does however change a linear algorithm into a quadratic + // one, so limit the number of probes we do. + auto TryIt = RootIt; + unsigned N = NumToleratedFailedMatches; + while (TryIt != Uses.end() && + !BaseInst->isSameOperationAs(TryIt->first) && + N--) { + ++TryIt; + TryIt = nextInstr(Iter, Uses, Visited, &TryIt); + } + + if (TryIt == Uses.end() || TryIt == RootIt || + instrDependsOn(TryIt->first, RootIt, TryIt)) { + DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << + " vs. " << *RootInst << "\n"); + return false; + } + + RootIt = TryIt; + RootInst = TryIt->first; + } + + // All instructions between the last root and this root + // may belong to some other iteration. If they belong to a + // future iteration, then they're dangerous to alias with. + // + // Note that because we allow a limited amount of flexibility in the order + // that we visit nodes, LastRootIt might be *before* RootIt, in which + // case we've already checked this set of instructions so we shouldn't + // do anything. + for (; LastRootIt < RootIt; ++LastRootIt) { + Instruction *I = LastRootIt->first; + if (LastRootIt->second.find_first() < (int)Iter) + continue; + if (I->mayWriteToMemory()) + AST.add(I); + // Note: This is specifically guarded by a check on isa<PHINode>, + // which while a valid (somewhat arbitrary) micro-optimization, is + // needed because otherwise isSafeToSpeculativelyExecute returns + // false on PHI nodes. + if (!isa<PHINode>(I) && !isUnorderedLoadStore(I) && + !isSafeToSpeculativelyExecute(I)) + // Intervening instructions cause side effects. + FutureSideEffects = true; + } + + // Make sure that this instruction, which is in the use set of this + // root instruction, does not also belong to the base set or the set of + // some other root instruction. + if (RootIt->second.count() > 1) { + DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << + " vs. " << *RootInst << " (prev. case overlap)\n"); + return false; + } + + // Make sure that we don't alias with any instruction in the alias set + // tracker. If we do, then we depend on a future iteration, and we + // can't reroll. + if (RootInst->mayReadFromMemory()) + for (auto &K : AST) { + if (K.aliasesUnknownInst(RootInst, *AA)) { + DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << + " vs. " << *RootInst << " (depends on future store)\n"); + return false; + } + } + + // If we've past an instruction from a future iteration that may have + // side effects, and this instruction might also, then we can't reorder + // them, and this matching fails. As an exception, we allow the alias + // set tracker to handle regular (unordered) load/store dependencies. + if (FutureSideEffects && ((!isUnorderedLoadStore(BaseInst) && + !isSafeToSpeculativelyExecute(BaseInst)) || + (!isUnorderedLoadStore(RootInst) && + !isSafeToSpeculativelyExecute(RootInst)))) { + DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << + " vs. " << *RootInst << + " (side effects prevent reordering)\n"); + return false; + } + + // For instructions that are part of a reduction, if the operation is + // associative, then don't bother matching the operands (because we + // already know that the instructions are isomorphic, and the order + // within the iteration does not matter). For non-associative reductions, + // we do need to match the operands, because we need to reject + // out-of-order instructions within an iteration! + // For example (assume floating-point addition), we need to reject this: + // x += a[i]; x += b[i]; + // x += a[i+1]; x += b[i+1]; + // x += b[i+2]; x += a[i+2]; + bool InReduction = Reductions.isPairInSame(BaseInst, RootInst); + + if (!(InReduction && BaseInst->isAssociative())) { + bool Swapped = false, SomeOpMatched = false; + for (unsigned j = 0; j < BaseInst->getNumOperands(); ++j) { + Value *Op2 = RootInst->getOperand(j); + + // If this is part of a reduction (and the operation is not + // associatve), then we match all operands, but not those that are + // part of the reduction. + if (InReduction) + if (Instruction *Op2I = dyn_cast<Instruction>(Op2)) + if (Reductions.isPairInSame(RootInst, Op2I)) + continue; + + DenseMap<Value *, Value *>::iterator BMI = BaseMap.find(Op2); + if (BMI != BaseMap.end()) { + Op2 = BMI->second; + } else { + for (auto &DRS : RootSets) { + if (DRS.Roots[Iter-1] == (Instruction*) Op2) { + Op2 = DRS.BaseInst; + break; + } + } + } + + if (BaseInst->getOperand(Swapped ? unsigned(!j) : j) != Op2) { + // If we've not already decided to swap the matched operands, and + // we've not already matched our first operand (note that we could + // have skipped matching the first operand because it is part of a + // reduction above), and the instruction is commutative, then try + // the swapped match. + if (!Swapped && BaseInst->isCommutative() && !SomeOpMatched && + BaseInst->getOperand(!j) == Op2) { + Swapped = true; + } else { + DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst + << " vs. " << *RootInst << " (operand " << j << ")\n"); + return false; + } + } + + SomeOpMatched = true; + } + } + + if ((!PossibleRedLastSet.count(BaseInst) && + hasUsesOutsideLoop(BaseInst, L)) || + (!PossibleRedLastSet.count(RootInst) && + hasUsesOutsideLoop(RootInst, L))) { + DEBUG(dbgs() << "LRR: iteration root match failed at " << *BaseInst << + " vs. " << *RootInst << " (uses outside loop)\n"); + return false; + } + + Reductions.recordPair(BaseInst, RootInst, Iter); + BaseMap.insert(std::make_pair(RootInst, BaseInst)); + + LastRootIt = RootIt; + Visited.insert(BaseInst); + Visited.insert(RootInst); + BaseIt = nextInstr(0, Uses, Visited); + RootIt = nextInstr(Iter, Uses, Visited); + } + assert(BaseIt == Uses.end() && RootIt == Uses.end() && + "Mismatched set sizes!"); + } + + DEBUG(dbgs() << "LRR: Matched all iteration increments for " << + *IV << "\n"); + + return true; +} + +void LoopReroll::DAGRootTracker::replace(const SCEV *IterCount) { + BasicBlock *Header = L->getHeader(); + // Remove instructions associated with non-base iterations. + for (BasicBlock::reverse_iterator J = Header->rbegin(), JE = Header->rend(); + J != JE;) { + unsigned I = Uses[&*J].find_first(); + if (I > 0 && I < IL_All) { + DEBUG(dbgs() << "LRR: removing: " << *J << "\n"); + J++->eraseFromParent(); + continue; + } + + ++J; + } + + bool HasTwoIVs = LoopControlIV && LoopControlIV != IV; + + if (HasTwoIVs) { + updateNonLoopCtrlIncr(); + replaceIV(LoopControlIV, LoopControlIV, IterCount); + } else + // We need to create a new induction variable for each different BaseInst. + for (auto &DRS : RootSets) + // Insert the new induction variable. + replaceIV(DRS.BaseInst, IV, IterCount); + + SimplifyInstructionsInBlock(Header, TLI); + DeleteDeadPHIs(Header, TLI); +} + +// For non-loop-control IVs, we only need to update the last increment +// with right amount, then we are done. +void LoopReroll::DAGRootTracker::updateNonLoopCtrlIncr() { + const SCEV *NewInc = nullptr; + for (auto *LoopInc : LoopIncs) { + GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(LoopInc); + const SCEVConstant *COp = nullptr; + if (GEP && LoopInc->getOperand(0)->getType()->isPointerTy()) { + COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1))); + } else { + COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(0))); + if (!COp) + COp = dyn_cast<SCEVConstant>(SE->getSCEV(LoopInc->getOperand(1))); + } + + assert(COp && "Didn't find constant operand of LoopInc!\n"); + + const APInt &AInt = COp->getValue()->getValue(); + const SCEV *ScaleSCEV = SE->getConstant(COp->getType(), Scale); + if (AInt.isNegative()) { + NewInc = SE->getNegativeSCEV(COp); + NewInc = SE->getUDivExpr(NewInc, ScaleSCEV); + NewInc = SE->getNegativeSCEV(NewInc); + } else + NewInc = SE->getUDivExpr(COp, ScaleSCEV); + + LoopInc->setOperand(1, dyn_cast<SCEVConstant>(NewInc)->getValue()); + } +} + +void LoopReroll::DAGRootTracker::replaceIV(Instruction *Inst, + Instruction *InstIV, + const SCEV *IterCount) { + BasicBlock *Header = L->getHeader(); + int64_t Inc = IVToIncMap[InstIV]; + bool NeedNewIV = InstIV == LoopControlIV; + bool Negative = !NeedNewIV && Inc < 0; + + const SCEVAddRecExpr *RealIVSCEV = cast<SCEVAddRecExpr>(SE->getSCEV(Inst)); + const SCEV *Start = RealIVSCEV->getStart(); + + if (NeedNewIV) + Start = SE->getConstant(Start->getType(), 0); + + const SCEV *SizeOfExpr = nullptr; + const SCEV *IncrExpr = + SE->getConstant(RealIVSCEV->getType(), Negative ? -1 : 1); + if (auto *PTy = dyn_cast<PointerType>(Inst->getType())) { + Type *ElTy = PTy->getElementType(); + SizeOfExpr = + SE->getSizeOfExpr(SE->getEffectiveSCEVType(Inst->getType()), ElTy); + IncrExpr = SE->getMulExpr(IncrExpr, SizeOfExpr); + } + const SCEV *NewIVSCEV = + SE->getAddRecExpr(Start, IncrExpr, L, SCEV::FlagAnyWrap); + + { // Limit the lifetime of SCEVExpander. + const DataLayout &DL = Header->getModule()->getDataLayout(); + SCEVExpander Expander(*SE, DL, "reroll"); + Value *NewIV = Expander.expandCodeFor(NewIVSCEV, Inst->getType(), + Header->getFirstNonPHIOrDbg()); + + for (auto &KV : Uses) + if (KV.second.find_first() == 0) + KV.first->replaceUsesOfWith(Inst, NewIV); + + if (BranchInst *BI = dyn_cast<BranchInst>(Header->getTerminator())) { + // FIXME: Why do we need this check? + if (Uses[BI].find_first() == IL_All) { + const SCEV *ICSCEV = RealIVSCEV->evaluateAtIteration(IterCount, *SE); + + if (NeedNewIV) + ICSCEV = SE->getMulExpr(IterCount, + SE->getConstant(IterCount->getType(), Scale)); + + // Iteration count SCEV minus or plus 1 + const SCEV *MinusPlus1SCEV = + SE->getConstant(ICSCEV->getType(), Negative ? -1 : 1); + if (Inst->getType()->isPointerTy()) { + assert(SizeOfExpr && "SizeOfExpr is not initialized"); + MinusPlus1SCEV = SE->getMulExpr(MinusPlus1SCEV, SizeOfExpr); + } + + const SCEV *ICMinusPlus1SCEV = SE->getMinusSCEV(ICSCEV, MinusPlus1SCEV); + // Iteration count minus 1 + Instruction *InsertPtr = nullptr; + if (isa<SCEVConstant>(ICMinusPlus1SCEV)) { + InsertPtr = BI; + } else { + BasicBlock *Preheader = L->getLoopPreheader(); + if (!Preheader) + Preheader = InsertPreheaderForLoop(L, DT, LI, PreserveLCSSA); + InsertPtr = Preheader->getTerminator(); + } + + if (!isa<PointerType>(NewIV->getType()) && NeedNewIV && + (SE->getTypeSizeInBits(NewIV->getType()) < + SE->getTypeSizeInBits(ICMinusPlus1SCEV->getType()))) { + IRBuilder<> Builder(BI); + Builder.SetCurrentDebugLocation(BI->getDebugLoc()); + NewIV = Builder.CreateSExt(NewIV, ICMinusPlus1SCEV->getType()); + } + Value *ICMinusPlus1 = Expander.expandCodeFor( + ICMinusPlus1SCEV, NewIV->getType(), InsertPtr); + + Value *Cond = + new ICmpInst(BI, CmpInst::ICMP_EQ, NewIV, ICMinusPlus1, "exitcond"); + BI->setCondition(Cond); + + if (BI->getSuccessor(1) != Header) + BI->swapSuccessors(); + } + } + } +} + +// Validate the selected reductions. All iterations must have an isomorphic +// part of the reduction chain and, for non-associative reductions, the chain +// entries must appear in order. +bool LoopReroll::ReductionTracker::validateSelected() { + // For a non-associative reduction, the chain entries must appear in order. + for (int i : Reds) { + int PrevIter = 0, BaseCount = 0, Count = 0; + for (Instruction *J : PossibleReds[i]) { + // Note that all instructions in the chain must have been found because + // all instructions in the function must have been assigned to some + // iteration. + int Iter = PossibleRedIter[J]; + if (Iter != PrevIter && Iter != PrevIter + 1 && + !PossibleReds[i].getReducedValue()->isAssociative()) { + DEBUG(dbgs() << "LRR: Out-of-order non-associative reduction: " << + J << "\n"); + return false; + } + + if (Iter != PrevIter) { + if (Count != BaseCount) { + DEBUG(dbgs() << "LRR: Iteration " << PrevIter << + " reduction use count " << Count << + " is not equal to the base use count " << + BaseCount << "\n"); + return false; + } + + Count = 0; + } + + ++Count; + if (Iter == 0) + ++BaseCount; + + PrevIter = Iter; + } + } + + return true; +} + +// For all selected reductions, remove all parts except those in the first +// iteration (and the PHI). Replace outside uses of the reduced value with uses +// of the first-iteration reduced value (in other words, reroll the selected +// reductions). +void LoopReroll::ReductionTracker::replaceSelected() { + // Fixup reductions to refer to the last instruction associated with the + // first iteration (not the last). + for (int i : Reds) { + int j = 0; + for (int e = PossibleReds[i].size(); j != e; ++j) + if (PossibleRedIter[PossibleReds[i][j]] != 0) { + --j; + break; + } + + // Replace users with the new end-of-chain value. + SmallInstructionVector Users; + for (User *U : PossibleReds[i].getReducedValue()->users()) { + Users.push_back(cast<Instruction>(U)); + } + + for (Instruction *User : Users) + User->replaceUsesOfWith(PossibleReds[i].getReducedValue(), + PossibleReds[i][j]); + } +} + +// Reroll the provided loop with respect to the provided induction variable. +// Generally, we're looking for a loop like this: +// +// %iv = phi [ (preheader, ...), (body, %iv.next) ] +// f(%iv) +// %iv.1 = add %iv, 1 <-- a root increment +// f(%iv.1) +// %iv.2 = add %iv, 2 <-- a root increment +// f(%iv.2) +// %iv.scale_m_1 = add %iv, scale-1 <-- a root increment +// f(%iv.scale_m_1) +// ... +// %iv.next = add %iv, scale +// %cmp = icmp(%iv, ...) +// br %cmp, header, exit +// +// Notably, we do not require that f(%iv), f(%iv.1), etc. be isolated groups of +// instructions. In other words, the instructions in f(%iv), f(%iv.1), etc. can +// be intermixed with eachother. The restriction imposed by this algorithm is +// that the relative order of the isomorphic instructions in f(%iv), f(%iv.1), +// etc. be the same. +// +// First, we collect the use set of %iv, excluding the other increment roots. +// This gives us f(%iv). Then we iterate over the loop instructions (scale-1) +// times, having collected the use set of f(%iv.(i+1)), during which we: +// - Ensure that the next unmatched instruction in f(%iv) is isomorphic to +// the next unmatched instruction in f(%iv.(i+1)). +// - Ensure that both matched instructions don't have any external users +// (with the exception of last-in-chain reduction instructions). +// - Track the (aliasing) write set, and other side effects, of all +// instructions that belong to future iterations that come before the matched +// instructions. If the matched instructions read from that write set, then +// f(%iv) or f(%iv.(i+1)) has some dependency on instructions in +// f(%iv.(j+1)) for some j > i, and we cannot reroll the loop. Similarly, +// if any of these future instructions had side effects (could not be +// speculatively executed), and so do the matched instructions, when we +// cannot reorder those side-effect-producing instructions, and rerolling +// fails. +// +// Finally, we make sure that all loop instructions are either loop increment +// roots, belong to simple latch code, parts of validated reductions, part of +// f(%iv) or part of some f(%iv.i). If all of that is true (and all reductions +// have been validated), then we reroll the loop. +bool LoopReroll::reroll(Instruction *IV, Loop *L, BasicBlock *Header, + const SCEV *IterCount, + ReductionTracker &Reductions) { + DAGRootTracker DAGRoots(this, L, IV, SE, AA, TLI, DT, LI, PreserveLCSSA, + IVToIncMap, LoopControlIV); + + if (!DAGRoots.findRoots()) + return false; + DEBUG(dbgs() << "LRR: Found all root induction increments for: " << + *IV << "\n"); + + if (!DAGRoots.validate(Reductions)) + return false; + if (!Reductions.validateSelected()) + return false; + // At this point, we've validated the rerolling, and we're committed to + // making changes! + + Reductions.replaceSelected(); + DAGRoots.replace(IterCount); + + ++NumRerolledLoops; + return true; +} + +bool LoopReroll::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + + AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); + + BasicBlock *Header = L->getHeader(); + DEBUG(dbgs() << "LRR: F[" << Header->getParent()->getName() << + "] Loop %" << Header->getName() << " (" << + L->getNumBlocks() << " block(s))\n"); + + // For now, we'll handle only single BB loops. + if (L->getNumBlocks() > 1) + return false; + + if (!SE->hasLoopInvariantBackedgeTakenCount(L)) + return false; + + const SCEV *LIBETC = SE->getBackedgeTakenCount(L); + const SCEV *IterCount = SE->getAddExpr(LIBETC, SE->getOne(LIBETC->getType())); + DEBUG(dbgs() << "\n Before Reroll:\n" << *(L->getHeader()) << "\n"); + DEBUG(dbgs() << "LRR: iteration count = " << *IterCount << "\n"); + + // First, we need to find the induction variable with respect to which we can + // reroll (there may be several possible options). + SmallInstructionVector PossibleIVs; + IVToIncMap.clear(); + LoopControlIV = nullptr; + collectPossibleIVs(L, PossibleIVs); + + if (PossibleIVs.empty()) { + DEBUG(dbgs() << "LRR: No possible IVs found\n"); + return false; + } + + ReductionTracker Reductions; + collectPossibleReductions(L, Reductions); + bool Changed = false; + + // For each possible IV, collect the associated possible set of 'root' nodes + // (i+1, i+2, etc.). + for (Instruction *PossibleIV : PossibleIVs) + if (reroll(PossibleIV, L, Header, IterCount, Reductions)) { + Changed = true; + break; + } + DEBUG(dbgs() << "\n After Reroll:\n" << *(L->getHeader()) << "\n"); + + // Trip count of L has changed so SE must be re-evaluated. + if (Changed) + SE->forgetLoop(L); + + return Changed; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp new file mode 100644 index 000000000000..a91f53ba663f --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopRotation.cpp @@ -0,0 +1,711 @@ +//===- LoopRotation.cpp - Loop Rotation Pass ------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements Loop Rotation Pass. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopRotation.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-rotate" + +static cl::opt<unsigned> DefaultRotationThreshold( + "rotation-max-header-size", cl::init(16), cl::Hidden, + cl::desc("The default maximum header size for automatic loop rotation")); + +STATISTIC(NumRotated, "Number of loops rotated"); + +namespace { +/// A simple loop rotation transformation. +class LoopRotate { + const unsigned MaxHeaderSize; + LoopInfo *LI; + const TargetTransformInfo *TTI; + AssumptionCache *AC; + DominatorTree *DT; + ScalarEvolution *SE; + const SimplifyQuery &SQ; + +public: + LoopRotate(unsigned MaxHeaderSize, LoopInfo *LI, + const TargetTransformInfo *TTI, AssumptionCache *AC, + DominatorTree *DT, ScalarEvolution *SE, const SimplifyQuery &SQ) + : MaxHeaderSize(MaxHeaderSize), LI(LI), TTI(TTI), AC(AC), DT(DT), SE(SE), + SQ(SQ) {} + bool processLoop(Loop *L); + +private: + bool rotateLoop(Loop *L, bool SimplifiedLatch); + bool simplifyLoopLatch(Loop *L); +}; +} // end anonymous namespace + +/// RewriteUsesOfClonedInstructions - We just cloned the instructions from the +/// old header into the preheader. If there were uses of the values produced by +/// these instruction that were outside of the loop, we have to insert PHI nodes +/// to merge the two values. Do this now. +static void RewriteUsesOfClonedInstructions(BasicBlock *OrigHeader, + BasicBlock *OrigPreheader, + ValueToValueMapTy &ValueMap, + SmallVectorImpl<PHINode*> *InsertedPHIs) { + // Remove PHI node entries that are no longer live. + BasicBlock::iterator I, E = OrigHeader->end(); + for (I = OrigHeader->begin(); PHINode *PN = dyn_cast<PHINode>(I); ++I) + PN->removeIncomingValue(PN->getBasicBlockIndex(OrigPreheader)); + + // Now fix up users of the instructions in OrigHeader, inserting PHI nodes + // as necessary. + SSAUpdater SSA(InsertedPHIs); + for (I = OrigHeader->begin(); I != E; ++I) { + Value *OrigHeaderVal = &*I; + + // If there are no uses of the value (e.g. because it returns void), there + // is nothing to rewrite. + if (OrigHeaderVal->use_empty()) + continue; + + Value *OrigPreHeaderVal = ValueMap.lookup(OrigHeaderVal); + + // The value now exits in two versions: the initial value in the preheader + // and the loop "next" value in the original header. + SSA.Initialize(OrigHeaderVal->getType(), OrigHeaderVal->getName()); + SSA.AddAvailableValue(OrigHeader, OrigHeaderVal); + SSA.AddAvailableValue(OrigPreheader, OrigPreHeaderVal); + + // Visit each use of the OrigHeader instruction. + for (Value::use_iterator UI = OrigHeaderVal->use_begin(), + UE = OrigHeaderVal->use_end(); + UI != UE;) { + // Grab the use before incrementing the iterator. + Use &U = *UI; + + // Increment the iterator before removing the use from the list. + ++UI; + + // SSAUpdater can't handle a non-PHI use in the same block as an + // earlier def. We can easily handle those cases manually. + Instruction *UserInst = cast<Instruction>(U.getUser()); + if (!isa<PHINode>(UserInst)) { + BasicBlock *UserBB = UserInst->getParent(); + + // The original users in the OrigHeader are already using the + // original definitions. + if (UserBB == OrigHeader) + continue; + + // Users in the OrigPreHeader need to use the value to which the + // original definitions are mapped. + if (UserBB == OrigPreheader) { + U = OrigPreHeaderVal; + continue; + } + } + + // Anything else can be handled by SSAUpdater. + SSA.RewriteUse(U); + } + + // Replace MetadataAsValue(ValueAsMetadata(OrigHeaderVal)) uses in debug + // intrinsics. + SmallVector<DbgValueInst *, 1> DbgValues; + llvm::findDbgValues(DbgValues, OrigHeaderVal); + for (auto &DbgValue : DbgValues) { + // The original users in the OrigHeader are already using the original + // definitions. + BasicBlock *UserBB = DbgValue->getParent(); + if (UserBB == OrigHeader) + continue; + + // Users in the OrigPreHeader need to use the value to which the + // original definitions are mapped and anything else can be handled by + // the SSAUpdater. To avoid adding PHINodes, check if the value is + // available in UserBB, if not substitute undef. + Value *NewVal; + if (UserBB == OrigPreheader) + NewVal = OrigPreHeaderVal; + else if (SSA.HasValueForBlock(UserBB)) + NewVal = SSA.GetValueInMiddleOfBlock(UserBB); + else + NewVal = UndefValue::get(OrigHeaderVal->getType()); + DbgValue->setOperand(0, + MetadataAsValue::get(OrigHeaderVal->getContext(), + ValueAsMetadata::get(NewVal))); + } + } +} + +/// Propagate dbg.value intrinsics through the newly inserted Phis. +static void insertDebugValues(BasicBlock *OrigHeader, + SmallVectorImpl<PHINode*> &InsertedPHIs) { + ValueToValueMapTy DbgValueMap; + + // Map existing PHI nodes to their dbg.values. + for (auto &I : *OrigHeader) { + if (auto DbgII = dyn_cast<DbgInfoIntrinsic>(&I)) { + if (auto *Loc = dyn_cast_or_null<PHINode>(DbgII->getVariableLocation())) + DbgValueMap.insert({Loc, DbgII}); + } + } + + // Then iterate through the new PHIs and look to see if they use one of the + // previously mapped PHIs. If so, insert a new dbg.value intrinsic that will + // propagate the info through the new PHI. + LLVMContext &C = OrigHeader->getContext(); + for (auto PHI : InsertedPHIs) { + for (auto VI : PHI->operand_values()) { + auto V = DbgValueMap.find(VI); + if (V != DbgValueMap.end()) { + auto *DbgII = cast<DbgInfoIntrinsic>(V->second); + Instruction *NewDbgII = DbgII->clone(); + auto PhiMAV = MetadataAsValue::get(C, ValueAsMetadata::get(PHI)); + NewDbgII->setOperand(0, PhiMAV); + BasicBlock *Parent = PHI->getParent(); + NewDbgII->insertBefore(Parent->getFirstNonPHIOrDbgOrLifetime()); + } + } + } +} + +/// Rotate loop LP. Return true if the loop is rotated. +/// +/// \param SimplifiedLatch is true if the latch was just folded into the final +/// loop exit. In this case we may want to rotate even though the new latch is +/// now an exiting branch. This rotation would have happened had the latch not +/// been simplified. However, if SimplifiedLatch is false, then we avoid +/// rotating loops in which the latch exits to avoid excessive or endless +/// rotation. LoopRotate should be repeatable and converge to a canonical +/// form. This property is satisfied because simplifying the loop latch can only +/// happen once across multiple invocations of the LoopRotate pass. +bool LoopRotate::rotateLoop(Loop *L, bool SimplifiedLatch) { + // If the loop has only one block then there is not much to rotate. + if (L->getBlocks().size() == 1) + return false; + + BasicBlock *OrigHeader = L->getHeader(); + BasicBlock *OrigLatch = L->getLoopLatch(); + + BranchInst *BI = dyn_cast<BranchInst>(OrigHeader->getTerminator()); + if (!BI || BI->isUnconditional()) + return false; + + // If the loop header is not one of the loop exiting blocks then + // either this loop is already rotated or it is not + // suitable for loop rotation transformations. + if (!L->isLoopExiting(OrigHeader)) + return false; + + // If the loop latch already contains a branch that leaves the loop then the + // loop is already rotated. + if (!OrigLatch) + return false; + + // Rotate if either the loop latch does *not* exit the loop, or if the loop + // latch was just simplified. + if (L->isLoopExiting(OrigLatch) && !SimplifiedLatch) + return false; + + // Check size of original header and reject loop if it is very big or we can't + // duplicate blocks inside it. + { + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + + CodeMetrics Metrics; + Metrics.analyzeBasicBlock(OrigHeader, *TTI, EphValues); + if (Metrics.notDuplicatable) { + DEBUG(dbgs() << "LoopRotation: NOT rotating - contains non-duplicatable" + << " instructions: "; + L->dump()); + return false; + } + if (Metrics.convergent) { + DEBUG(dbgs() << "LoopRotation: NOT rotating - contains convergent " + "instructions: "; + L->dump()); + return false; + } + if (Metrics.NumInsts > MaxHeaderSize) + return false; + } + + // Now, this loop is suitable for rotation. + BasicBlock *OrigPreheader = L->getLoopPreheader(); + + // If the loop could not be converted to canonical form, it must have an + // indirectbr in it, just give up. + if (!OrigPreheader) + return false; + + // Anything ScalarEvolution may know about this loop or the PHI nodes + // in its header will soon be invalidated. + if (SE) + SE->forgetLoop(L); + + DEBUG(dbgs() << "LoopRotation: rotating "; L->dump()); + + // Find new Loop header. NewHeader is a Header's one and only successor + // that is inside loop. Header's other successor is outside the + // loop. Otherwise loop is not suitable for rotation. + BasicBlock *Exit = BI->getSuccessor(0); + BasicBlock *NewHeader = BI->getSuccessor(1); + if (L->contains(Exit)) + std::swap(Exit, NewHeader); + assert(NewHeader && "Unable to determine new loop header"); + assert(L->contains(NewHeader) && !L->contains(Exit) && + "Unable to determine loop header and exit blocks"); + + // This code assumes that the new header has exactly one predecessor. + // Remove any single-entry PHI nodes in it. + assert(NewHeader->getSinglePredecessor() && + "New header doesn't have one pred!"); + FoldSingleEntryPHINodes(NewHeader); + + // Begin by walking OrigHeader and populating ValueMap with an entry for + // each Instruction. + BasicBlock::iterator I = OrigHeader->begin(), E = OrigHeader->end(); + ValueToValueMapTy ValueMap; + + // For PHI nodes, the value available in OldPreHeader is just the + // incoming value from OldPreHeader. + for (; PHINode *PN = dyn_cast<PHINode>(I); ++I) + ValueMap[PN] = PN->getIncomingValueForBlock(OrigPreheader); + + // For the rest of the instructions, either hoist to the OrigPreheader if + // possible or create a clone in the OldPreHeader if not. + TerminatorInst *LoopEntryBranch = OrigPreheader->getTerminator(); + + // Record all debug intrinsics preceding LoopEntryBranch to avoid duplication. + using DbgIntrinsicHash = + std::pair<std::pair<Value *, DILocalVariable *>, DIExpression *>; + auto makeHash = [](DbgInfoIntrinsic *D) -> DbgIntrinsicHash { + return {{D->getVariableLocation(), D->getVariable()}, D->getExpression()}; + }; + SmallDenseSet<DbgIntrinsicHash, 8> DbgIntrinsics; + for (auto I = std::next(OrigPreheader->rbegin()), E = OrigPreheader->rend(); + I != E; ++I) { + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(&*I)) + DbgIntrinsics.insert(makeHash(DII)); + else + break; + } + + while (I != E) { + Instruction *Inst = &*I++; + + // If the instruction's operands are invariant and it doesn't read or write + // memory, then it is safe to hoist. Doing this doesn't change the order of + // execution in the preheader, but does prevent the instruction from + // executing in each iteration of the loop. This means it is safe to hoist + // something that might trap, but isn't safe to hoist something that reads + // memory (without proving that the loop doesn't write). + if (L->hasLoopInvariantOperands(Inst) && !Inst->mayReadFromMemory() && + !Inst->mayWriteToMemory() && !isa<TerminatorInst>(Inst) && + !isa<DbgInfoIntrinsic>(Inst) && !isa<AllocaInst>(Inst)) { + Inst->moveBefore(LoopEntryBranch); + continue; + } + + // Otherwise, create a duplicate of the instruction. + Instruction *C = Inst->clone(); + + // Eagerly remap the operands of the instruction. + RemapInstruction(C, ValueMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + + // Avoid inserting the same intrinsic twice. + if (auto *DII = dyn_cast<DbgInfoIntrinsic>(C)) + if (DbgIntrinsics.count(makeHash(DII))) { + C->deleteValue(); + continue; + } + + // With the operands remapped, see if the instruction constant folds or is + // otherwise simplifyable. This commonly occurs because the entry from PHI + // nodes allows icmps and other instructions to fold. + Value *V = SimplifyInstruction(C, SQ); + if (V && LI->replacementPreservesLCSSAForm(C, V)) { + // If so, then delete the temporary instruction and stick the folded value + // in the map. + ValueMap[Inst] = V; + if (!C->mayHaveSideEffects()) { + C->deleteValue(); + C = nullptr; + } + } else { + ValueMap[Inst] = C; + } + if (C) { + // Otherwise, stick the new instruction into the new block! + C->setName(Inst->getName()); + C->insertBefore(LoopEntryBranch); + + if (auto *II = dyn_cast<IntrinsicInst>(C)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + } + } + + // Along with all the other instructions, we just cloned OrigHeader's + // terminator into OrigPreHeader. Fix up the PHI nodes in each of OrigHeader's + // successors by duplicating their incoming values for OrigHeader. + TerminatorInst *TI = OrigHeader->getTerminator(); + for (BasicBlock *SuccBB : TI->successors()) + for (BasicBlock::iterator BI = SuccBB->begin(); + PHINode *PN = dyn_cast<PHINode>(BI); ++BI) + PN->addIncoming(PN->getIncomingValueForBlock(OrigHeader), OrigPreheader); + + // Now that OrigPreHeader has a clone of OrigHeader's terminator, remove + // OrigPreHeader's old terminator (the original branch into the loop), and + // remove the corresponding incoming values from the PHI nodes in OrigHeader. + LoopEntryBranch->eraseFromParent(); + + + SmallVector<PHINode*, 2> InsertedPHIs; + // If there were any uses of instructions in the duplicated block outside the + // loop, update them, inserting PHI nodes as required + RewriteUsesOfClonedInstructions(OrigHeader, OrigPreheader, ValueMap, + &InsertedPHIs); + + // Attach dbg.value intrinsics to the new phis if that phi uses a value that + // previously had debug metadata attached. This keeps the debug info + // up-to-date in the loop body. + if (!InsertedPHIs.empty()) + insertDebugValues(OrigHeader, InsertedPHIs); + + // NewHeader is now the header of the loop. + L->moveToHeader(NewHeader); + assert(L->getHeader() == NewHeader && "Latch block is our new header"); + + // Inform DT about changes to the CFG. + if (DT) { + // The OrigPreheader branches to the NewHeader and Exit now. Then, inform + // the DT about the removed edge to the OrigHeader (that got removed). + SmallVector<DominatorTree::UpdateType, 3> Updates; + Updates.push_back({DominatorTree::Insert, OrigPreheader, Exit}); + Updates.push_back({DominatorTree::Insert, OrigPreheader, NewHeader}); + Updates.push_back({DominatorTree::Delete, OrigPreheader, OrigHeader}); + DT->applyUpdates(Updates); + } + + // At this point, we've finished our major CFG changes. As part of cloning + // the loop into the preheader we've simplified instructions and the + // duplicated conditional branch may now be branching on a constant. If it is + // branching on a constant and if that constant means that we enter the loop, + // then we fold away the cond branch to an uncond branch. This simplifies the + // loop in cases important for nested loops, and it also means we don't have + // to split as many edges. + BranchInst *PHBI = cast<BranchInst>(OrigPreheader->getTerminator()); + assert(PHBI->isConditional() && "Should be clone of BI condbr!"); + if (!isa<ConstantInt>(PHBI->getCondition()) || + PHBI->getSuccessor(cast<ConstantInt>(PHBI->getCondition())->isZero()) != + NewHeader) { + // The conditional branch can't be folded, handle the general case. + // Split edges as necessary to preserve LoopSimplify form. + + // Right now OrigPreHeader has two successors, NewHeader and ExitBlock, and + // thus is not a preheader anymore. + // Split the edge to form a real preheader. + BasicBlock *NewPH = SplitCriticalEdge( + OrigPreheader, NewHeader, + CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + NewPH->setName(NewHeader->getName() + ".lr.ph"); + + // Preserve canonical loop form, which means that 'Exit' should have only + // one predecessor. Note that Exit could be an exit block for multiple + // nested loops, causing both of the edges to now be critical and need to + // be split. + SmallVector<BasicBlock *, 4> ExitPreds(pred_begin(Exit), pred_end(Exit)); + bool SplitLatchEdge = false; + for (BasicBlock *ExitPred : ExitPreds) { + // We only need to split loop exit edges. + Loop *PredLoop = LI->getLoopFor(ExitPred); + if (!PredLoop || PredLoop->contains(Exit)) + continue; + if (isa<IndirectBrInst>(ExitPred->getTerminator())) + continue; + SplitLatchEdge |= L->getLoopLatch() == ExitPred; + BasicBlock *ExitSplit = SplitCriticalEdge( + ExitPred, Exit, + CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA()); + ExitSplit->moveBefore(Exit); + } + assert(SplitLatchEdge && + "Despite splitting all preds, failed to split latch exit?"); + } else { + // We can fold the conditional branch in the preheader, this makes things + // simpler. The first step is to remove the extra edge to the Exit block. + Exit->removePredecessor(OrigPreheader, true /*preserve LCSSA*/); + BranchInst *NewBI = BranchInst::Create(NewHeader, PHBI); + NewBI->setDebugLoc(PHBI->getDebugLoc()); + PHBI->eraseFromParent(); + + // With our CFG finalized, update DomTree if it is available. + if (DT) DT->deleteEdge(OrigPreheader, Exit); + } + + assert(L->getLoopPreheader() && "Invalid loop preheader after loop rotation"); + assert(L->getLoopLatch() && "Invalid loop latch after loop rotation"); + + // Now that the CFG and DomTree are in a consistent state again, try to merge + // the OrigHeader block into OrigLatch. This will succeed if they are + // connected by an unconditional branch. This is just a cleanup so the + // emitted code isn't too gross in this common case. + MergeBlockIntoPredecessor(OrigHeader, DT, LI); + + DEBUG(dbgs() << "LoopRotation: into "; L->dump()); + + ++NumRotated; + return true; +} + +/// Determine whether the instructions in this range may be safely and cheaply +/// speculated. This is not an important enough situation to develop complex +/// heuristics. We handle a single arithmetic instruction along with any type +/// conversions. +static bool shouldSpeculateInstrs(BasicBlock::iterator Begin, + BasicBlock::iterator End, Loop *L) { + bool seenIncrement = false; + bool MultiExitLoop = false; + + if (!L->getExitingBlock()) + MultiExitLoop = true; + + for (BasicBlock::iterator I = Begin; I != End; ++I) { + + if (!isSafeToSpeculativelyExecute(&*I)) + return false; + + if (isa<DbgInfoIntrinsic>(I)) + continue; + + switch (I->getOpcode()) { + default: + return false; + case Instruction::GetElementPtr: + // GEPs are cheap if all indices are constant. + if (!cast<GEPOperator>(I)->hasAllConstantIndices()) + return false; + // fall-thru to increment case + LLVM_FALLTHROUGH; + case Instruction::Add: + case Instruction::Sub: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: { + Value *IVOpnd = + !isa<Constant>(I->getOperand(0)) + ? I->getOperand(0) + : !isa<Constant>(I->getOperand(1)) ? I->getOperand(1) : nullptr; + if (!IVOpnd) + return false; + + // If increment operand is used outside of the loop, this speculation + // could cause extra live range interference. + if (MultiExitLoop) { + for (User *UseI : IVOpnd->users()) { + auto *UserInst = cast<Instruction>(UseI); + if (!L->contains(UserInst)) + return false; + } + } + + if (seenIncrement) + return false; + seenIncrement = true; + break; + } + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + // ignore type conversions + break; + } + } + return true; +} + +/// Fold the loop tail into the loop exit by speculating the loop tail +/// instructions. Typically, this is a single post-increment. In the case of a +/// simple 2-block loop, hoisting the increment can be much better than +/// duplicating the entire loop header. In the case of loops with early exits, +/// rotation will not work anyway, but simplifyLoopLatch will put the loop in +/// canonical form so downstream passes can handle it. +/// +/// I don't believe this invalidates SCEV. +bool LoopRotate::simplifyLoopLatch(Loop *L) { + BasicBlock *Latch = L->getLoopLatch(); + if (!Latch || Latch->hasAddressTaken()) + return false; + + BranchInst *Jmp = dyn_cast<BranchInst>(Latch->getTerminator()); + if (!Jmp || !Jmp->isUnconditional()) + return false; + + BasicBlock *LastExit = Latch->getSinglePredecessor(); + if (!LastExit || !L->isLoopExiting(LastExit)) + return false; + + BranchInst *BI = dyn_cast<BranchInst>(LastExit->getTerminator()); + if (!BI) + return false; + + if (!shouldSpeculateInstrs(Latch->begin(), Jmp->getIterator(), L)) + return false; + + DEBUG(dbgs() << "Folding loop latch " << Latch->getName() << " into " + << LastExit->getName() << "\n"); + + // Hoist the instructions from Latch into LastExit. + LastExit->getInstList().splice(BI->getIterator(), Latch->getInstList(), + Latch->begin(), Jmp->getIterator()); + + unsigned FallThruPath = BI->getSuccessor(0) == Latch ? 0 : 1; + BasicBlock *Header = Jmp->getSuccessor(0); + assert(Header == L->getHeader() && "expected a backward branch"); + + // Remove Latch from the CFG so that LastExit becomes the new Latch. + BI->setSuccessor(FallThruPath, Header); + Latch->replaceSuccessorsPhiUsesWith(LastExit); + Jmp->eraseFromParent(); + + // Nuke the Latch block. + assert(Latch->empty() && "unable to evacuate Latch"); + LI->removeBlock(Latch); + if (DT) + DT->eraseNode(Latch); + Latch->eraseFromParent(); + return true; +} + +/// Rotate \c L, and return true if any modification was made. +bool LoopRotate::processLoop(Loop *L) { + // Save the loop metadata. + MDNode *LoopMD = L->getLoopID(); + + // Simplify the loop latch before attempting to rotate the header + // upward. Rotation may not be needed if the loop tail can be folded into the + // loop exit. + bool SimplifiedLatch = simplifyLoopLatch(L); + + bool MadeChange = rotateLoop(L, SimplifiedLatch); + assert((!MadeChange || L->isLoopExiting(L->getLoopLatch())) && + "Loop latch should be exiting after loop-rotate."); + + // Restore the loop metadata. + // NB! We presume LoopRotation DOESN'T ADD its own metadata. + if ((MadeChange || SimplifiedLatch) && LoopMD) + L->setLoopID(LoopMD); + + return MadeChange || SimplifiedLatch; +} + +LoopRotatePass::LoopRotatePass(bool EnableHeaderDuplication) + : EnableHeaderDuplication(EnableHeaderDuplication) {} + +PreservedAnalyses LoopRotatePass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + int Threshold = EnableHeaderDuplication ? DefaultRotationThreshold : 0; + const DataLayout &DL = L.getHeader()->getModule()->getDataLayout(); + const SimplifyQuery SQ = getBestSimplifyQuery(AR, DL); + LoopRotate LR(Threshold, &AR.LI, &AR.TTI, &AR.AC, &AR.DT, &AR.SE, + SQ); + + bool Changed = LR.processLoop(&L); + if (!Changed) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +namespace { + +class LoopRotateLegacyPass : public LoopPass { + unsigned MaxHeaderSize; + +public: + static char ID; // Pass ID, replacement for typeid + LoopRotateLegacyPass(int SpecifiedMaxHeaderSize = -1) : LoopPass(ID) { + initializeLoopRotateLegacyPassPass(*PassRegistry::getPassRegistry()); + if (SpecifiedMaxHeaderSize == -1) + MaxHeaderSize = DefaultRotationThreshold; + else + MaxHeaderSize = unsigned(SpecifiedMaxHeaderSize); + } + + // LCSSA form makes instruction renaming easier. + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + Function &F = *L->getHeader()->getParent(); + + auto *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + const auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto *DT = DTWP ? &DTWP->getDomTree() : nullptr; + auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + auto *SE = SEWP ? &SEWP->getSE() : nullptr; + const SimplifyQuery SQ = getBestSimplifyQuery(*this, F); + LoopRotate LR(MaxHeaderSize, LI, TTI, AC, DT, SE, SQ); + return LR.processLoop(L); + } +}; +} + +char LoopRotateLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", + false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(LoopRotateLegacyPass, "loop-rotate", "Rotate Loops", false, + false) + +Pass *llvm::createLoopRotatePass(int MaxHeaderSize) { + return new LoopRotateLegacyPass(MaxHeaderSize); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp new file mode 100644 index 000000000000..35c05e84fd68 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopSimplifyCFG.cpp @@ -0,0 +1,109 @@ +//===--------- LoopSimplifyCFG.cpp - Loop CFG Simplification Pass ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements the Loop SimplifyCFG Pass. This pass is responsible for +// basic loop CFG cleanup, primarily to assist other loop passes. If you +// encounter a noncanonical CFG construct that causes another loop pass to +// perform suboptimally, this is the place to fix it up. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopSimplifyCFG.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/DependenceAnalysis.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +using namespace llvm; + +#define DEBUG_TYPE "loop-simplifycfg" + +static bool simplifyLoopCFG(Loop &L, DominatorTree &DT, LoopInfo &LI) { + bool Changed = false; + // Copy blocks into a temporary array to avoid iterator invalidation issues + // as we remove them. + SmallVector<WeakTrackingVH, 16> Blocks(L.blocks()); + + for (auto &Block : Blocks) { + // Attempt to merge blocks in the trivial case. Don't modify blocks which + // belong to other loops. + BasicBlock *Succ = cast_or_null<BasicBlock>(Block); + if (!Succ) + continue; + + BasicBlock *Pred = Succ->getSinglePredecessor(); + if (!Pred || !Pred->getSingleSuccessor() || LI.getLoopFor(Pred) != &L) + continue; + + // Pred is going to disappear, so we need to update the loop info. + if (L.getHeader() == Pred) + L.moveToHeader(Succ); + LI.removeBlock(Pred); + MergeBasicBlockIntoOnlyPred(Succ, &DT); + Changed = true; + } + + return Changed; +} + +PreservedAnalyses LoopSimplifyCFGPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (!simplifyLoopCFG(L, AR.DT, AR.LI)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +namespace { +class LoopSimplifyCFGLegacyPass : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + LoopSimplifyCFGLegacyPass() : LoopPass(ID) { + initializeLoopSimplifyCFGLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &) override { + if (skipLoop(L)) + return false; + + DominatorTree &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + return simplifyLoopCFG(*L, DT, LI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<DependenceAnalysisWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; +} + +char LoopSimplifyCFGLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", + "Simplify loop CFG", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_END(LoopSimplifyCFGLegacyPass, "loop-simplifycfg", + "Simplify loop CFG", false, false) + +Pass *llvm::createLoopSimplifyCFGPass() { + return new LoopSimplifyCFGLegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp new file mode 100644 index 000000000000..430a7085d93f --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopSink.cpp @@ -0,0 +1,373 @@ +//===-- LoopSink.cpp - Loop Sink Pass -------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass does the inverse transformation of what LICM does. +// It traverses all of the instructions in the loop's preheader and sinks +// them to the loop body where frequency is lower than the loop's preheader. +// This pass is a reverse-transformation of LICM. It differs from the Sink +// pass in the following ways: +// +// * It only handles sinking of instructions from the loop's preheader to the +// loop's body +// * It uses alias set tracker to get more accurate alias info +// * It uses block frequency info to find the optimal sinking locations +// +// Overall algorithm: +// +// For I in Preheader: +// InsertBBs = BBs that uses I +// For BB in sorted(LoopBBs): +// DomBBs = BBs in InsertBBs that are dominated by BB +// if freq(DomBBs) > freq(BB) +// InsertBBs = UseBBs - DomBBs + BB +// For BB in InsertBBs: +// Insert I at BB's beginning +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopSink.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/BlockFrequencyInfo.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +using namespace llvm; + +#define DEBUG_TYPE "loopsink" + +STATISTIC(NumLoopSunk, "Number of instructions sunk into loop"); +STATISTIC(NumLoopSunkCloned, "Number of cloned instructions sunk into loop"); + +static cl::opt<unsigned> SinkFrequencyPercentThreshold( + "sink-freq-percent-threshold", cl::Hidden, cl::init(90), + cl::desc("Do not sink instructions that require cloning unless they " + "execute less than this percent of the time.")); + +static cl::opt<unsigned> MaxNumberOfUseBBsForSinking( + "max-uses-for-sinking", cl::Hidden, cl::init(30), + cl::desc("Do not sink instructions that have too many uses.")); + +/// Return adjusted total frequency of \p BBs. +/// +/// * If there is only one BB, sinking instruction will not introduce code +/// size increase. Thus there is no need to adjust the frequency. +/// * If there are more than one BB, sinking would lead to code size increase. +/// In this case, we add some "tax" to the total frequency to make it harder +/// to sink. E.g. +/// Freq(Preheader) = 100 +/// Freq(BBs) = sum(50, 49) = 99 +/// Even if Freq(BBs) < Freq(Preheader), we will not sink from Preheade to +/// BBs as the difference is too small to justify the code size increase. +/// To model this, The adjusted Freq(BBs) will be: +/// AdjustedFreq(BBs) = 99 / SinkFrequencyPercentThreshold% +static BlockFrequency adjustedSumFreq(SmallPtrSetImpl<BasicBlock *> &BBs, + BlockFrequencyInfo &BFI) { + BlockFrequency T = 0; + for (BasicBlock *B : BBs) + T += BFI.getBlockFreq(B); + if (BBs.size() > 1) + T /= BranchProbability(SinkFrequencyPercentThreshold, 100); + return T; +} + +/// Return a set of basic blocks to insert sinked instructions. +/// +/// The returned set of basic blocks (BBsToSinkInto) should satisfy: +/// +/// * Inside the loop \p L +/// * For each UseBB in \p UseBBs, there is at least one BB in BBsToSinkInto +/// that domintates the UseBB +/// * Has minimum total frequency that is no greater than preheader frequency +/// +/// The purpose of the function is to find the optimal sinking points to +/// minimize execution cost, which is defined as "sum of frequency of +/// BBsToSinkInto". +/// As a result, the returned BBsToSinkInto needs to have minimum total +/// frequency. +/// Additionally, if the total frequency of BBsToSinkInto exceeds preheader +/// frequency, the optimal solution is not sinking (return empty set). +/// +/// \p ColdLoopBBs is used to help find the optimal sinking locations. +/// It stores a list of BBs that is: +/// +/// * Inside the loop \p L +/// * Has a frequency no larger than the loop's preheader +/// * Sorted by BB frequency +/// +/// The complexity of the function is O(UseBBs.size() * ColdLoopBBs.size()). +/// To avoid expensive computation, we cap the maximum UseBBs.size() in its +/// caller. +static SmallPtrSet<BasicBlock *, 2> +findBBsToSinkInto(const Loop &L, const SmallPtrSetImpl<BasicBlock *> &UseBBs, + const SmallVectorImpl<BasicBlock *> &ColdLoopBBs, + DominatorTree &DT, BlockFrequencyInfo &BFI) { + SmallPtrSet<BasicBlock *, 2> BBsToSinkInto; + if (UseBBs.size() == 0) + return BBsToSinkInto; + + BBsToSinkInto.insert(UseBBs.begin(), UseBBs.end()); + SmallPtrSet<BasicBlock *, 2> BBsDominatedByColdestBB; + + // For every iteration: + // * Pick the ColdestBB from ColdLoopBBs + // * Find the set BBsDominatedByColdestBB that satisfy: + // - BBsDominatedByColdestBB is a subset of BBsToSinkInto + // - Every BB in BBsDominatedByColdestBB is dominated by ColdestBB + // * If Freq(ColdestBB) < Freq(BBsDominatedByColdestBB), remove + // BBsDominatedByColdestBB from BBsToSinkInto, add ColdestBB to + // BBsToSinkInto + for (BasicBlock *ColdestBB : ColdLoopBBs) { + BBsDominatedByColdestBB.clear(); + for (BasicBlock *SinkedBB : BBsToSinkInto) + if (DT.dominates(ColdestBB, SinkedBB)) + BBsDominatedByColdestBB.insert(SinkedBB); + if (BBsDominatedByColdestBB.size() == 0) + continue; + if (adjustedSumFreq(BBsDominatedByColdestBB, BFI) > + BFI.getBlockFreq(ColdestBB)) { + for (BasicBlock *DominatedBB : BBsDominatedByColdestBB) { + BBsToSinkInto.erase(DominatedBB); + } + BBsToSinkInto.insert(ColdestBB); + } + } + + // If the total frequency of BBsToSinkInto is larger than preheader frequency, + // do not sink. + if (adjustedSumFreq(BBsToSinkInto, BFI) > + BFI.getBlockFreq(L.getLoopPreheader())) + BBsToSinkInto.clear(); + return BBsToSinkInto; +} + +// Sinks \p I from the loop \p L's preheader to its uses. Returns true if +// sinking is successful. +// \p LoopBlockNumber is used to sort the insertion blocks to ensure +// determinism. +static bool sinkInstruction(Loop &L, Instruction &I, + const SmallVectorImpl<BasicBlock *> &ColdLoopBBs, + const SmallDenseMap<BasicBlock *, int, 16> &LoopBlockNumber, + LoopInfo &LI, DominatorTree &DT, + BlockFrequencyInfo &BFI) { + // Compute the set of blocks in loop L which contain a use of I. + SmallPtrSet<BasicBlock *, 2> BBs; + for (auto &U : I.uses()) { + Instruction *UI = cast<Instruction>(U.getUser()); + // We cannot sink I to PHI-uses. + if (dyn_cast<PHINode>(UI)) + return false; + // We cannot sink I if it has uses outside of the loop. + if (!L.contains(LI.getLoopFor(UI->getParent()))) + return false; + BBs.insert(UI->getParent()); + } + + // findBBsToSinkInto is O(BBs.size() * ColdLoopBBs.size()). We cap the max + // BBs.size() to avoid expensive computation. + // FIXME: Handle code size growth for min_size and opt_size. + if (BBs.size() > MaxNumberOfUseBBsForSinking) + return false; + + // Find the set of BBs that we should insert a copy of I. + SmallPtrSet<BasicBlock *, 2> BBsToSinkInto = + findBBsToSinkInto(L, BBs, ColdLoopBBs, DT, BFI); + if (BBsToSinkInto.empty()) + return false; + + // Copy the final BBs into a vector and sort them using the total ordering + // of the loop block numbers as iterating the set doesn't give a useful + // order. No need to stable sort as the block numbers are a total ordering. + SmallVector<BasicBlock *, 2> SortedBBsToSinkInto; + SortedBBsToSinkInto.insert(SortedBBsToSinkInto.begin(), BBsToSinkInto.begin(), + BBsToSinkInto.end()); + std::sort(SortedBBsToSinkInto.begin(), SortedBBsToSinkInto.end(), + [&](BasicBlock *A, BasicBlock *B) { + return *LoopBlockNumber.find(A) < *LoopBlockNumber.find(B); + }); + + BasicBlock *MoveBB = *SortedBBsToSinkInto.begin(); + // FIXME: Optimize the efficiency for cloned value replacement. The current + // implementation is O(SortedBBsToSinkInto.size() * I.num_uses()). + for (BasicBlock *N : SortedBBsToSinkInto) { + if (N == MoveBB) + continue; + // Clone I and replace its uses. + Instruction *IC = I.clone(); + IC->setName(I.getName()); + IC->insertBefore(&*N->getFirstInsertionPt()); + // Replaces uses of I with IC in N + for (Value::use_iterator UI = I.use_begin(), UE = I.use_end(); UI != UE;) { + Use &U = *UI++; + auto *I = cast<Instruction>(U.getUser()); + if (I->getParent() == N) + U.set(IC); + } + // Replaces uses of I with IC in blocks dominated by N + replaceDominatedUsesWith(&I, IC, DT, N); + DEBUG(dbgs() << "Sinking a clone of " << I << " To: " << N->getName() + << '\n'); + NumLoopSunkCloned++; + } + DEBUG(dbgs() << "Sinking " << I << " To: " << MoveBB->getName() << '\n'); + NumLoopSunk++; + I.moveBefore(&*MoveBB->getFirstInsertionPt()); + + return true; +} + +/// Sinks instructions from loop's preheader to the loop body if the +/// sum frequency of inserted copy is smaller than preheader's frequency. +static bool sinkLoopInvariantInstructions(Loop &L, AAResults &AA, LoopInfo &LI, + DominatorTree &DT, + BlockFrequencyInfo &BFI, + ScalarEvolution *SE) { + BasicBlock *Preheader = L.getLoopPreheader(); + if (!Preheader) + return false; + + // Enable LoopSink only when runtime profile is available. + // With static profile, the sinking decision may be sub-optimal. + if (!Preheader->getParent()->hasProfileData()) + return false; + + const BlockFrequency PreheaderFreq = BFI.getBlockFreq(Preheader); + // If there are no basic blocks with lower frequency than the preheader then + // we can avoid the detailed analysis as we will never find profitable sinking + // opportunities. + if (all_of(L.blocks(), [&](const BasicBlock *BB) { + return BFI.getBlockFreq(BB) > PreheaderFreq; + })) + return false; + + bool Changed = false; + AliasSetTracker CurAST(AA); + + // Compute alias set. + for (BasicBlock *BB : L.blocks()) + CurAST.add(*BB); + + // Sort loop's basic blocks by frequency + SmallVector<BasicBlock *, 10> ColdLoopBBs; + SmallDenseMap<BasicBlock *, int, 16> LoopBlockNumber; + int i = 0; + for (BasicBlock *B : L.blocks()) + if (BFI.getBlockFreq(B) < BFI.getBlockFreq(L.getLoopPreheader())) { + ColdLoopBBs.push_back(B); + LoopBlockNumber[B] = ++i; + } + std::stable_sort(ColdLoopBBs.begin(), ColdLoopBBs.end(), + [&](BasicBlock *A, BasicBlock *B) { + return BFI.getBlockFreq(A) < BFI.getBlockFreq(B); + }); + + // Traverse preheader's instructions in reverse order becaue if A depends + // on B (A appears after B), A needs to be sinked first before B can be + // sinked. + for (auto II = Preheader->rbegin(), E = Preheader->rend(); II != E;) { + Instruction *I = &*II++; + // No need to check for instruction's operands are loop invariant. + assert(L.hasLoopInvariantOperands(I) && + "Insts in a loop's preheader should have loop invariant operands!"); + if (!canSinkOrHoistInst(*I, &AA, &DT, &L, &CurAST, nullptr)) + continue; + if (sinkInstruction(L, *I, ColdLoopBBs, LoopBlockNumber, LI, DT, BFI)) + Changed = true; + } + + if (Changed && SE) + SE->forgetLoopDispositions(&L); + return Changed; +} + +PreservedAnalyses LoopSinkPass::run(Function &F, FunctionAnalysisManager &FAM) { + LoopInfo &LI = FAM.getResult<LoopAnalysis>(F); + // Nothing to do if there are no loops. + if (LI.empty()) + return PreservedAnalyses::all(); + + AAResults &AA = FAM.getResult<AAManager>(F); + DominatorTree &DT = FAM.getResult<DominatorTreeAnalysis>(F); + BlockFrequencyInfo &BFI = FAM.getResult<BlockFrequencyAnalysis>(F); + + // We want to do a postorder walk over the loops. Since loops are a tree this + // is equivalent to a reversed preorder walk and preorder is easy to compute + // without recursion. Since we reverse the preorder, we will visit siblings + // in reverse program order. This isn't expected to matter at all but is more + // consistent with sinking algorithms which generally work bottom-up. + SmallVector<Loop *, 4> PreorderLoops = LI.getLoopsInPreorder(); + + bool Changed = false; + do { + Loop &L = *PreorderLoops.pop_back_val(); + + // Note that we don't pass SCEV here because it is only used to invalidate + // loops in SCEV and we don't preserve (or request) SCEV at all making that + // unnecessary. + Changed |= sinkLoopInvariantInstructions(L, AA, LI, DT, BFI, + /*ScalarEvolution*/ nullptr); + } while (!PreorderLoops.empty()); + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +namespace { +struct LegacyLoopSinkPass : public LoopPass { + static char ID; + LegacyLoopSinkPass() : LoopPass(ID) { + initializeLegacyLoopSinkPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + auto *SE = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>(); + return sinkLoopInvariantInstructions( + *L, getAnalysis<AAResultsWrapperPass>().getAAResults(), + getAnalysis<LoopInfoWrapperPass>().getLoopInfo(), + getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI(), + SE ? &SE->getSE() : nullptr); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<BlockFrequencyInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; +} + +char LegacyLoopSinkPass::ID = 0; +INITIALIZE_PASS_BEGIN(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false, + false) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass) +INITIALIZE_PASS_END(LegacyLoopSinkPass, "loop-sink", "Loop Sink", false, false) + +Pass *llvm::createLoopSinkPass() { return new LegacyLoopSinkPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp new file mode 100644 index 000000000000..ff3e9eef16d9 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp @@ -0,0 +1,5518 @@ +//===- LoopStrengthReduce.cpp - Strength Reduce IVs in Loops --------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This transformation analyzes and transforms the induction variables (and +// computations derived from them) into forms suitable for efficient execution +// on the target. +// +// This pass performs a strength reduction on array references inside loops that +// have as one or more of their components the loop induction variable, it +// rewrites expressions to take advantage of scaled-index addressing modes +// available on the target, and it performs a variety of other optimizations +// related to loop induction variables. +// +// Terminology note: this code has a lot of handling for "post-increment" or +// "post-inc" users. This is not talking about post-increment addressing modes; +// it is instead talking about code like this: +// +// %i = phi [ 0, %entry ], [ %i.next, %latch ] +// ... +// %i.next = add %i, 1 +// %c = icmp eq %i.next, %n +// +// The SCEV for %i is {0,+,1}<%L>. The SCEV for %i.next is {1,+,1}<%L>, however +// it's useful to think about these as the same register, with some uses using +// the value of the register before the add and some using it after. In this +// example, the icmp is a post-increment user, since it uses %i.next, which is +// the value of the induction variable after the increment. The other common +// case of post-increment users is users outside the loop. +// +// TODO: More sophistication in the way Formulae are generated and filtered. +// +// TODO: Handle multiple loops at a time. +// +// TODO: Should the addressing mode BaseGV be changed to a ConstantExpr instead +// of a GlobalValue? +// +// TODO: When truncation is free, truncate ICmp users' operands to make it a +// smaller encoding (on x86 at least). +// +// TODO: When a negated register is used by an add (such as in a list of +// multiple base registers, or as the increment expression in an addrec), +// we may not actually need both reg and (-1 * reg) in registers; the +// negation can be implemented by using a sub instead of an add. The +// lack of support for taking this into consideration when making +// register pressure decisions is partly worked around by the "Special" +// use kind. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopStrengthReduce.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/IVUsers.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/ScalarEvolutionExpander.h" +#include "llvm/Analysis/ScalarEvolutionExpressions.h" +#include "llvm/Analysis/ScalarEvolutionNormalization.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GlobalValue.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/OperandTraits.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdlib> +#include <iterator> +#include <limits> +#include <map> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "loop-reduce" + +/// MaxIVUsers is an arbitrary threshold that provides an early opportunitiy for +/// bail out. This threshold is far beyond the number of users that LSR can +/// conceivably solve, so it should not affect generated code, but catches the +/// worst cases before LSR burns too much compile time and stack space. +static const unsigned MaxIVUsers = 200; + +// Temporary flag to cleanup congruent phis after LSR phi expansion. +// It's currently disabled until we can determine whether it's truly useful or +// not. The flag should be removed after the v3.0 release. +// This is now needed for ivchains. +static cl::opt<bool> EnablePhiElim( + "enable-lsr-phielim", cl::Hidden, cl::init(true), + cl::desc("Enable LSR phi elimination")); + +// The flag adds instruction count to solutions cost comparision. +static cl::opt<bool> InsnsCost( + "lsr-insns-cost", cl::Hidden, cl::init(true), + cl::desc("Add instruction count to a LSR cost model")); + +// Flag to choose how to narrow complex lsr solution +static cl::opt<bool> LSRExpNarrow( + "lsr-exp-narrow", cl::Hidden, cl::init(false), + cl::desc("Narrow LSR complex solution using" + " expectation of registers number")); + +// Flag to narrow search space by filtering non-optimal formulae with +// the same ScaledReg and Scale. +static cl::opt<bool> FilterSameScaledReg( + "lsr-filter-same-scaled-reg", cl::Hidden, cl::init(true), + cl::desc("Narrow LSR search space by filtering non-optimal formulae" + " with the same ScaledReg and Scale")); + +#ifndef NDEBUG +// Stress test IV chain generation. +static cl::opt<bool> StressIVChain( + "stress-ivchain", cl::Hidden, cl::init(false), + cl::desc("Stress test LSR IV chains")); +#else +static bool StressIVChain = false; +#endif + +namespace { + +struct MemAccessTy { + /// Used in situations where the accessed memory type is unknown. + static const unsigned UnknownAddressSpace = + std::numeric_limits<unsigned>::max(); + + Type *MemTy = nullptr; + unsigned AddrSpace = UnknownAddressSpace; + + MemAccessTy() = default; + MemAccessTy(Type *Ty, unsigned AS) : MemTy(Ty), AddrSpace(AS) {} + + bool operator==(MemAccessTy Other) const { + return MemTy == Other.MemTy && AddrSpace == Other.AddrSpace; + } + + bool operator!=(MemAccessTy Other) const { return !(*this == Other); } + + static MemAccessTy getUnknown(LLVMContext &Ctx, + unsigned AS = UnknownAddressSpace) { + return MemAccessTy(Type::getVoidTy(Ctx), AS); + } +}; + +/// This class holds data which is used to order reuse candidates. +class RegSortData { +public: + /// This represents the set of LSRUse indices which reference + /// a particular register. + SmallBitVector UsedByIndices; + + void print(raw_ostream &OS) const; + void dump() const; +}; + +} // end anonymous namespace + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void RegSortData::print(raw_ostream &OS) const { + OS << "[NumUses=" << UsedByIndices.count() << ']'; +} + +LLVM_DUMP_METHOD void RegSortData::dump() const { + print(errs()); errs() << '\n'; +} +#endif + +namespace { + +/// Map register candidates to information about how they are used. +class RegUseTracker { + using RegUsesTy = DenseMap<const SCEV *, RegSortData>; + + RegUsesTy RegUsesMap; + SmallVector<const SCEV *, 16> RegSequence; + +public: + void countRegister(const SCEV *Reg, size_t LUIdx); + void dropRegister(const SCEV *Reg, size_t LUIdx); + void swapAndDropUse(size_t LUIdx, size_t LastLUIdx); + + bool isRegUsedByUsesOtherThan(const SCEV *Reg, size_t LUIdx) const; + + const SmallBitVector &getUsedByIndices(const SCEV *Reg) const; + + void clear(); + + using iterator = SmallVectorImpl<const SCEV *>::iterator; + using const_iterator = SmallVectorImpl<const SCEV *>::const_iterator; + + iterator begin() { return RegSequence.begin(); } + iterator end() { return RegSequence.end(); } + const_iterator begin() const { return RegSequence.begin(); } + const_iterator end() const { return RegSequence.end(); } +}; + +} // end anonymous namespace + +void +RegUseTracker::countRegister(const SCEV *Reg, size_t LUIdx) { + std::pair<RegUsesTy::iterator, bool> Pair = + RegUsesMap.insert(std::make_pair(Reg, RegSortData())); + RegSortData &RSD = Pair.first->second; + if (Pair.second) + RegSequence.push_back(Reg); + RSD.UsedByIndices.resize(std::max(RSD.UsedByIndices.size(), LUIdx + 1)); + RSD.UsedByIndices.set(LUIdx); +} + +void +RegUseTracker::dropRegister(const SCEV *Reg, size_t LUIdx) { + RegUsesTy::iterator It = RegUsesMap.find(Reg); + assert(It != RegUsesMap.end()); + RegSortData &RSD = It->second; + assert(RSD.UsedByIndices.size() > LUIdx); + RSD.UsedByIndices.reset(LUIdx); +} + +void +RegUseTracker::swapAndDropUse(size_t LUIdx, size_t LastLUIdx) { + assert(LUIdx <= LastLUIdx); + + // Update RegUses. The data structure is not optimized for this purpose; + // we must iterate through it and update each of the bit vectors. + for (auto &Pair : RegUsesMap) { + SmallBitVector &UsedByIndices = Pair.second.UsedByIndices; + if (LUIdx < UsedByIndices.size()) + UsedByIndices[LUIdx] = + LastLUIdx < UsedByIndices.size() ? UsedByIndices[LastLUIdx] : false; + UsedByIndices.resize(std::min(UsedByIndices.size(), LastLUIdx)); + } +} + +bool +RegUseTracker::isRegUsedByUsesOtherThan(const SCEV *Reg, size_t LUIdx) const { + RegUsesTy::const_iterator I = RegUsesMap.find(Reg); + if (I == RegUsesMap.end()) + return false; + const SmallBitVector &UsedByIndices = I->second.UsedByIndices; + int i = UsedByIndices.find_first(); + if (i == -1) return false; + if ((size_t)i != LUIdx) return true; + return UsedByIndices.find_next(i) != -1; +} + +const SmallBitVector &RegUseTracker::getUsedByIndices(const SCEV *Reg) const { + RegUsesTy::const_iterator I = RegUsesMap.find(Reg); + assert(I != RegUsesMap.end() && "Unknown register!"); + return I->second.UsedByIndices; +} + +void RegUseTracker::clear() { + RegUsesMap.clear(); + RegSequence.clear(); +} + +namespace { + +/// This class holds information that describes a formula for computing +/// satisfying a use. It may include broken-out immediates and scaled registers. +struct Formula { + /// Global base address used for complex addressing. + GlobalValue *BaseGV = nullptr; + + /// Base offset for complex addressing. + int64_t BaseOffset = 0; + + /// Whether any complex addressing has a base register. + bool HasBaseReg = false; + + /// The scale of any complex addressing. + int64_t Scale = 0; + + /// The list of "base" registers for this use. When this is non-empty. The + /// canonical representation of a formula is + /// 1. BaseRegs.size > 1 implies ScaledReg != NULL and + /// 2. ScaledReg != NULL implies Scale != 1 || !BaseRegs.empty(). + /// 3. The reg containing recurrent expr related with currect loop in the + /// formula should be put in the ScaledReg. + /// #1 enforces that the scaled register is always used when at least two + /// registers are needed by the formula: e.g., reg1 + reg2 is reg1 + 1 * reg2. + /// #2 enforces that 1 * reg is reg. + /// #3 ensures invariant regs with respect to current loop can be combined + /// together in LSR codegen. + /// This invariant can be temporarly broken while building a formula. + /// However, every formula inserted into the LSRInstance must be in canonical + /// form. + SmallVector<const SCEV *, 4> BaseRegs; + + /// The 'scaled' register for this use. This should be non-null when Scale is + /// not zero. + const SCEV *ScaledReg = nullptr; + + /// An additional constant offset which added near the use. This requires a + /// temporary register, but the offset itself can live in an add immediate + /// field rather than a register. + int64_t UnfoldedOffset = 0; + + Formula() = default; + + void initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE); + + bool isCanonical(const Loop &L) const; + + void canonicalize(const Loop &L); + + bool unscale(); + + bool hasZeroEnd() const; + + size_t getNumRegs() const; + Type *getType() const; + + void deleteBaseReg(const SCEV *&S); + + bool referencesReg(const SCEV *S) const; + bool hasRegsUsedByUsesOtherThan(size_t LUIdx, + const RegUseTracker &RegUses) const; + + void print(raw_ostream &OS) const; + void dump() const; +}; + +} // end anonymous namespace + +/// Recursion helper for initialMatch. +static void DoInitialMatch(const SCEV *S, Loop *L, + SmallVectorImpl<const SCEV *> &Good, + SmallVectorImpl<const SCEV *> &Bad, + ScalarEvolution &SE) { + // Collect expressions which properly dominate the loop header. + if (SE.properlyDominates(S, L->getHeader())) { + Good.push_back(S); + return; + } + + // Look at add operands. + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + for (const SCEV *S : Add->operands()) + DoInitialMatch(S, L, Good, Bad, SE); + return; + } + + // Look at addrec operands. + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) + if (!AR->getStart()->isZero() && AR->isAffine()) { + DoInitialMatch(AR->getStart(), L, Good, Bad, SE); + DoInitialMatch(SE.getAddRecExpr(SE.getConstant(AR->getType(), 0), + AR->getStepRecurrence(SE), + // FIXME: AR->getNoWrapFlags() + AR->getLoop(), SCEV::FlagAnyWrap), + L, Good, Bad, SE); + return; + } + + // Handle a multiplication by -1 (negation) if it didn't fold. + if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) + if (Mul->getOperand(0)->isAllOnesValue()) { + SmallVector<const SCEV *, 4> Ops(Mul->op_begin()+1, Mul->op_end()); + const SCEV *NewMul = SE.getMulExpr(Ops); + + SmallVector<const SCEV *, 4> MyGood; + SmallVector<const SCEV *, 4> MyBad; + DoInitialMatch(NewMul, L, MyGood, MyBad, SE); + const SCEV *NegOne = SE.getSCEV(ConstantInt::getAllOnesValue( + SE.getEffectiveSCEVType(NewMul->getType()))); + for (const SCEV *S : MyGood) + Good.push_back(SE.getMulExpr(NegOne, S)); + for (const SCEV *S : MyBad) + Bad.push_back(SE.getMulExpr(NegOne, S)); + return; + } + + // Ok, we can't do anything interesting. Just stuff the whole thing into a + // register and hope for the best. + Bad.push_back(S); +} + +/// Incorporate loop-variant parts of S into this Formula, attempting to keep +/// all loop-invariant and loop-computable values in a single base register. +void Formula::initialMatch(const SCEV *S, Loop *L, ScalarEvolution &SE) { + SmallVector<const SCEV *, 4> Good; + SmallVector<const SCEV *, 4> Bad; + DoInitialMatch(S, L, Good, Bad, SE); + if (!Good.empty()) { + const SCEV *Sum = SE.getAddExpr(Good); + if (!Sum->isZero()) + BaseRegs.push_back(Sum); + HasBaseReg = true; + } + if (!Bad.empty()) { + const SCEV *Sum = SE.getAddExpr(Bad); + if (!Sum->isZero()) + BaseRegs.push_back(Sum); + HasBaseReg = true; + } + canonicalize(*L); +} + +/// \brief Check whether or not this formula statisfies the canonical +/// representation. +/// \see Formula::BaseRegs. +bool Formula::isCanonical(const Loop &L) const { + if (!ScaledReg) + return BaseRegs.size() <= 1; + + if (Scale != 1) + return true; + + if (Scale == 1 && BaseRegs.empty()) + return false; + + const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); + if (SAR && SAR->getLoop() == &L) + return true; + + // If ScaledReg is not a recurrent expr, or it is but its loop is not current + // loop, meanwhile BaseRegs contains a recurrent expr reg related with current + // loop, we want to swap the reg in BaseRegs with ScaledReg. + auto I = + find_if(make_range(BaseRegs.begin(), BaseRegs.end()), [&](const SCEV *S) { + return isa<const SCEVAddRecExpr>(S) && + (cast<SCEVAddRecExpr>(S)->getLoop() == &L); + }); + return I == BaseRegs.end(); +} + +/// \brief Helper method to morph a formula into its canonical representation. +/// \see Formula::BaseRegs. +/// Every formula having more than one base register, must use the ScaledReg +/// field. Otherwise, we would have to do special cases everywhere in LSR +/// to treat reg1 + reg2 + ... the same way as reg1 + 1*reg2 + ... +/// On the other hand, 1*reg should be canonicalized into reg. +void Formula::canonicalize(const Loop &L) { + if (isCanonical(L)) + return; + // So far we did not need this case. This is easy to implement but it is + // useless to maintain dead code. Beside it could hurt compile time. + assert(!BaseRegs.empty() && "1*reg => reg, should not be needed."); + + // Keep the invariant sum in BaseRegs and one of the variant sum in ScaledReg. + if (!ScaledReg) { + ScaledReg = BaseRegs.back(); + BaseRegs.pop_back(); + Scale = 1; + } + + // If ScaledReg is an invariant with respect to L, find the reg from + // BaseRegs containing the recurrent expr related with Loop L. Swap the + // reg with ScaledReg. + const SCEVAddRecExpr *SAR = dyn_cast<const SCEVAddRecExpr>(ScaledReg); + if (!SAR || SAR->getLoop() != &L) { + auto I = find_if(make_range(BaseRegs.begin(), BaseRegs.end()), + [&](const SCEV *S) { + return isa<const SCEVAddRecExpr>(S) && + (cast<SCEVAddRecExpr>(S)->getLoop() == &L); + }); + if (I != BaseRegs.end()) + std::swap(ScaledReg, *I); + } +} + +/// \brief Get rid of the scale in the formula. +/// In other words, this method morphes reg1 + 1*reg2 into reg1 + reg2. +/// \return true if it was possible to get rid of the scale, false otherwise. +/// \note After this operation the formula may not be in the canonical form. +bool Formula::unscale() { + if (Scale != 1) + return false; + Scale = 0; + BaseRegs.push_back(ScaledReg); + ScaledReg = nullptr; + return true; +} + +bool Formula::hasZeroEnd() const { + if (UnfoldedOffset || BaseOffset) + return false; + if (BaseRegs.size() != 1 || ScaledReg) + return false; + return true; +} + +/// Return the total number of register operands used by this formula. This does +/// not include register uses implied by non-constant addrec strides. +size_t Formula::getNumRegs() const { + return !!ScaledReg + BaseRegs.size(); +} + +/// Return the type of this formula, if it has one, or null otherwise. This type +/// is meaningless except for the bit size. +Type *Formula::getType() const { + return !BaseRegs.empty() ? BaseRegs.front()->getType() : + ScaledReg ? ScaledReg->getType() : + BaseGV ? BaseGV->getType() : + nullptr; +} + +/// Delete the given base reg from the BaseRegs list. +void Formula::deleteBaseReg(const SCEV *&S) { + if (&S != &BaseRegs.back()) + std::swap(S, BaseRegs.back()); + BaseRegs.pop_back(); +} + +/// Test if this formula references the given register. +bool Formula::referencesReg(const SCEV *S) const { + return S == ScaledReg || is_contained(BaseRegs, S); +} + +/// Test whether this formula uses registers which are used by uses other than +/// the use with the given index. +bool Formula::hasRegsUsedByUsesOtherThan(size_t LUIdx, + const RegUseTracker &RegUses) const { + if (ScaledReg) + if (RegUses.isRegUsedByUsesOtherThan(ScaledReg, LUIdx)) + return true; + for (const SCEV *BaseReg : BaseRegs) + if (RegUses.isRegUsedByUsesOtherThan(BaseReg, LUIdx)) + return true; + return false; +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void Formula::print(raw_ostream &OS) const { + bool First = true; + if (BaseGV) { + if (!First) OS << " + "; else First = false; + BaseGV->printAsOperand(OS, /*PrintType=*/false); + } + if (BaseOffset != 0) { + if (!First) OS << " + "; else First = false; + OS << BaseOffset; + } + for (const SCEV *BaseReg : BaseRegs) { + if (!First) OS << " + "; else First = false; + OS << "reg(" << *BaseReg << ')'; + } + if (HasBaseReg && BaseRegs.empty()) { + if (!First) OS << " + "; else First = false; + OS << "**error: HasBaseReg**"; + } else if (!HasBaseReg && !BaseRegs.empty()) { + if (!First) OS << " + "; else First = false; + OS << "**error: !HasBaseReg**"; + } + if (Scale != 0) { + if (!First) OS << " + "; else First = false; + OS << Scale << "*reg("; + if (ScaledReg) + OS << *ScaledReg; + else + OS << "<unknown>"; + OS << ')'; + } + if (UnfoldedOffset != 0) { + if (!First) OS << " + "; + OS << "imm(" << UnfoldedOffset << ')'; + } +} + +LLVM_DUMP_METHOD void Formula::dump() const { + print(errs()); errs() << '\n'; +} +#endif + +/// Return true if the given addrec can be sign-extended without changing its +/// value. +static bool isAddRecSExtable(const SCEVAddRecExpr *AR, ScalarEvolution &SE) { + Type *WideTy = + IntegerType::get(SE.getContext(), SE.getTypeSizeInBits(AR->getType()) + 1); + return isa<SCEVAddRecExpr>(SE.getSignExtendExpr(AR, WideTy)); +} + +/// Return true if the given add can be sign-extended without changing its +/// value. +static bool isAddSExtable(const SCEVAddExpr *A, ScalarEvolution &SE) { + Type *WideTy = + IntegerType::get(SE.getContext(), SE.getTypeSizeInBits(A->getType()) + 1); + return isa<SCEVAddExpr>(SE.getSignExtendExpr(A, WideTy)); +} + +/// Return true if the given mul can be sign-extended without changing its +/// value. +static bool isMulSExtable(const SCEVMulExpr *M, ScalarEvolution &SE) { + Type *WideTy = + IntegerType::get(SE.getContext(), + SE.getTypeSizeInBits(M->getType()) * M->getNumOperands()); + return isa<SCEVMulExpr>(SE.getSignExtendExpr(M, WideTy)); +} + +/// Return an expression for LHS /s RHS, if it can be determined and if the +/// remainder is known to be zero, or null otherwise. If IgnoreSignificantBits +/// is true, expressions like (X * Y) /s Y are simplified to Y, ignoring that +/// the multiplication may overflow, which is useful when the result will be +/// used in a context where the most significant bits are ignored. +static const SCEV *getExactSDiv(const SCEV *LHS, const SCEV *RHS, + ScalarEvolution &SE, + bool IgnoreSignificantBits = false) { + // Handle the trivial case, which works for any SCEV type. + if (LHS == RHS) + return SE.getConstant(LHS->getType(), 1); + + // Handle a few RHS special cases. + const SCEVConstant *RC = dyn_cast<SCEVConstant>(RHS); + if (RC) { + const APInt &RA = RC->getAPInt(); + // Handle x /s -1 as x * -1, to give ScalarEvolution a chance to do + // some folding. + if (RA.isAllOnesValue()) + return SE.getMulExpr(LHS, RC); + // Handle x /s 1 as x. + if (RA == 1) + return LHS; + } + + // Check for a division of a constant by a constant. + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(LHS)) { + if (!RC) + return nullptr; + const APInt &LA = C->getAPInt(); + const APInt &RA = RC->getAPInt(); + if (LA.srem(RA) != 0) + return nullptr; + return SE.getConstant(LA.sdiv(RA)); + } + + // Distribute the sdiv over addrec operands, if the addrec doesn't overflow. + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(LHS)) { + if ((IgnoreSignificantBits || isAddRecSExtable(AR, SE)) && AR->isAffine()) { + const SCEV *Step = getExactSDiv(AR->getStepRecurrence(SE), RHS, SE, + IgnoreSignificantBits); + if (!Step) return nullptr; + const SCEV *Start = getExactSDiv(AR->getStart(), RHS, SE, + IgnoreSignificantBits); + if (!Start) return nullptr; + // FlagNW is independent of the start value, step direction, and is + // preserved with smaller magnitude steps. + // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) + return SE.getAddRecExpr(Start, Step, AR->getLoop(), SCEV::FlagAnyWrap); + } + return nullptr; + } + + // Distribute the sdiv over add operands, if the add doesn't overflow. + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(LHS)) { + if (IgnoreSignificantBits || isAddSExtable(Add, SE)) { + SmallVector<const SCEV *, 8> Ops; + for (const SCEV *S : Add->operands()) { + const SCEV *Op = getExactSDiv(S, RHS, SE, IgnoreSignificantBits); + if (!Op) return nullptr; + Ops.push_back(Op); + } + return SE.getAddExpr(Ops); + } + return nullptr; + } + + // Check for a multiply operand that we can pull RHS out of. + if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(LHS)) { + if (IgnoreSignificantBits || isMulSExtable(Mul, SE)) { + SmallVector<const SCEV *, 4> Ops; + bool Found = false; + for (const SCEV *S : Mul->operands()) { + if (!Found) + if (const SCEV *Q = getExactSDiv(S, RHS, SE, + IgnoreSignificantBits)) { + S = Q; + Found = true; + } + Ops.push_back(S); + } + return Found ? SE.getMulExpr(Ops) : nullptr; + } + return nullptr; + } + + // Otherwise we don't know. + return nullptr; +} + +/// If S involves the addition of a constant integer value, return that integer +/// value, and mutate S to point to a new SCEV with that value excluded. +static int64_t ExtractImmediate(const SCEV *&S, ScalarEvolution &SE) { + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(S)) { + if (C->getAPInt().getMinSignedBits() <= 64) { + S = SE.getConstant(C->getType(), 0); + return C->getValue()->getSExtValue(); + } + } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + SmallVector<const SCEV *, 8> NewOps(Add->op_begin(), Add->op_end()); + int64_t Result = ExtractImmediate(NewOps.front(), SE); + if (Result != 0) + S = SE.getAddExpr(NewOps); + return Result; + } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + SmallVector<const SCEV *, 8> NewOps(AR->op_begin(), AR->op_end()); + int64_t Result = ExtractImmediate(NewOps.front(), SE); + if (Result != 0) + S = SE.getAddRecExpr(NewOps, AR->getLoop(), + // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) + SCEV::FlagAnyWrap); + return Result; + } + return 0; +} + +/// If S involves the addition of a GlobalValue address, return that symbol, and +/// mutate S to point to a new SCEV with that value excluded. +static GlobalValue *ExtractSymbol(const SCEV *&S, ScalarEvolution &SE) { + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) { + if (GlobalValue *GV = dyn_cast<GlobalValue>(U->getValue())) { + S = SE.getConstant(GV->getType(), 0); + return GV; + } + } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + SmallVector<const SCEV *, 8> NewOps(Add->op_begin(), Add->op_end()); + GlobalValue *Result = ExtractSymbol(NewOps.back(), SE); + if (Result) + S = SE.getAddExpr(NewOps); + return Result; + } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + SmallVector<const SCEV *, 8> NewOps(AR->op_begin(), AR->op_end()); + GlobalValue *Result = ExtractSymbol(NewOps.front(), SE); + if (Result) + S = SE.getAddRecExpr(NewOps, AR->getLoop(), + // FIXME: AR->getNoWrapFlags(SCEV::FlagNW) + SCEV::FlagAnyWrap); + return Result; + } + return nullptr; +} + +/// Returns true if the specified instruction is using the specified value as an +/// address. +static bool isAddressUse(const TargetTransformInfo &TTI, + Instruction *Inst, Value *OperandVal) { + bool isAddress = isa<LoadInst>(Inst); + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + if (SI->getPointerOperand() == OperandVal) + isAddress = true; + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + // Addressing modes can also be folded into prefetches and a variety + // of intrinsics. + switch (II->getIntrinsicID()) { + case Intrinsic::memset: + case Intrinsic::prefetch: + if (II->getArgOperand(0) == OperandVal) + isAddress = true; + break; + case Intrinsic::memmove: + case Intrinsic::memcpy: + if (II->getArgOperand(0) == OperandVal || + II->getArgOperand(1) == OperandVal) + isAddress = true; + break; + default: { + MemIntrinsicInfo IntrInfo; + if (TTI.getTgtMemIntrinsic(II, IntrInfo)) { + if (IntrInfo.PtrVal == OperandVal) + isAddress = true; + } + } + } + } else if (AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) { + if (RMW->getPointerOperand() == OperandVal) + isAddress = true; + } else if (AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { + if (CmpX->getPointerOperand() == OperandVal) + isAddress = true; + } + return isAddress; +} + +/// Return the type of the memory being accessed. +static MemAccessTy getAccessType(const TargetTransformInfo &TTI, + Instruction *Inst) { + MemAccessTy AccessTy(Inst->getType(), MemAccessTy::UnknownAddressSpace); + if (const StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + AccessTy.MemTy = SI->getOperand(0)->getType(); + AccessTy.AddrSpace = SI->getPointerAddressSpace(); + } else if (const LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + AccessTy.AddrSpace = LI->getPointerAddressSpace(); + } else if (const AtomicRMWInst *RMW = dyn_cast<AtomicRMWInst>(Inst)) { + AccessTy.AddrSpace = RMW->getPointerAddressSpace(); + } else if (const AtomicCmpXchgInst *CmpX = dyn_cast<AtomicCmpXchgInst>(Inst)) { + AccessTy.AddrSpace = CmpX->getPointerAddressSpace(); + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::prefetch: + AccessTy.AddrSpace = II->getArgOperand(0)->getType()->getPointerAddressSpace(); + break; + default: { + MemIntrinsicInfo IntrInfo; + if (TTI.getTgtMemIntrinsic(II, IntrInfo) && IntrInfo.PtrVal) { + AccessTy.AddrSpace + = IntrInfo.PtrVal->getType()->getPointerAddressSpace(); + } + + break; + } + } + } + + // All pointers have the same requirements, so canonicalize them to an + // arbitrary pointer type to minimize variation. + if (PointerType *PTy = dyn_cast<PointerType>(AccessTy.MemTy)) + AccessTy.MemTy = PointerType::get(IntegerType::get(PTy->getContext(), 1), + PTy->getAddressSpace()); + + return AccessTy; +} + +/// Return true if this AddRec is already a phi in its loop. +static bool isExistingPhi(const SCEVAddRecExpr *AR, ScalarEvolution &SE) { + for (PHINode &PN : AR->getLoop()->getHeader()->phis()) { + if (SE.isSCEVable(PN.getType()) && + (SE.getEffectiveSCEVType(PN.getType()) == + SE.getEffectiveSCEVType(AR->getType())) && + SE.getSCEV(&PN) == AR) + return true; + } + return false; +} + +/// Check if expanding this expression is likely to incur significant cost. This +/// is tricky because SCEV doesn't track which expressions are actually computed +/// by the current IR. +/// +/// We currently allow expansion of IV increments that involve adds, +/// multiplication by constants, and AddRecs from existing phis. +/// +/// TODO: Allow UDivExpr if we can find an existing IV increment that is an +/// obvious multiple of the UDivExpr. +static bool isHighCostExpansion(const SCEV *S, + SmallPtrSetImpl<const SCEV*> &Processed, + ScalarEvolution &SE) { + // Zero/One operand expressions + switch (S->getSCEVType()) { + case scUnknown: + case scConstant: + return false; + case scTruncate: + return isHighCostExpansion(cast<SCEVTruncateExpr>(S)->getOperand(), + Processed, SE); + case scZeroExtend: + return isHighCostExpansion(cast<SCEVZeroExtendExpr>(S)->getOperand(), + Processed, SE); + case scSignExtend: + return isHighCostExpansion(cast<SCEVSignExtendExpr>(S)->getOperand(), + Processed, SE); + } + + if (!Processed.insert(S).second) + return false; + + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + for (const SCEV *S : Add->operands()) { + if (isHighCostExpansion(S, Processed, SE)) + return true; + } + return false; + } + + if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { + if (Mul->getNumOperands() == 2) { + // Multiplication by a constant is ok + if (isa<SCEVConstant>(Mul->getOperand(0))) + return isHighCostExpansion(Mul->getOperand(1), Processed, SE); + + // If we have the value of one operand, check if an existing + // multiplication already generates this expression. + if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(Mul->getOperand(1))) { + Value *UVal = U->getValue(); + for (User *UR : UVal->users()) { + // If U is a constant, it may be used by a ConstantExpr. + Instruction *UI = dyn_cast<Instruction>(UR); + if (UI && UI->getOpcode() == Instruction::Mul && + SE.isSCEVable(UI->getType())) { + return SE.getSCEV(UI) == Mul; + } + } + } + } + } + + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + if (isExistingPhi(AR, SE)) + return false; + } + + // Fow now, consider any other type of expression (div/mul/min/max) high cost. + return true; +} + +/// If any of the instructions is the specified set are trivially dead, delete +/// them and see if this makes any of their operands subsequently dead. +static bool +DeleteTriviallyDeadInstructions(SmallVectorImpl<WeakTrackingVH> &DeadInsts) { + bool Changed = false; + + while (!DeadInsts.empty()) { + Value *V = DeadInsts.pop_back_val(); + Instruction *I = dyn_cast_or_null<Instruction>(V); + + if (!I || !isInstructionTriviallyDead(I)) + continue; + + for (Use &O : I->operands()) + if (Instruction *U = dyn_cast<Instruction>(O)) { + O = nullptr; + if (U->use_empty()) + DeadInsts.emplace_back(U); + } + + I->eraseFromParent(); + Changed = true; + } + + return Changed; +} + +namespace { + +class LSRUse; + +} // end anonymous namespace + +/// \brief Check if the addressing mode defined by \p F is completely +/// folded in \p LU at isel time. +/// This includes address-mode folding and special icmp tricks. +/// This function returns true if \p LU can accommodate what \p F +/// defines and up to 1 base + 1 scaled + offset. +/// In other words, if \p F has several base registers, this function may +/// still return true. Therefore, users still need to account for +/// additional base registers and/or unfolded offsets to derive an +/// accurate cost model. +static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, + const LSRUse &LU, const Formula &F); + +// Get the cost of the scaling factor used in F for LU. +static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, + const LSRUse &LU, const Formula &F, + const Loop &L); + +namespace { + +/// This class is used to measure and compare candidate formulae. +class Cost { + TargetTransformInfo::LSRCost C; + +public: + Cost() { + C.Insns = 0; + C.NumRegs = 0; + C.AddRecCost = 0; + C.NumIVMuls = 0; + C.NumBaseAdds = 0; + C.ImmCost = 0; + C.SetupCost = 0; + C.ScaleCost = 0; + } + + bool isLess(Cost &Other, const TargetTransformInfo &TTI); + + void Lose(); + +#ifndef NDEBUG + // Once any of the metrics loses, they must all remain losers. + bool isValid() { + return ((C.Insns | C.NumRegs | C.AddRecCost | C.NumIVMuls | C.NumBaseAdds + | C.ImmCost | C.SetupCost | C.ScaleCost) != ~0u) + || ((C.Insns & C.NumRegs & C.AddRecCost & C.NumIVMuls & C.NumBaseAdds + & C.ImmCost & C.SetupCost & C.ScaleCost) == ~0u); + } +#endif + + bool isLoser() { + assert(isValid() && "invalid cost"); + return C.NumRegs == ~0u; + } + + void RateFormula(const TargetTransformInfo &TTI, + const Formula &F, + SmallPtrSetImpl<const SCEV *> &Regs, + const DenseSet<const SCEV *> &VisitedRegs, + const Loop *L, + ScalarEvolution &SE, DominatorTree &DT, + const LSRUse &LU, + SmallPtrSetImpl<const SCEV *> *LoserRegs = nullptr); + + void print(raw_ostream &OS) const; + void dump() const; + +private: + void RateRegister(const SCEV *Reg, + SmallPtrSetImpl<const SCEV *> &Regs, + const Loop *L, + ScalarEvolution &SE, DominatorTree &DT); + void RatePrimaryRegister(const SCEV *Reg, + SmallPtrSetImpl<const SCEV *> &Regs, + const Loop *L, + ScalarEvolution &SE, DominatorTree &DT, + SmallPtrSetImpl<const SCEV *> *LoserRegs); +}; + +/// An operand value in an instruction which is to be replaced with some +/// equivalent, possibly strength-reduced, replacement. +struct LSRFixup { + /// The instruction which will be updated. + Instruction *UserInst = nullptr; + + /// The operand of the instruction which will be replaced. The operand may be + /// used more than once; every instance will be replaced. + Value *OperandValToReplace = nullptr; + + /// If this user is to use the post-incremented value of an induction + /// variable, this set is non-empty and holds the loops associated with the + /// induction variable. + PostIncLoopSet PostIncLoops; + + /// A constant offset to be added to the LSRUse expression. This allows + /// multiple fixups to share the same LSRUse with different offsets, for + /// example in an unrolled loop. + int64_t Offset = 0; + + LSRFixup() = default; + + bool isUseFullyOutsideLoop(const Loop *L) const; + + void print(raw_ostream &OS) const; + void dump() const; +}; + +/// A DenseMapInfo implementation for holding DenseMaps and DenseSets of sorted +/// SmallVectors of const SCEV*. +struct UniquifierDenseMapInfo { + static SmallVector<const SCEV *, 4> getEmptyKey() { + SmallVector<const SCEV *, 4> V; + V.push_back(reinterpret_cast<const SCEV *>(-1)); + return V; + } + + static SmallVector<const SCEV *, 4> getTombstoneKey() { + SmallVector<const SCEV *, 4> V; + V.push_back(reinterpret_cast<const SCEV *>(-2)); + return V; + } + + static unsigned getHashValue(const SmallVector<const SCEV *, 4> &V) { + return static_cast<unsigned>(hash_combine_range(V.begin(), V.end())); + } + + static bool isEqual(const SmallVector<const SCEV *, 4> &LHS, + const SmallVector<const SCEV *, 4> &RHS) { + return LHS == RHS; + } +}; + +/// This class holds the state that LSR keeps for each use in IVUsers, as well +/// as uses invented by LSR itself. It includes information about what kinds of +/// things can be folded into the user, information about the user itself, and +/// information about how the use may be satisfied. TODO: Represent multiple +/// users of the same expression in common? +class LSRUse { + DenseSet<SmallVector<const SCEV *, 4>, UniquifierDenseMapInfo> Uniquifier; + +public: + /// An enum for a kind of use, indicating what types of scaled and immediate + /// operands it might support. + enum KindType { + Basic, ///< A normal use, with no folding. + Special, ///< A special case of basic, allowing -1 scales. + Address, ///< An address use; folding according to TargetLowering + ICmpZero ///< An equality icmp with both operands folded into one. + // TODO: Add a generic icmp too? + }; + + using SCEVUseKindPair = PointerIntPair<const SCEV *, 2, KindType>; + + KindType Kind; + MemAccessTy AccessTy; + + /// The list of operands which are to be replaced. + SmallVector<LSRFixup, 8> Fixups; + + /// Keep track of the min and max offsets of the fixups. + int64_t MinOffset = std::numeric_limits<int64_t>::max(); + int64_t MaxOffset = std::numeric_limits<int64_t>::min(); + + /// This records whether all of the fixups using this LSRUse are outside of + /// the loop, in which case some special-case heuristics may be used. + bool AllFixupsOutsideLoop = true; + + /// RigidFormula is set to true to guarantee that this use will be associated + /// with a single formula--the one that initially matched. Some SCEV + /// expressions cannot be expanded. This allows LSR to consider the registers + /// used by those expressions without the need to expand them later after + /// changing the formula. + bool RigidFormula = false; + + /// This records the widest use type for any fixup using this + /// LSRUse. FindUseWithSimilarFormula can't consider uses with different max + /// fixup widths to be equivalent, because the narrower one may be relying on + /// the implicit truncation to truncate away bogus bits. + Type *WidestFixupType = nullptr; + + /// A list of ways to build a value that can satisfy this user. After the + /// list is populated, one of these is selected heuristically and used to + /// formulate a replacement for OperandValToReplace in UserInst. + SmallVector<Formula, 12> Formulae; + + /// The set of register candidates used by all formulae in this LSRUse. + SmallPtrSet<const SCEV *, 4> Regs; + + LSRUse(KindType K, MemAccessTy AT) : Kind(K), AccessTy(AT) {} + + LSRFixup &getNewFixup() { + Fixups.push_back(LSRFixup()); + return Fixups.back(); + } + + void pushFixup(LSRFixup &f) { + Fixups.push_back(f); + if (f.Offset > MaxOffset) + MaxOffset = f.Offset; + if (f.Offset < MinOffset) + MinOffset = f.Offset; + } + + bool HasFormulaWithSameRegs(const Formula &F) const; + float getNotSelectedProbability(const SCEV *Reg) const; + bool InsertFormula(const Formula &F, const Loop &L); + void DeleteFormula(Formula &F); + void RecomputeRegs(size_t LUIdx, RegUseTracker &Reguses); + + void print(raw_ostream &OS) const; + void dump() const; +}; + +} // end anonymous namespace + +static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, + LSRUse::KindType Kind, MemAccessTy AccessTy, + GlobalValue *BaseGV, int64_t BaseOffset, + bool HasBaseReg, int64_t Scale, + Instruction *Fixup = nullptr); + +/// Tally up interesting quantities from the given register. +void Cost::RateRegister(const SCEV *Reg, + SmallPtrSetImpl<const SCEV *> &Regs, + const Loop *L, + ScalarEvolution &SE, DominatorTree &DT) { + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Reg)) { + // If this is an addrec for another loop, it should be an invariant + // with respect to L since L is the innermost loop (at least + // for now LSR only handles innermost loops). + if (AR->getLoop() != L) { + // If the AddRec exists, consider it's register free and leave it alone. + if (isExistingPhi(AR, SE)) + return; + + // It is bad to allow LSR for current loop to add induction variables + // for its sibling loops. + if (!AR->getLoop()->contains(L)) { + Lose(); + return; + } + + // Otherwise, it will be an invariant with respect to Loop L. + ++C.NumRegs; + return; + } + C.AddRecCost += 1; /// TODO: This should be a function of the stride. + + // Add the step value register, if it needs one. + // TODO: The non-affine case isn't precisely modeled here. + if (!AR->isAffine() || !isa<SCEVConstant>(AR->getOperand(1))) { + if (!Regs.count(AR->getOperand(1))) { + RateRegister(AR->getOperand(1), Regs, L, SE, DT); + if (isLoser()) + return; + } + } + } + ++C.NumRegs; + + // Rough heuristic; favor registers which don't require extra setup + // instructions in the preheader. + if (!isa<SCEVUnknown>(Reg) && + !isa<SCEVConstant>(Reg) && + !(isa<SCEVAddRecExpr>(Reg) && + (isa<SCEVUnknown>(cast<SCEVAddRecExpr>(Reg)->getStart()) || + isa<SCEVConstant>(cast<SCEVAddRecExpr>(Reg)->getStart())))) + ++C.SetupCost; + + C.NumIVMuls += isa<SCEVMulExpr>(Reg) && + SE.hasComputableLoopEvolution(Reg, L); +} + +/// Record this register in the set. If we haven't seen it before, rate +/// it. Optional LoserRegs provides a way to declare any formula that refers to +/// one of those regs an instant loser. +void Cost::RatePrimaryRegister(const SCEV *Reg, + SmallPtrSetImpl<const SCEV *> &Regs, + const Loop *L, + ScalarEvolution &SE, DominatorTree &DT, + SmallPtrSetImpl<const SCEV *> *LoserRegs) { + if (LoserRegs && LoserRegs->count(Reg)) { + Lose(); + return; + } + if (Regs.insert(Reg).second) { + RateRegister(Reg, Regs, L, SE, DT); + if (LoserRegs && isLoser()) + LoserRegs->insert(Reg); + } +} + +void Cost::RateFormula(const TargetTransformInfo &TTI, + const Formula &F, + SmallPtrSetImpl<const SCEV *> &Regs, + const DenseSet<const SCEV *> &VisitedRegs, + const Loop *L, + ScalarEvolution &SE, DominatorTree &DT, + const LSRUse &LU, + SmallPtrSetImpl<const SCEV *> *LoserRegs) { + assert(F.isCanonical(*L) && "Cost is accurate only for canonical formula"); + // Tally up the registers. + unsigned PrevAddRecCost = C.AddRecCost; + unsigned PrevNumRegs = C.NumRegs; + unsigned PrevNumBaseAdds = C.NumBaseAdds; + if (const SCEV *ScaledReg = F.ScaledReg) { + if (VisitedRegs.count(ScaledReg)) { + Lose(); + return; + } + RatePrimaryRegister(ScaledReg, Regs, L, SE, DT, LoserRegs); + if (isLoser()) + return; + } + for (const SCEV *BaseReg : F.BaseRegs) { + if (VisitedRegs.count(BaseReg)) { + Lose(); + return; + } + RatePrimaryRegister(BaseReg, Regs, L, SE, DT, LoserRegs); + if (isLoser()) + return; + } + + // Determine how many (unfolded) adds we'll need inside the loop. + size_t NumBaseParts = F.getNumRegs(); + if (NumBaseParts > 1) + // Do not count the base and a possible second register if the target + // allows to fold 2 registers. + C.NumBaseAdds += + NumBaseParts - (1 + (F.Scale && isAMCompletelyFolded(TTI, LU, F))); + C.NumBaseAdds += (F.UnfoldedOffset != 0); + + // Accumulate non-free scaling amounts. + C.ScaleCost += getScalingFactorCost(TTI, LU, F, *L); + + // Tally up the non-zero immediates. + for (const LSRFixup &Fixup : LU.Fixups) { + int64_t O = Fixup.Offset; + int64_t Offset = (uint64_t)O + F.BaseOffset; + if (F.BaseGV) + C.ImmCost += 64; // Handle symbolic values conservatively. + // TODO: This should probably be the pointer size. + else if (Offset != 0) + C.ImmCost += APInt(64, Offset, true).getMinSignedBits(); + + // Check with target if this offset with this instruction is + // specifically not supported. + if (LU.Kind == LSRUse::Address && Offset != 0 && + !isAMCompletelyFolded(TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, + Offset, F.HasBaseReg, F.Scale, Fixup.UserInst)) + C.NumBaseAdds++; + } + + // If we don't count instruction cost exit here. + if (!InsnsCost) { + assert(isValid() && "invalid cost"); + return; + } + + // Treat every new register that exceeds TTI.getNumberOfRegisters() - 1 as + // additional instruction (at least fill). + unsigned TTIRegNum = TTI.getNumberOfRegisters(false) - 1; + if (C.NumRegs > TTIRegNum) { + // Cost already exceeded TTIRegNum, then only newly added register can add + // new instructions. + if (PrevNumRegs > TTIRegNum) + C.Insns += (C.NumRegs - PrevNumRegs); + else + C.Insns += (C.NumRegs - TTIRegNum); + } + + // If ICmpZero formula ends with not 0, it could not be replaced by + // just add or sub. We'll need to compare final result of AddRec. + // That means we'll need an additional instruction. + // For -10 + {0, +, 1}: + // i = i + 1; + // cmp i, 10 + // + // For {-10, +, 1}: + // i = i + 1; + if (LU.Kind == LSRUse::ICmpZero && !F.hasZeroEnd()) + C.Insns++; + // Each new AddRec adds 1 instruction to calculation. + C.Insns += (C.AddRecCost - PrevAddRecCost); + + // BaseAdds adds instructions for unfolded registers. + if (LU.Kind != LSRUse::ICmpZero) + C.Insns += C.NumBaseAdds - PrevNumBaseAdds; + assert(isValid() && "invalid cost"); +} + +/// Set this cost to a losing value. +void Cost::Lose() { + C.Insns = std::numeric_limits<unsigned>::max(); + C.NumRegs = std::numeric_limits<unsigned>::max(); + C.AddRecCost = std::numeric_limits<unsigned>::max(); + C.NumIVMuls = std::numeric_limits<unsigned>::max(); + C.NumBaseAdds = std::numeric_limits<unsigned>::max(); + C.ImmCost = std::numeric_limits<unsigned>::max(); + C.SetupCost = std::numeric_limits<unsigned>::max(); + C.ScaleCost = std::numeric_limits<unsigned>::max(); +} + +/// Choose the lower cost. +bool Cost::isLess(Cost &Other, const TargetTransformInfo &TTI) { + if (InsnsCost.getNumOccurrences() > 0 && InsnsCost && + C.Insns != Other.C.Insns) + return C.Insns < Other.C.Insns; + return TTI.isLSRCostLess(C, Other.C); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void Cost::print(raw_ostream &OS) const { + if (InsnsCost) + OS << C.Insns << " instruction" << (C.Insns == 1 ? " " : "s "); + OS << C.NumRegs << " reg" << (C.NumRegs == 1 ? "" : "s"); + if (C.AddRecCost != 0) + OS << ", with addrec cost " << C.AddRecCost; + if (C.NumIVMuls != 0) + OS << ", plus " << C.NumIVMuls << " IV mul" + << (C.NumIVMuls == 1 ? "" : "s"); + if (C.NumBaseAdds != 0) + OS << ", plus " << C.NumBaseAdds << " base add" + << (C.NumBaseAdds == 1 ? "" : "s"); + if (C.ScaleCost != 0) + OS << ", plus " << C.ScaleCost << " scale cost"; + if (C.ImmCost != 0) + OS << ", plus " << C.ImmCost << " imm cost"; + if (C.SetupCost != 0) + OS << ", plus " << C.SetupCost << " setup cost"; +} + +LLVM_DUMP_METHOD void Cost::dump() const { + print(errs()); errs() << '\n'; +} +#endif + +/// Test whether this fixup always uses its value outside of the given loop. +bool LSRFixup::isUseFullyOutsideLoop(const Loop *L) const { + // PHI nodes use their value in their incoming blocks. + if (const PHINode *PN = dyn_cast<PHINode>(UserInst)) { + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == OperandValToReplace && + L->contains(PN->getIncomingBlock(i))) + return false; + return true; + } + + return !L->contains(UserInst); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void LSRFixup::print(raw_ostream &OS) const { + OS << "UserInst="; + // Store is common and interesting enough to be worth special-casing. + if (StoreInst *Store = dyn_cast<StoreInst>(UserInst)) { + OS << "store "; + Store->getOperand(0)->printAsOperand(OS, /*PrintType=*/false); + } else if (UserInst->getType()->isVoidTy()) + OS << UserInst->getOpcodeName(); + else + UserInst->printAsOperand(OS, /*PrintType=*/false); + + OS << ", OperandValToReplace="; + OperandValToReplace->printAsOperand(OS, /*PrintType=*/false); + + for (const Loop *PIL : PostIncLoops) { + OS << ", PostIncLoop="; + PIL->getHeader()->printAsOperand(OS, /*PrintType=*/false); + } + + if (Offset != 0) + OS << ", Offset=" << Offset; +} + +LLVM_DUMP_METHOD void LSRFixup::dump() const { + print(errs()); errs() << '\n'; +} +#endif + +/// Test whether this use as a formula which has the same registers as the given +/// formula. +bool LSRUse::HasFormulaWithSameRegs(const Formula &F) const { + SmallVector<const SCEV *, 4> Key = F.BaseRegs; + if (F.ScaledReg) Key.push_back(F.ScaledReg); + // Unstable sort by host order ok, because this is only used for uniquifying. + std::sort(Key.begin(), Key.end()); + return Uniquifier.count(Key); +} + +/// The function returns a probability of selecting formula without Reg. +float LSRUse::getNotSelectedProbability(const SCEV *Reg) const { + unsigned FNum = 0; + for (const Formula &F : Formulae) + if (F.referencesReg(Reg)) + FNum++; + return ((float)(Formulae.size() - FNum)) / Formulae.size(); +} + +/// If the given formula has not yet been inserted, add it to the list, and +/// return true. Return false otherwise. The formula must be in canonical form. +bool LSRUse::InsertFormula(const Formula &F, const Loop &L) { + assert(F.isCanonical(L) && "Invalid canonical representation"); + + if (!Formulae.empty() && RigidFormula) + return false; + + SmallVector<const SCEV *, 4> Key = F.BaseRegs; + if (F.ScaledReg) Key.push_back(F.ScaledReg); + // Unstable sort by host order ok, because this is only used for uniquifying. + std::sort(Key.begin(), Key.end()); + + if (!Uniquifier.insert(Key).second) + return false; + + // Using a register to hold the value of 0 is not profitable. + assert((!F.ScaledReg || !F.ScaledReg->isZero()) && + "Zero allocated in a scaled register!"); +#ifndef NDEBUG + for (const SCEV *BaseReg : F.BaseRegs) + assert(!BaseReg->isZero() && "Zero allocated in a base register!"); +#endif + + // Add the formula to the list. + Formulae.push_back(F); + + // Record registers now being used by this use. + Regs.insert(F.BaseRegs.begin(), F.BaseRegs.end()); + if (F.ScaledReg) + Regs.insert(F.ScaledReg); + + return true; +} + +/// Remove the given formula from this use's list. +void LSRUse::DeleteFormula(Formula &F) { + if (&F != &Formulae.back()) + std::swap(F, Formulae.back()); + Formulae.pop_back(); +} + +/// Recompute the Regs field, and update RegUses. +void LSRUse::RecomputeRegs(size_t LUIdx, RegUseTracker &RegUses) { + // Now that we've filtered out some formulae, recompute the Regs set. + SmallPtrSet<const SCEV *, 4> OldRegs = std::move(Regs); + Regs.clear(); + for (const Formula &F : Formulae) { + if (F.ScaledReg) Regs.insert(F.ScaledReg); + Regs.insert(F.BaseRegs.begin(), F.BaseRegs.end()); + } + + // Update the RegTracker. + for (const SCEV *S : OldRegs) + if (!Regs.count(S)) + RegUses.dropRegister(S, LUIdx); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void LSRUse::print(raw_ostream &OS) const { + OS << "LSR Use: Kind="; + switch (Kind) { + case Basic: OS << "Basic"; break; + case Special: OS << "Special"; break; + case ICmpZero: OS << "ICmpZero"; break; + case Address: + OS << "Address of "; + if (AccessTy.MemTy->isPointerTy()) + OS << "pointer"; // the full pointer type could be really verbose + else { + OS << *AccessTy.MemTy; + } + + OS << " in addrspace(" << AccessTy.AddrSpace << ')'; + } + + OS << ", Offsets={"; + bool NeedComma = false; + for (const LSRFixup &Fixup : Fixups) { + if (NeedComma) OS << ','; + OS << Fixup.Offset; + NeedComma = true; + } + OS << '}'; + + if (AllFixupsOutsideLoop) + OS << ", all-fixups-outside-loop"; + + if (WidestFixupType) + OS << ", widest fixup type: " << *WidestFixupType; +} + +LLVM_DUMP_METHOD void LSRUse::dump() const { + print(errs()); errs() << '\n'; +} +#endif + +static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, + LSRUse::KindType Kind, MemAccessTy AccessTy, + GlobalValue *BaseGV, int64_t BaseOffset, + bool HasBaseReg, int64_t Scale, + Instruction *Fixup/*= nullptr*/) { + switch (Kind) { + case LSRUse::Address: + return TTI.isLegalAddressingMode(AccessTy.MemTy, BaseGV, BaseOffset, + HasBaseReg, Scale, AccessTy.AddrSpace, Fixup); + + case LSRUse::ICmpZero: + // There's not even a target hook for querying whether it would be legal to + // fold a GV into an ICmp. + if (BaseGV) + return false; + + // ICmp only has two operands; don't allow more than two non-trivial parts. + if (Scale != 0 && HasBaseReg && BaseOffset != 0) + return false; + + // ICmp only supports no scale or a -1 scale, as we can "fold" a -1 scale by + // putting the scaled register in the other operand of the icmp. + if (Scale != 0 && Scale != -1) + return false; + + // If we have low-level target information, ask the target if it can fold an + // integer immediate on an icmp. + if (BaseOffset != 0) { + // We have one of: + // ICmpZero BaseReg + BaseOffset => ICmp BaseReg, -BaseOffset + // ICmpZero -1*ScaleReg + BaseOffset => ICmp ScaleReg, BaseOffset + // Offs is the ICmp immediate. + if (Scale == 0) + // The cast does the right thing with + // std::numeric_limits<int64_t>::min(). + BaseOffset = -(uint64_t)BaseOffset; + return TTI.isLegalICmpImmediate(BaseOffset); + } + + // ICmpZero BaseReg + -1*ScaleReg => ICmp BaseReg, ScaleReg + return true; + + case LSRUse::Basic: + // Only handle single-register values. + return !BaseGV && Scale == 0 && BaseOffset == 0; + + case LSRUse::Special: + // Special case Basic to handle -1 scales. + return !BaseGV && (Scale == 0 || Scale == -1) && BaseOffset == 0; + } + + llvm_unreachable("Invalid LSRUse Kind!"); +} + +static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, + int64_t MinOffset, int64_t MaxOffset, + LSRUse::KindType Kind, MemAccessTy AccessTy, + GlobalValue *BaseGV, int64_t BaseOffset, + bool HasBaseReg, int64_t Scale) { + // Check for overflow. + if (((int64_t)((uint64_t)BaseOffset + MinOffset) > BaseOffset) != + (MinOffset > 0)) + return false; + MinOffset = (uint64_t)BaseOffset + MinOffset; + if (((int64_t)((uint64_t)BaseOffset + MaxOffset) > BaseOffset) != + (MaxOffset > 0)) + return false; + MaxOffset = (uint64_t)BaseOffset + MaxOffset; + + return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, MinOffset, + HasBaseReg, Scale) && + isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, MaxOffset, + HasBaseReg, Scale); +} + +static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, + int64_t MinOffset, int64_t MaxOffset, + LSRUse::KindType Kind, MemAccessTy AccessTy, + const Formula &F, const Loop &L) { + // For the purpose of isAMCompletelyFolded either having a canonical formula + // or a scale not equal to zero is correct. + // Problems may arise from non canonical formulae having a scale == 0. + // Strictly speaking it would best to just rely on canonical formulae. + // However, when we generate the scaled formulae, we first check that the + // scaling factor is profitable before computing the actual ScaledReg for + // compile time sake. + assert((F.isCanonical(L) || F.Scale != 0)); + return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, + F.BaseGV, F.BaseOffset, F.HasBaseReg, F.Scale); +} + +/// Test whether we know how to expand the current formula. +static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset, + int64_t MaxOffset, LSRUse::KindType Kind, + MemAccessTy AccessTy, GlobalValue *BaseGV, + int64_t BaseOffset, bool HasBaseReg, int64_t Scale) { + // We know how to expand completely foldable formulae. + return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, BaseGV, + BaseOffset, HasBaseReg, Scale) || + // Or formulae that use a base register produced by a sum of base + // registers. + (Scale == 1 && + isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, + BaseGV, BaseOffset, true, 0)); +} + +static bool isLegalUse(const TargetTransformInfo &TTI, int64_t MinOffset, + int64_t MaxOffset, LSRUse::KindType Kind, + MemAccessTy AccessTy, const Formula &F) { + return isLegalUse(TTI, MinOffset, MaxOffset, Kind, AccessTy, F.BaseGV, + F.BaseOffset, F.HasBaseReg, F.Scale); +} + +static bool isAMCompletelyFolded(const TargetTransformInfo &TTI, + const LSRUse &LU, const Formula &F) { + // Target may want to look at the user instructions. + if (LU.Kind == LSRUse::Address && TTI.LSRWithInstrQueries()) { + for (const LSRFixup &Fixup : LU.Fixups) + if (!isAMCompletelyFolded(TTI, LSRUse::Address, LU.AccessTy, F.BaseGV, + (F.BaseOffset + Fixup.Offset), F.HasBaseReg, + F.Scale, Fixup.UserInst)) + return false; + return true; + } + + return isAMCompletelyFolded(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, + LU.AccessTy, F.BaseGV, F.BaseOffset, F.HasBaseReg, + F.Scale); +} + +static unsigned getScalingFactorCost(const TargetTransformInfo &TTI, + const LSRUse &LU, const Formula &F, + const Loop &L) { + if (!F.Scale) + return 0; + + // If the use is not completely folded in that instruction, we will have to + // pay an extra cost only for scale != 1. + if (!isAMCompletelyFolded(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, + LU.AccessTy, F, L)) + return F.Scale != 1; + + switch (LU.Kind) { + case LSRUse::Address: { + // Check the scaling factor cost with both the min and max offsets. + int ScaleCostMinOffset = TTI.getScalingFactorCost( + LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MinOffset, F.HasBaseReg, + F.Scale, LU.AccessTy.AddrSpace); + int ScaleCostMaxOffset = TTI.getScalingFactorCost( + LU.AccessTy.MemTy, F.BaseGV, F.BaseOffset + LU.MaxOffset, F.HasBaseReg, + F.Scale, LU.AccessTy.AddrSpace); + + assert(ScaleCostMinOffset >= 0 && ScaleCostMaxOffset >= 0 && + "Legal addressing mode has an illegal cost!"); + return std::max(ScaleCostMinOffset, ScaleCostMaxOffset); + } + case LSRUse::ICmpZero: + case LSRUse::Basic: + case LSRUse::Special: + // The use is completely folded, i.e., everything is folded into the + // instruction. + return 0; + } + + llvm_unreachable("Invalid LSRUse Kind!"); +} + +static bool isAlwaysFoldable(const TargetTransformInfo &TTI, + LSRUse::KindType Kind, MemAccessTy AccessTy, + GlobalValue *BaseGV, int64_t BaseOffset, + bool HasBaseReg) { + // Fast-path: zero is always foldable. + if (BaseOffset == 0 && !BaseGV) return true; + + // Conservatively, create an address with an immediate and a + // base and a scale. + int64_t Scale = Kind == LSRUse::ICmpZero ? -1 : 1; + + // Canonicalize a scale of 1 to a base register if the formula doesn't + // already have a base register. + if (!HasBaseReg && Scale == 1) { + Scale = 0; + HasBaseReg = true; + } + + return isAMCompletelyFolded(TTI, Kind, AccessTy, BaseGV, BaseOffset, + HasBaseReg, Scale); +} + +static bool isAlwaysFoldable(const TargetTransformInfo &TTI, + ScalarEvolution &SE, int64_t MinOffset, + int64_t MaxOffset, LSRUse::KindType Kind, + MemAccessTy AccessTy, const SCEV *S, + bool HasBaseReg) { + // Fast-path: zero is always foldable. + if (S->isZero()) return true; + + // Conservatively, create an address with an immediate and a + // base and a scale. + int64_t BaseOffset = ExtractImmediate(S, SE); + GlobalValue *BaseGV = ExtractSymbol(S, SE); + + // If there's anything else involved, it's not foldable. + if (!S->isZero()) return false; + + // Fast-path: zero is always foldable. + if (BaseOffset == 0 && !BaseGV) return true; + + // Conservatively, create an address with an immediate and a + // base and a scale. + int64_t Scale = Kind == LSRUse::ICmpZero ? -1 : 1; + + return isAMCompletelyFolded(TTI, MinOffset, MaxOffset, Kind, AccessTy, BaseGV, + BaseOffset, HasBaseReg, Scale); +} + +namespace { + +/// An individual increment in a Chain of IV increments. Relate an IV user to +/// an expression that computes the IV it uses from the IV used by the previous +/// link in the Chain. +/// +/// For the head of a chain, IncExpr holds the absolute SCEV expression for the +/// original IVOperand. The head of the chain's IVOperand is only valid during +/// chain collection, before LSR replaces IV users. During chain generation, +/// IncExpr can be used to find the new IVOperand that computes the same +/// expression. +struct IVInc { + Instruction *UserInst; + Value* IVOperand; + const SCEV *IncExpr; + + IVInc(Instruction *U, Value *O, const SCEV *E) + : UserInst(U), IVOperand(O), IncExpr(E) {} +}; + +// The list of IV increments in program order. We typically add the head of a +// chain without finding subsequent links. +struct IVChain { + SmallVector<IVInc, 1> Incs; + const SCEV *ExprBase = nullptr; + + IVChain() = default; + IVChain(const IVInc &Head, const SCEV *Base) + : Incs(1, Head), ExprBase(Base) {} + + using const_iterator = SmallVectorImpl<IVInc>::const_iterator; + + // Return the first increment in the chain. + const_iterator begin() const { + assert(!Incs.empty()); + return std::next(Incs.begin()); + } + const_iterator end() const { + return Incs.end(); + } + + // Returns true if this chain contains any increments. + bool hasIncs() const { return Incs.size() >= 2; } + + // Add an IVInc to the end of this chain. + void add(const IVInc &X) { Incs.push_back(X); } + + // Returns the last UserInst in the chain. + Instruction *tailUserInst() const { return Incs.back().UserInst; } + + // Returns true if IncExpr can be profitably added to this chain. + bool isProfitableIncrement(const SCEV *OperExpr, + const SCEV *IncExpr, + ScalarEvolution&); +}; + +/// Helper for CollectChains to track multiple IV increment uses. Distinguish +/// between FarUsers that definitely cross IV increments and NearUsers that may +/// be used between IV increments. +struct ChainUsers { + SmallPtrSet<Instruction*, 4> FarUsers; + SmallPtrSet<Instruction*, 4> NearUsers; +}; + +/// This class holds state for the main loop strength reduction logic. +class LSRInstance { + IVUsers &IU; + ScalarEvolution &SE; + DominatorTree &DT; + LoopInfo &LI; + const TargetTransformInfo &TTI; + Loop *const L; + bool Changed = false; + + /// This is the insert position that the current loop's induction variable + /// increment should be placed. In simple loops, this is the latch block's + /// terminator. But in more complicated cases, this is a position which will + /// dominate all the in-loop post-increment users. + Instruction *IVIncInsertPos = nullptr; + + /// Interesting factors between use strides. + /// + /// We explicitly use a SetVector which contains a SmallSet, instead of the + /// default, a SmallDenseSet, because we need to use the full range of + /// int64_ts, and there's currently no good way of doing that with + /// SmallDenseSet. + SetVector<int64_t, SmallVector<int64_t, 8>, SmallSet<int64_t, 8>> Factors; + + /// Interesting use types, to facilitate truncation reuse. + SmallSetVector<Type *, 4> Types; + + /// The list of interesting uses. + SmallVector<LSRUse, 16> Uses; + + /// Track which uses use which register candidates. + RegUseTracker RegUses; + + // Limit the number of chains to avoid quadratic behavior. We don't expect to + // have more than a few IV increment chains in a loop. Missing a Chain falls + // back to normal LSR behavior for those uses. + static const unsigned MaxChains = 8; + + /// IV users can form a chain of IV increments. + SmallVector<IVChain, MaxChains> IVChainVec; + + /// IV users that belong to profitable IVChains. + SmallPtrSet<Use*, MaxChains> IVIncSet; + + void OptimizeShadowIV(); + bool FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse); + ICmpInst *OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse); + void OptimizeLoopTermCond(); + + void ChainInstruction(Instruction *UserInst, Instruction *IVOper, + SmallVectorImpl<ChainUsers> &ChainUsersVec); + void FinalizeChain(IVChain &Chain); + void CollectChains(); + void GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts); + + void CollectInterestingTypesAndFactors(); + void CollectFixupsAndInitialFormulae(); + + // Support for sharing of LSRUses between LSRFixups. + using UseMapTy = DenseMap<LSRUse::SCEVUseKindPair, size_t>; + UseMapTy UseMap; + + bool reconcileNewOffset(LSRUse &LU, int64_t NewOffset, bool HasBaseReg, + LSRUse::KindType Kind, MemAccessTy AccessTy); + + std::pair<size_t, int64_t> getUse(const SCEV *&Expr, LSRUse::KindType Kind, + MemAccessTy AccessTy); + + void DeleteUse(LSRUse &LU, size_t LUIdx); + + LSRUse *FindUseWithSimilarFormula(const Formula &F, const LSRUse &OrigLU); + + void InsertInitialFormula(const SCEV *S, LSRUse &LU, size_t LUIdx); + void InsertSupplementalFormula(const SCEV *S, LSRUse &LU, size_t LUIdx); + void CountRegisters(const Formula &F, size_t LUIdx); + bool InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F); + + void CollectLoopInvariantFixupsAndFormulae(); + + void GenerateReassociations(LSRUse &LU, unsigned LUIdx, Formula Base, + unsigned Depth = 0); + + void GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, + const Formula &Base, unsigned Depth, + size_t Idx, bool IsScaledReg = false); + void GenerateCombinations(LSRUse &LU, unsigned LUIdx, Formula Base); + void GenerateSymbolicOffsetsImpl(LSRUse &LU, unsigned LUIdx, + const Formula &Base, size_t Idx, + bool IsScaledReg = false); + void GenerateSymbolicOffsets(LSRUse &LU, unsigned LUIdx, Formula Base); + void GenerateConstantOffsetsImpl(LSRUse &LU, unsigned LUIdx, + const Formula &Base, + const SmallVectorImpl<int64_t> &Worklist, + size_t Idx, bool IsScaledReg = false); + void GenerateConstantOffsets(LSRUse &LU, unsigned LUIdx, Formula Base); + void GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, Formula Base); + void GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base); + void GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base); + void GenerateCrossUseConstantOffsets(); + void GenerateAllReuseFormulae(); + + void FilterOutUndesirableDedicatedRegisters(); + + size_t EstimateSearchSpaceComplexity() const; + void NarrowSearchSpaceByDetectingSupersets(); + void NarrowSearchSpaceByCollapsingUnrolledCode(); + void NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); + void NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); + void NarrowSearchSpaceByDeletingCostlyFormulas(); + void NarrowSearchSpaceByPickingWinnerRegs(); + void NarrowSearchSpaceUsingHeuristics(); + + void SolveRecurse(SmallVectorImpl<const Formula *> &Solution, + Cost &SolutionCost, + SmallVectorImpl<const Formula *> &Workspace, + const Cost &CurCost, + const SmallPtrSet<const SCEV *, 16> &CurRegs, + DenseSet<const SCEV *> &VisitedRegs) const; + void Solve(SmallVectorImpl<const Formula *> &Solution) const; + + BasicBlock::iterator + HoistInsertPosition(BasicBlock::iterator IP, + const SmallVectorImpl<Instruction *> &Inputs) const; + BasicBlock::iterator + AdjustInsertPositionForExpand(BasicBlock::iterator IP, + const LSRFixup &LF, + const LSRUse &LU, + SCEVExpander &Rewriter) const; + + Value *Expand(const LSRUse &LU, const LSRFixup &LF, const Formula &F, + BasicBlock::iterator IP, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const; + void RewriteForPHI(PHINode *PN, const LSRUse &LU, const LSRFixup &LF, + const Formula &F, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const; + void Rewrite(const LSRUse &LU, const LSRFixup &LF, const Formula &F, + SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const; + void ImplementSolution(const SmallVectorImpl<const Formula *> &Solution); + +public: + LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, DominatorTree &DT, + LoopInfo &LI, const TargetTransformInfo &TTI); + + bool getChanged() const { return Changed; } + + void print_factors_and_types(raw_ostream &OS) const; + void print_fixups(raw_ostream &OS) const; + void print_uses(raw_ostream &OS) const; + void print(raw_ostream &OS) const; + void dump() const; +}; + +} // end anonymous namespace + +/// If IV is used in a int-to-float cast inside the loop then try to eliminate +/// the cast operation. +void LSRInstance::OptimizeShadowIV() { + const SCEV *BackedgeTakenCount = SE.getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + return; + + for (IVUsers::const_iterator UI = IU.begin(), E = IU.end(); + UI != E; /* empty */) { + IVUsers::const_iterator CandidateUI = UI; + ++UI; + Instruction *ShadowUse = CandidateUI->getUser(); + Type *DestTy = nullptr; + bool IsSigned = false; + + /* If shadow use is a int->float cast then insert a second IV + to eliminate this cast. + + for (unsigned i = 0; i < n; ++i) + foo((double)i); + + is transformed into + + double d = 0.0; + for (unsigned i = 0; i < n; ++i, ++d) + foo(d); + */ + if (UIToFPInst *UCast = dyn_cast<UIToFPInst>(CandidateUI->getUser())) { + IsSigned = false; + DestTy = UCast->getDestTy(); + } + else if (SIToFPInst *SCast = dyn_cast<SIToFPInst>(CandidateUI->getUser())) { + IsSigned = true; + DestTy = SCast->getDestTy(); + } + if (!DestTy) continue; + + // If target does not support DestTy natively then do not apply + // this transformation. + if (!TTI.isTypeLegal(DestTy)) continue; + + PHINode *PH = dyn_cast<PHINode>(ShadowUse->getOperand(0)); + if (!PH) continue; + if (PH->getNumIncomingValues() != 2) continue; + + // If the calculation in integers overflows, the result in FP type will + // differ. So we only can do this transformation if we are guaranteed to not + // deal with overflowing values + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(SE.getSCEV(PH)); + if (!AR) continue; + if (IsSigned && !AR->hasNoSignedWrap()) continue; + if (!IsSigned && !AR->hasNoUnsignedWrap()) continue; + + Type *SrcTy = PH->getType(); + int Mantissa = DestTy->getFPMantissaWidth(); + if (Mantissa == -1) continue; + if ((int)SE.getTypeSizeInBits(SrcTy) > Mantissa) + continue; + + unsigned Entry, Latch; + if (PH->getIncomingBlock(0) == L->getLoopPreheader()) { + Entry = 0; + Latch = 1; + } else { + Entry = 1; + Latch = 0; + } + + ConstantInt *Init = dyn_cast<ConstantInt>(PH->getIncomingValue(Entry)); + if (!Init) continue; + Constant *NewInit = ConstantFP::get(DestTy, IsSigned ? + (double)Init->getSExtValue() : + (double)Init->getZExtValue()); + + BinaryOperator *Incr = + dyn_cast<BinaryOperator>(PH->getIncomingValue(Latch)); + if (!Incr) continue; + if (Incr->getOpcode() != Instruction::Add + && Incr->getOpcode() != Instruction::Sub) + continue; + + /* Initialize new IV, double d = 0.0 in above example. */ + ConstantInt *C = nullptr; + if (Incr->getOperand(0) == PH) + C = dyn_cast<ConstantInt>(Incr->getOperand(1)); + else if (Incr->getOperand(1) == PH) + C = dyn_cast<ConstantInt>(Incr->getOperand(0)); + else + continue; + + if (!C) continue; + + // Ignore negative constants, as the code below doesn't handle them + // correctly. TODO: Remove this restriction. + if (!C->getValue().isStrictlyPositive()) continue; + + /* Add new PHINode. */ + PHINode *NewPH = PHINode::Create(DestTy, 2, "IV.S.", PH); + + /* create new increment. '++d' in above example. */ + Constant *CFP = ConstantFP::get(DestTy, C->getZExtValue()); + BinaryOperator *NewIncr = + BinaryOperator::Create(Incr->getOpcode() == Instruction::Add ? + Instruction::FAdd : Instruction::FSub, + NewPH, CFP, "IV.S.next.", Incr); + + NewPH->addIncoming(NewInit, PH->getIncomingBlock(Entry)); + NewPH->addIncoming(NewIncr, PH->getIncomingBlock(Latch)); + + /* Remove cast operation */ + ShadowUse->replaceAllUsesWith(NewPH); + ShadowUse->eraseFromParent(); + Changed = true; + break; + } +} + +/// If Cond has an operand that is an expression of an IV, set the IV user and +/// stride information and return true, otherwise return false. +bool LSRInstance::FindIVUserForCond(ICmpInst *Cond, IVStrideUse *&CondUse) { + for (IVStrideUse &U : IU) + if (U.getUser() == Cond) { + // NOTE: we could handle setcc instructions with multiple uses here, but + // InstCombine does it as well for simple uses, it's not clear that it + // occurs enough in real life to handle. + CondUse = &U; + return true; + } + return false; +} + +/// Rewrite the loop's terminating condition if it uses a max computation. +/// +/// This is a narrow solution to a specific, but acute, problem. For loops +/// like this: +/// +/// i = 0; +/// do { +/// p[i] = 0.0; +/// } while (++i < n); +/// +/// the trip count isn't just 'n', because 'n' might not be positive. And +/// unfortunately this can come up even for loops where the user didn't use +/// a C do-while loop. For example, seemingly well-behaved top-test loops +/// will commonly be lowered like this: +/// +/// if (n > 0) { +/// i = 0; +/// do { +/// p[i] = 0.0; +/// } while (++i < n); +/// } +/// +/// and then it's possible for subsequent optimization to obscure the if +/// test in such a way that indvars can't find it. +/// +/// When indvars can't find the if test in loops like this, it creates a +/// max expression, which allows it to give the loop a canonical +/// induction variable: +/// +/// i = 0; +/// max = n < 1 ? 1 : n; +/// do { +/// p[i] = 0.0; +/// } while (++i != max); +/// +/// Canonical induction variables are necessary because the loop passes +/// are designed around them. The most obvious example of this is the +/// LoopInfo analysis, which doesn't remember trip count values. It +/// expects to be able to rediscover the trip count each time it is +/// needed, and it does this using a simple analysis that only succeeds if +/// the loop has a canonical induction variable. +/// +/// However, when it comes time to generate code, the maximum operation +/// can be quite costly, especially if it's inside of an outer loop. +/// +/// This function solves this problem by detecting this type of loop and +/// rewriting their conditions from ICMP_NE back to ICMP_SLT, and deleting +/// the instructions for the maximum computation. +ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) { + // Check that the loop matches the pattern we're looking for. + if (Cond->getPredicate() != CmpInst::ICMP_EQ && + Cond->getPredicate() != CmpInst::ICMP_NE) + return Cond; + + SelectInst *Sel = dyn_cast<SelectInst>(Cond->getOperand(1)); + if (!Sel || !Sel->hasOneUse()) return Cond; + + const SCEV *BackedgeTakenCount = SE.getBackedgeTakenCount(L); + if (isa<SCEVCouldNotCompute>(BackedgeTakenCount)) + return Cond; + const SCEV *One = SE.getConstant(BackedgeTakenCount->getType(), 1); + + // Add one to the backedge-taken count to get the trip count. + const SCEV *IterationCount = SE.getAddExpr(One, BackedgeTakenCount); + if (IterationCount != SE.getSCEV(Sel)) return Cond; + + // Check for a max calculation that matches the pattern. There's no check + // for ICMP_ULE here because the comparison would be with zero, which + // isn't interesting. + CmpInst::Predicate Pred = ICmpInst::BAD_ICMP_PREDICATE; + const SCEVNAryExpr *Max = nullptr; + if (const SCEVSMaxExpr *S = dyn_cast<SCEVSMaxExpr>(BackedgeTakenCount)) { + Pred = ICmpInst::ICMP_SLE; + Max = S; + } else if (const SCEVSMaxExpr *S = dyn_cast<SCEVSMaxExpr>(IterationCount)) { + Pred = ICmpInst::ICMP_SLT; + Max = S; + } else if (const SCEVUMaxExpr *U = dyn_cast<SCEVUMaxExpr>(IterationCount)) { + Pred = ICmpInst::ICMP_ULT; + Max = U; + } else { + // No match; bail. + return Cond; + } + + // To handle a max with more than two operands, this optimization would + // require additional checking and setup. + if (Max->getNumOperands() != 2) + return Cond; + + const SCEV *MaxLHS = Max->getOperand(0); + const SCEV *MaxRHS = Max->getOperand(1); + + // ScalarEvolution canonicalizes constants to the left. For < and >, look + // for a comparison with 1. For <= and >=, a comparison with zero. + if (!MaxLHS || + (ICmpInst::isTrueWhenEqual(Pred) ? !MaxLHS->isZero() : (MaxLHS != One))) + return Cond; + + // Check the relevant induction variable for conformance to + // the pattern. + const SCEV *IV = SE.getSCEV(Cond->getOperand(0)); + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV); + if (!AR || !AR->isAffine() || + AR->getStart() != One || + AR->getStepRecurrence(SE) != One) + return Cond; + + assert(AR->getLoop() == L && + "Loop condition operand is an addrec in a different loop!"); + + // Check the right operand of the select, and remember it, as it will + // be used in the new comparison instruction. + Value *NewRHS = nullptr; + if (ICmpInst::isTrueWhenEqual(Pred)) { + // Look for n+1, and grab n. + if (AddOperator *BO = dyn_cast<AddOperator>(Sel->getOperand(1))) + if (ConstantInt *BO1 = dyn_cast<ConstantInt>(BO->getOperand(1))) + if (BO1->isOne() && SE.getSCEV(BO->getOperand(0)) == MaxRHS) + NewRHS = BO->getOperand(0); + if (AddOperator *BO = dyn_cast<AddOperator>(Sel->getOperand(2))) + if (ConstantInt *BO1 = dyn_cast<ConstantInt>(BO->getOperand(1))) + if (BO1->isOne() && SE.getSCEV(BO->getOperand(0)) == MaxRHS) + NewRHS = BO->getOperand(0); + if (!NewRHS) + return Cond; + } else if (SE.getSCEV(Sel->getOperand(1)) == MaxRHS) + NewRHS = Sel->getOperand(1); + else if (SE.getSCEV(Sel->getOperand(2)) == MaxRHS) + NewRHS = Sel->getOperand(2); + else if (const SCEVUnknown *SU = dyn_cast<SCEVUnknown>(MaxRHS)) + NewRHS = SU->getValue(); + else + // Max doesn't match expected pattern. + return Cond; + + // Determine the new comparison opcode. It may be signed or unsigned, + // and the original comparison may be either equality or inequality. + if (Cond->getPredicate() == CmpInst::ICMP_EQ) + Pred = CmpInst::getInversePredicate(Pred); + + // Ok, everything looks ok to change the condition into an SLT or SGE and + // delete the max calculation. + ICmpInst *NewCond = + new ICmpInst(Cond, Pred, Cond->getOperand(0), NewRHS, "scmp"); + + // Delete the max calculation instructions. + Cond->replaceAllUsesWith(NewCond); + CondUse->setUser(NewCond); + Instruction *Cmp = cast<Instruction>(Sel->getOperand(0)); + Cond->eraseFromParent(); + Sel->eraseFromParent(); + if (Cmp->use_empty()) + Cmp->eraseFromParent(); + return NewCond; +} + +/// Change loop terminating condition to use the postinc iv when possible. +void +LSRInstance::OptimizeLoopTermCond() { + SmallPtrSet<Instruction *, 4> PostIncs; + + // We need a different set of heuristics for rotated and non-rotated loops. + // If a loop is rotated then the latch is also the backedge, so inserting + // post-inc expressions just before the latch is ideal. To reduce live ranges + // it also makes sense to rewrite terminating conditions to use post-inc + // expressions. + // + // If the loop is not rotated then the latch is not a backedge; the latch + // check is done in the loop head. Adding post-inc expressions before the + // latch will cause overlapping live-ranges of pre-inc and post-inc expressions + // in the loop body. In this case we do *not* want to use post-inc expressions + // in the latch check, and we want to insert post-inc expressions before + // the backedge. + BasicBlock *LatchBlock = L->getLoopLatch(); + SmallVector<BasicBlock*, 8> ExitingBlocks; + L->getExitingBlocks(ExitingBlocks); + if (llvm::all_of(ExitingBlocks, [&LatchBlock](const BasicBlock *BB) { + return LatchBlock != BB; + })) { + // The backedge doesn't exit the loop; treat this as a head-tested loop. + IVIncInsertPos = LatchBlock->getTerminator(); + return; + } + + // Otherwise treat this as a rotated loop. + for (BasicBlock *ExitingBlock : ExitingBlocks) { + // Get the terminating condition for the loop if possible. If we + // can, we want to change it to use a post-incremented version of its + // induction variable, to allow coalescing the live ranges for the IV into + // one register value. + + BranchInst *TermBr = dyn_cast<BranchInst>(ExitingBlock->getTerminator()); + if (!TermBr) + continue; + // FIXME: Overly conservative, termination condition could be an 'or' etc.. + if (TermBr->isUnconditional() || !isa<ICmpInst>(TermBr->getCondition())) + continue; + + // Search IVUsesByStride to find Cond's IVUse if there is one. + IVStrideUse *CondUse = nullptr; + ICmpInst *Cond = cast<ICmpInst>(TermBr->getCondition()); + if (!FindIVUserForCond(Cond, CondUse)) + continue; + + // If the trip count is computed in terms of a max (due to ScalarEvolution + // being unable to find a sufficient guard, for example), change the loop + // comparison to use SLT or ULT instead of NE. + // One consequence of doing this now is that it disrupts the count-down + // optimization. That's not always a bad thing though, because in such + // cases it may still be worthwhile to avoid a max. + Cond = OptimizeMax(Cond, CondUse); + + // If this exiting block dominates the latch block, it may also use + // the post-inc value if it won't be shared with other uses. + // Check for dominance. + if (!DT.dominates(ExitingBlock, LatchBlock)) + continue; + + // Conservatively avoid trying to use the post-inc value in non-latch + // exits if there may be pre-inc users in intervening blocks. + if (LatchBlock != ExitingBlock) + for (IVUsers::const_iterator UI = IU.begin(), E = IU.end(); UI != E; ++UI) + // Test if the use is reachable from the exiting block. This dominator + // query is a conservative approximation of reachability. + if (&*UI != CondUse && + !DT.properlyDominates(UI->getUser()->getParent(), ExitingBlock)) { + // Conservatively assume there may be reuse if the quotient of their + // strides could be a legal scale. + const SCEV *A = IU.getStride(*CondUse, L); + const SCEV *B = IU.getStride(*UI, L); + if (!A || !B) continue; + if (SE.getTypeSizeInBits(A->getType()) != + SE.getTypeSizeInBits(B->getType())) { + if (SE.getTypeSizeInBits(A->getType()) > + SE.getTypeSizeInBits(B->getType())) + B = SE.getSignExtendExpr(B, A->getType()); + else + A = SE.getSignExtendExpr(A, B->getType()); + } + if (const SCEVConstant *D = + dyn_cast_or_null<SCEVConstant>(getExactSDiv(B, A, SE))) { + const ConstantInt *C = D->getValue(); + // Stride of one or negative one can have reuse with non-addresses. + if (C->isOne() || C->isMinusOne()) + goto decline_post_inc; + // Avoid weird situations. + if (C->getValue().getMinSignedBits() >= 64 || + C->getValue().isMinSignedValue()) + goto decline_post_inc; + // Check for possible scaled-address reuse. + MemAccessTy AccessTy = getAccessType(TTI, UI->getUser()); + int64_t Scale = C->getSExtValue(); + if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/0, + /*HasBaseReg=*/false, Scale, + AccessTy.AddrSpace)) + goto decline_post_inc; + Scale = -Scale; + if (TTI.isLegalAddressingMode(AccessTy.MemTy, /*BaseGV=*/nullptr, + /*BaseOffset=*/0, + /*HasBaseReg=*/false, Scale, + AccessTy.AddrSpace)) + goto decline_post_inc; + } + } + + DEBUG(dbgs() << " Change loop exiting icmp to use postinc iv: " + << *Cond << '\n'); + + // It's possible for the setcc instruction to be anywhere in the loop, and + // possible for it to have multiple users. If it is not immediately before + // the exiting block branch, move it. + if (&*++BasicBlock::iterator(Cond) != TermBr) { + if (Cond->hasOneUse()) { + Cond->moveBefore(TermBr); + } else { + // Clone the terminating condition and insert into the loopend. + ICmpInst *OldCond = Cond; + Cond = cast<ICmpInst>(Cond->clone()); + Cond->setName(L->getHeader()->getName() + ".termcond"); + ExitingBlock->getInstList().insert(TermBr->getIterator(), Cond); + + // Clone the IVUse, as the old use still exists! + CondUse = &IU.AddUser(Cond, CondUse->getOperandValToReplace()); + TermBr->replaceUsesOfWith(OldCond, Cond); + } + } + + // If we get to here, we know that we can transform the setcc instruction to + // use the post-incremented version of the IV, allowing us to coalesce the + // live ranges for the IV correctly. + CondUse->transformToPostInc(L); + Changed = true; + + PostIncs.insert(Cond); + decline_post_inc:; + } + + // Determine an insertion point for the loop induction variable increment. It + // must dominate all the post-inc comparisons we just set up, and it must + // dominate the loop latch edge. + IVIncInsertPos = L->getLoopLatch()->getTerminator(); + for (Instruction *Inst : PostIncs) { + BasicBlock *BB = + DT.findNearestCommonDominator(IVIncInsertPos->getParent(), + Inst->getParent()); + if (BB == Inst->getParent()) + IVIncInsertPos = Inst; + else if (BB != IVIncInsertPos->getParent()) + IVIncInsertPos = BB->getTerminator(); + } +} + +/// Determine if the given use can accommodate a fixup at the given offset and +/// other details. If so, update the use and return true. +bool LSRInstance::reconcileNewOffset(LSRUse &LU, int64_t NewOffset, + bool HasBaseReg, LSRUse::KindType Kind, + MemAccessTy AccessTy) { + int64_t NewMinOffset = LU.MinOffset; + int64_t NewMaxOffset = LU.MaxOffset; + MemAccessTy NewAccessTy = AccessTy; + + // Check for a mismatched kind. It's tempting to collapse mismatched kinds to + // something conservative, however this can pessimize in the case that one of + // the uses will have all its uses outside the loop, for example. + if (LU.Kind != Kind) + return false; + + // Check for a mismatched access type, and fall back conservatively as needed. + // TODO: Be less conservative when the type is similar and can use the same + // addressing modes. + if (Kind == LSRUse::Address) { + if (AccessTy.MemTy != LU.AccessTy.MemTy) { + NewAccessTy = MemAccessTy::getUnknown(AccessTy.MemTy->getContext(), + AccessTy.AddrSpace); + } + } + + // Conservatively assume HasBaseReg is true for now. + if (NewOffset < LU.MinOffset) { + if (!isAlwaysFoldable(TTI, Kind, NewAccessTy, /*BaseGV=*/nullptr, + LU.MaxOffset - NewOffset, HasBaseReg)) + return false; + NewMinOffset = NewOffset; + } else if (NewOffset > LU.MaxOffset) { + if (!isAlwaysFoldable(TTI, Kind, NewAccessTy, /*BaseGV=*/nullptr, + NewOffset - LU.MinOffset, HasBaseReg)) + return false; + NewMaxOffset = NewOffset; + } + + // Update the use. + LU.MinOffset = NewMinOffset; + LU.MaxOffset = NewMaxOffset; + LU.AccessTy = NewAccessTy; + return true; +} + +/// Return an LSRUse index and an offset value for a fixup which needs the given +/// expression, with the given kind and optional access type. Either reuse an +/// existing use or create a new one, as needed. +std::pair<size_t, int64_t> LSRInstance::getUse(const SCEV *&Expr, + LSRUse::KindType Kind, + MemAccessTy AccessTy) { + const SCEV *Copy = Expr; + int64_t Offset = ExtractImmediate(Expr, SE); + + // Basic uses can't accept any offset, for example. + if (!isAlwaysFoldable(TTI, Kind, AccessTy, /*BaseGV=*/ nullptr, + Offset, /*HasBaseReg=*/ true)) { + Expr = Copy; + Offset = 0; + } + + std::pair<UseMapTy::iterator, bool> P = + UseMap.insert(std::make_pair(LSRUse::SCEVUseKindPair(Expr, Kind), 0)); + if (!P.second) { + // A use already existed with this base. + size_t LUIdx = P.first->second; + LSRUse &LU = Uses[LUIdx]; + if (reconcileNewOffset(LU, Offset, /*HasBaseReg=*/true, Kind, AccessTy)) + // Reuse this use. + return std::make_pair(LUIdx, Offset); + } + + // Create a new use. + size_t LUIdx = Uses.size(); + P.first->second = LUIdx; + Uses.push_back(LSRUse(Kind, AccessTy)); + LSRUse &LU = Uses[LUIdx]; + + LU.MinOffset = Offset; + LU.MaxOffset = Offset; + return std::make_pair(LUIdx, Offset); +} + +/// Delete the given use from the Uses list. +void LSRInstance::DeleteUse(LSRUse &LU, size_t LUIdx) { + if (&LU != &Uses.back()) + std::swap(LU, Uses.back()); + Uses.pop_back(); + + // Update RegUses. + RegUses.swapAndDropUse(LUIdx, Uses.size()); +} + +/// Look for a use distinct from OrigLU which is has a formula that has the same +/// registers as the given formula. +LSRUse * +LSRInstance::FindUseWithSimilarFormula(const Formula &OrigF, + const LSRUse &OrigLU) { + // Search all uses for the formula. This could be more clever. + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + // Check whether this use is close enough to OrigLU, to see whether it's + // worthwhile looking through its formulae. + // Ignore ICmpZero uses because they may contain formulae generated by + // GenerateICmpZeroScales, in which case adding fixup offsets may + // be invalid. + if (&LU != &OrigLU && + LU.Kind != LSRUse::ICmpZero && + LU.Kind == OrigLU.Kind && OrigLU.AccessTy == LU.AccessTy && + LU.WidestFixupType == OrigLU.WidestFixupType && + LU.HasFormulaWithSameRegs(OrigF)) { + // Scan through this use's formulae. + for (const Formula &F : LU.Formulae) { + // Check to see if this formula has the same registers and symbols + // as OrigF. + if (F.BaseRegs == OrigF.BaseRegs && + F.ScaledReg == OrigF.ScaledReg && + F.BaseGV == OrigF.BaseGV && + F.Scale == OrigF.Scale && + F.UnfoldedOffset == OrigF.UnfoldedOffset) { + if (F.BaseOffset == 0) + return &LU; + // This is the formula where all the registers and symbols matched; + // there aren't going to be any others. Since we declined it, we + // can skip the rest of the formulae and proceed to the next LSRUse. + break; + } + } + } + } + + // Nothing looked good. + return nullptr; +} + +void LSRInstance::CollectInterestingTypesAndFactors() { + SmallSetVector<const SCEV *, 4> Strides; + + // Collect interesting types and strides. + SmallVector<const SCEV *, 4> Worklist; + for (const IVStrideUse &U : IU) { + const SCEV *Expr = IU.getExpr(U); + + // Collect interesting types. + Types.insert(SE.getEffectiveSCEVType(Expr->getType())); + + // Add strides for mentioned loops. + Worklist.push_back(Expr); + do { + const SCEV *S = Worklist.pop_back_val(); + if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + if (AR->getLoop() == L) + Strides.insert(AR->getStepRecurrence(SE)); + Worklist.push_back(AR->getStart()); + } else if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + Worklist.append(Add->op_begin(), Add->op_end()); + } + } while (!Worklist.empty()); + } + + // Compute interesting factors from the set of interesting strides. + for (SmallSetVector<const SCEV *, 4>::const_iterator + I = Strides.begin(), E = Strides.end(); I != E; ++I) + for (SmallSetVector<const SCEV *, 4>::const_iterator NewStrideIter = + std::next(I); NewStrideIter != E; ++NewStrideIter) { + const SCEV *OldStride = *I; + const SCEV *NewStride = *NewStrideIter; + + if (SE.getTypeSizeInBits(OldStride->getType()) != + SE.getTypeSizeInBits(NewStride->getType())) { + if (SE.getTypeSizeInBits(OldStride->getType()) > + SE.getTypeSizeInBits(NewStride->getType())) + NewStride = SE.getSignExtendExpr(NewStride, OldStride->getType()); + else + OldStride = SE.getSignExtendExpr(OldStride, NewStride->getType()); + } + if (const SCEVConstant *Factor = + dyn_cast_or_null<SCEVConstant>(getExactSDiv(NewStride, OldStride, + SE, true))) { + if (Factor->getAPInt().getMinSignedBits() <= 64) + Factors.insert(Factor->getAPInt().getSExtValue()); + } else if (const SCEVConstant *Factor = + dyn_cast_or_null<SCEVConstant>(getExactSDiv(OldStride, + NewStride, + SE, true))) { + if (Factor->getAPInt().getMinSignedBits() <= 64) + Factors.insert(Factor->getAPInt().getSExtValue()); + } + } + + // If all uses use the same type, don't bother looking for truncation-based + // reuse. + if (Types.size() == 1) + Types.clear(); + + DEBUG(print_factors_and_types(dbgs())); +} + +/// Helper for CollectChains that finds an IV operand (computed by an AddRec in +/// this loop) within [OI,OE) or returns OE. If IVUsers mapped Instructions to +/// IVStrideUses, we could partially skip this. +static User::op_iterator +findIVOperand(User::op_iterator OI, User::op_iterator OE, + Loop *L, ScalarEvolution &SE) { + for(; OI != OE; ++OI) { + if (Instruction *Oper = dyn_cast<Instruction>(*OI)) { + if (!SE.isSCEVable(Oper->getType())) + continue; + + if (const SCEVAddRecExpr *AR = + dyn_cast<SCEVAddRecExpr>(SE.getSCEV(Oper))) { + if (AR->getLoop() == L) + break; + } + } + } + return OI; +} + +/// IVChain logic must consistenctly peek base TruncInst operands, so wrap it in +/// a convenient helper. +static Value *getWideOperand(Value *Oper) { + if (TruncInst *Trunc = dyn_cast<TruncInst>(Oper)) + return Trunc->getOperand(0); + return Oper; +} + +/// Return true if we allow an IV chain to include both types. +static bool isCompatibleIVType(Value *LVal, Value *RVal) { + Type *LType = LVal->getType(); + Type *RType = RVal->getType(); + return (LType == RType) || (LType->isPointerTy() && RType->isPointerTy() && + // Different address spaces means (possibly) + // different types of the pointer implementation, + // e.g. i16 vs i32 so disallow that. + (LType->getPointerAddressSpace() == + RType->getPointerAddressSpace())); +} + +/// Return an approximation of this SCEV expression's "base", or NULL for any +/// constant. Returning the expression itself is conservative. Returning a +/// deeper subexpression is more precise and valid as long as it isn't less +/// complex than another subexpression. For expressions involving multiple +/// unscaled values, we need to return the pointer-type SCEVUnknown. This avoids +/// forming chains across objects, such as: PrevOper==a[i], IVOper==b[i], +/// IVInc==b-a. +/// +/// Since SCEVUnknown is the rightmost type, and pointers are the rightmost +/// SCEVUnknown, we simply return the rightmost SCEV operand. +static const SCEV *getExprBase(const SCEV *S) { + switch (S->getSCEVType()) { + default: // uncluding scUnknown. + return S; + case scConstant: + return nullptr; + case scTruncate: + return getExprBase(cast<SCEVTruncateExpr>(S)->getOperand()); + case scZeroExtend: + return getExprBase(cast<SCEVZeroExtendExpr>(S)->getOperand()); + case scSignExtend: + return getExprBase(cast<SCEVSignExtendExpr>(S)->getOperand()); + case scAddExpr: { + // Skip over scaled operands (scMulExpr) to follow add operands as long as + // there's nothing more complex. + // FIXME: not sure if we want to recognize negation. + const SCEVAddExpr *Add = cast<SCEVAddExpr>(S); + for (std::reverse_iterator<SCEVAddExpr::op_iterator> I(Add->op_end()), + E(Add->op_begin()); I != E; ++I) { + const SCEV *SubExpr = *I; + if (SubExpr->getSCEVType() == scAddExpr) + return getExprBase(SubExpr); + + if (SubExpr->getSCEVType() != scMulExpr) + return SubExpr; + } + return S; // all operands are scaled, be conservative. + } + case scAddRecExpr: + return getExprBase(cast<SCEVAddRecExpr>(S)->getStart()); + } +} + +/// Return true if the chain increment is profitable to expand into a loop +/// invariant value, which may require its own register. A profitable chain +/// increment will be an offset relative to the same base. We allow such offsets +/// to potentially be used as chain increment as long as it's not obviously +/// expensive to expand using real instructions. +bool IVChain::isProfitableIncrement(const SCEV *OperExpr, + const SCEV *IncExpr, + ScalarEvolution &SE) { + // Aggressively form chains when -stress-ivchain. + if (StressIVChain) + return true; + + // Do not replace a constant offset from IV head with a nonconstant IV + // increment. + if (!isa<SCEVConstant>(IncExpr)) { + const SCEV *HeadExpr = SE.getSCEV(getWideOperand(Incs[0].IVOperand)); + if (isa<SCEVConstant>(SE.getMinusSCEV(OperExpr, HeadExpr))) + return false; + } + + SmallPtrSet<const SCEV*, 8> Processed; + return !isHighCostExpansion(IncExpr, Processed, SE); +} + +/// Return true if the number of registers needed for the chain is estimated to +/// be less than the number required for the individual IV users. First prohibit +/// any IV users that keep the IV live across increments (the Users set should +/// be empty). Next count the number and type of increments in the chain. +/// +/// Chaining IVs can lead to considerable code bloat if ISEL doesn't +/// effectively use postinc addressing modes. Only consider it profitable it the +/// increments can be computed in fewer registers when chained. +/// +/// TODO: Consider IVInc free if it's already used in another chains. +static bool +isProfitableChain(IVChain &Chain, SmallPtrSetImpl<Instruction*> &Users, + ScalarEvolution &SE, const TargetTransformInfo &TTI) { + if (StressIVChain) + return true; + + if (!Chain.hasIncs()) + return false; + + if (!Users.empty()) { + DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst << " users:\n"; + for (Instruction *Inst : Users) { + dbgs() << " " << *Inst << "\n"; + }); + return false; + } + assert(!Chain.Incs.empty() && "empty IV chains are not allowed"); + + // The chain itself may require a register, so intialize cost to 1. + int cost = 1; + + // A complete chain likely eliminates the need for keeping the original IV in + // a register. LSR does not currently know how to form a complete chain unless + // the header phi already exists. + if (isa<PHINode>(Chain.tailUserInst()) + && SE.getSCEV(Chain.tailUserInst()) == Chain.Incs[0].IncExpr) { + --cost; + } + const SCEV *LastIncExpr = nullptr; + unsigned NumConstIncrements = 0; + unsigned NumVarIncrements = 0; + unsigned NumReusedIncrements = 0; + for (const IVInc &Inc : Chain) { + if (Inc.IncExpr->isZero()) + continue; + + // Incrementing by zero or some constant is neutral. We assume constants can + // be folded into an addressing mode or an add's immediate operand. + if (isa<SCEVConstant>(Inc.IncExpr)) { + ++NumConstIncrements; + continue; + } + + if (Inc.IncExpr == LastIncExpr) + ++NumReusedIncrements; + else + ++NumVarIncrements; + + LastIncExpr = Inc.IncExpr; + } + // An IV chain with a single increment is handled by LSR's postinc + // uses. However, a chain with multiple increments requires keeping the IV's + // value live longer than it needs to be if chained. + if (NumConstIncrements > 1) + --cost; + + // Materializing increment expressions in the preheader that didn't exist in + // the original code may cost a register. For example, sign-extended array + // indices can produce ridiculous increments like this: + // IV + ((sext i32 (2 * %s) to i64) + (-1 * (sext i32 %s to i64))) + cost += NumVarIncrements; + + // Reusing variable increments likely saves a register to hold the multiple of + // the stride. + cost -= NumReusedIncrements; + + DEBUG(dbgs() << "Chain: " << *Chain.Incs[0].UserInst << " Cost: " << cost + << "\n"); + + return cost < 0; +} + +/// Add this IV user to an existing chain or make it the head of a new chain. +void LSRInstance::ChainInstruction(Instruction *UserInst, Instruction *IVOper, + SmallVectorImpl<ChainUsers> &ChainUsersVec) { + // When IVs are used as types of varying widths, they are generally converted + // to a wider type with some uses remaining narrow under a (free) trunc. + Value *const NextIV = getWideOperand(IVOper); + const SCEV *const OperExpr = SE.getSCEV(NextIV); + const SCEV *const OperExprBase = getExprBase(OperExpr); + + // Visit all existing chains. Check if its IVOper can be computed as a + // profitable loop invariant increment from the last link in the Chain. + unsigned ChainIdx = 0, NChains = IVChainVec.size(); + const SCEV *LastIncExpr = nullptr; + for (; ChainIdx < NChains; ++ChainIdx) { + IVChain &Chain = IVChainVec[ChainIdx]; + + // Prune the solution space aggressively by checking that both IV operands + // are expressions that operate on the same unscaled SCEVUnknown. This + // "base" will be canceled by the subsequent getMinusSCEV call. Checking + // first avoids creating extra SCEV expressions. + if (!StressIVChain && Chain.ExprBase != OperExprBase) + continue; + + Value *PrevIV = getWideOperand(Chain.Incs.back().IVOperand); + if (!isCompatibleIVType(PrevIV, NextIV)) + continue; + + // A phi node terminates a chain. + if (isa<PHINode>(UserInst) && isa<PHINode>(Chain.tailUserInst())) + continue; + + // The increment must be loop-invariant so it can be kept in a register. + const SCEV *PrevExpr = SE.getSCEV(PrevIV); + const SCEV *IncExpr = SE.getMinusSCEV(OperExpr, PrevExpr); + if (!SE.isLoopInvariant(IncExpr, L)) + continue; + + if (Chain.isProfitableIncrement(OperExpr, IncExpr, SE)) { + LastIncExpr = IncExpr; + break; + } + } + // If we haven't found a chain, create a new one, unless we hit the max. Don't + // bother for phi nodes, because they must be last in the chain. + if (ChainIdx == NChains) { + if (isa<PHINode>(UserInst)) + return; + if (NChains >= MaxChains && !StressIVChain) { + DEBUG(dbgs() << "IV Chain Limit\n"); + return; + } + LastIncExpr = OperExpr; + // IVUsers may have skipped over sign/zero extensions. We don't currently + // attempt to form chains involving extensions unless they can be hoisted + // into this loop's AddRec. + if (!isa<SCEVAddRecExpr>(LastIncExpr)) + return; + ++NChains; + IVChainVec.push_back(IVChain(IVInc(UserInst, IVOper, LastIncExpr), + OperExprBase)); + ChainUsersVec.resize(NChains); + DEBUG(dbgs() << "IV Chain#" << ChainIdx << " Head: (" << *UserInst + << ") IV=" << *LastIncExpr << "\n"); + } else { + DEBUG(dbgs() << "IV Chain#" << ChainIdx << " Inc: (" << *UserInst + << ") IV+" << *LastIncExpr << "\n"); + // Add this IV user to the end of the chain. + IVChainVec[ChainIdx].add(IVInc(UserInst, IVOper, LastIncExpr)); + } + IVChain &Chain = IVChainVec[ChainIdx]; + + SmallPtrSet<Instruction*,4> &NearUsers = ChainUsersVec[ChainIdx].NearUsers; + // This chain's NearUsers become FarUsers. + if (!LastIncExpr->isZero()) { + ChainUsersVec[ChainIdx].FarUsers.insert(NearUsers.begin(), + NearUsers.end()); + NearUsers.clear(); + } + + // All other uses of IVOperand become near uses of the chain. + // We currently ignore intermediate values within SCEV expressions, assuming + // they will eventually be used be the current chain, or can be computed + // from one of the chain increments. To be more precise we could + // transitively follow its user and only add leaf IV users to the set. + for (User *U : IVOper->users()) { + Instruction *OtherUse = dyn_cast<Instruction>(U); + if (!OtherUse) + continue; + // Uses in the chain will no longer be uses if the chain is formed. + // Include the head of the chain in this iteration (not Chain.begin()). + IVChain::const_iterator IncIter = Chain.Incs.begin(); + IVChain::const_iterator IncEnd = Chain.Incs.end(); + for( ; IncIter != IncEnd; ++IncIter) { + if (IncIter->UserInst == OtherUse) + break; + } + if (IncIter != IncEnd) + continue; + + if (SE.isSCEVable(OtherUse->getType()) + && !isa<SCEVUnknown>(SE.getSCEV(OtherUse)) + && IU.isIVUserOrOperand(OtherUse)) { + continue; + } + NearUsers.insert(OtherUse); + } + + // Since this user is part of the chain, it's no longer considered a use + // of the chain. + ChainUsersVec[ChainIdx].FarUsers.erase(UserInst); +} + +/// Populate the vector of Chains. +/// +/// This decreases ILP at the architecture level. Targets with ample registers, +/// multiple memory ports, and no register renaming probably don't want +/// this. However, such targets should probably disable LSR altogether. +/// +/// The job of LSR is to make a reasonable choice of induction variables across +/// the loop. Subsequent passes can easily "unchain" computation exposing more +/// ILP *within the loop* if the target wants it. +/// +/// Finding the best IV chain is potentially a scheduling problem. Since LSR +/// will not reorder memory operations, it will recognize this as a chain, but +/// will generate redundant IV increments. Ideally this would be corrected later +/// by a smart scheduler: +/// = A[i] +/// = A[i+x] +/// A[i] = +/// A[i+x] = +/// +/// TODO: Walk the entire domtree within this loop, not just the path to the +/// loop latch. This will discover chains on side paths, but requires +/// maintaining multiple copies of the Chains state. +void LSRInstance::CollectChains() { + DEBUG(dbgs() << "Collecting IV Chains.\n"); + SmallVector<ChainUsers, 8> ChainUsersVec; + + SmallVector<BasicBlock *,8> LatchPath; + BasicBlock *LoopHeader = L->getHeader(); + for (DomTreeNode *Rung = DT.getNode(L->getLoopLatch()); + Rung->getBlock() != LoopHeader; Rung = Rung->getIDom()) { + LatchPath.push_back(Rung->getBlock()); + } + LatchPath.push_back(LoopHeader); + + // Walk the instruction stream from the loop header to the loop latch. + for (BasicBlock *BB : reverse(LatchPath)) { + for (Instruction &I : *BB) { + // Skip instructions that weren't seen by IVUsers analysis. + if (isa<PHINode>(I) || !IU.isIVUserOrOperand(&I)) + continue; + + // Ignore users that are part of a SCEV expression. This way we only + // consider leaf IV Users. This effectively rediscovers a portion of + // IVUsers analysis but in program order this time. + if (SE.isSCEVable(I.getType()) && !isa<SCEVUnknown>(SE.getSCEV(&I))) + continue; + + // Remove this instruction from any NearUsers set it may be in. + for (unsigned ChainIdx = 0, NChains = IVChainVec.size(); + ChainIdx < NChains; ++ChainIdx) { + ChainUsersVec[ChainIdx].NearUsers.erase(&I); + } + // Search for operands that can be chained. + SmallPtrSet<Instruction*, 4> UniqueOperands; + User::op_iterator IVOpEnd = I.op_end(); + User::op_iterator IVOpIter = findIVOperand(I.op_begin(), IVOpEnd, L, SE); + while (IVOpIter != IVOpEnd) { + Instruction *IVOpInst = cast<Instruction>(*IVOpIter); + if (UniqueOperands.insert(IVOpInst).second) + ChainInstruction(&I, IVOpInst, ChainUsersVec); + IVOpIter = findIVOperand(std::next(IVOpIter), IVOpEnd, L, SE); + } + } // Continue walking down the instructions. + } // Continue walking down the domtree. + // Visit phi backedges to determine if the chain can generate the IV postinc. + for (PHINode &PN : L->getHeader()->phis()) { + if (!SE.isSCEVable(PN.getType())) + continue; + + Instruction *IncV = + dyn_cast<Instruction>(PN.getIncomingValueForBlock(L->getLoopLatch())); + if (IncV) + ChainInstruction(&PN, IncV, ChainUsersVec); + } + // Remove any unprofitable chains. + unsigned ChainIdx = 0; + for (unsigned UsersIdx = 0, NChains = IVChainVec.size(); + UsersIdx < NChains; ++UsersIdx) { + if (!isProfitableChain(IVChainVec[UsersIdx], + ChainUsersVec[UsersIdx].FarUsers, SE, TTI)) + continue; + // Preserve the chain at UsesIdx. + if (ChainIdx != UsersIdx) + IVChainVec[ChainIdx] = IVChainVec[UsersIdx]; + FinalizeChain(IVChainVec[ChainIdx]); + ++ChainIdx; + } + IVChainVec.resize(ChainIdx); +} + +void LSRInstance::FinalizeChain(IVChain &Chain) { + assert(!Chain.Incs.empty() && "empty IV chains are not allowed"); + DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n"); + + for (const IVInc &Inc : Chain) { + DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n"); + auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand); + assert(UseI != Inc.UserInst->op_end() && "cannot find IV operand"); + IVIncSet.insert(UseI); + } +} + +/// Return true if the IVInc can be folded into an addressing mode. +static bool canFoldIVIncExpr(const SCEV *IncExpr, Instruction *UserInst, + Value *Operand, const TargetTransformInfo &TTI) { + const SCEVConstant *IncConst = dyn_cast<SCEVConstant>(IncExpr); + if (!IncConst || !isAddressUse(TTI, UserInst, Operand)) + return false; + + if (IncConst->getAPInt().getMinSignedBits() > 64) + return false; + + MemAccessTy AccessTy = getAccessType(TTI, UserInst); + int64_t IncOffset = IncConst->getValue()->getSExtValue(); + if (!isAlwaysFoldable(TTI, LSRUse::Address, AccessTy, /*BaseGV=*/nullptr, + IncOffset, /*HaseBaseReg=*/false)) + return false; + + return true; +} + +/// Generate an add or subtract for each IVInc in a chain to materialize the IV +/// user's operand from the previous IV user's operand. +void LSRInstance::GenerateIVChain(const IVChain &Chain, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) { + // Find the new IVOperand for the head of the chain. It may have been replaced + // by LSR. + const IVInc &Head = Chain.Incs[0]; + User::op_iterator IVOpEnd = Head.UserInst->op_end(); + // findIVOperand returns IVOpEnd if it can no longer find a valid IV user. + User::op_iterator IVOpIter = findIVOperand(Head.UserInst->op_begin(), + IVOpEnd, L, SE); + Value *IVSrc = nullptr; + while (IVOpIter != IVOpEnd) { + IVSrc = getWideOperand(*IVOpIter); + + // If this operand computes the expression that the chain needs, we may use + // it. (Check this after setting IVSrc which is used below.) + // + // Note that if Head.IncExpr is wider than IVSrc, then this phi is too + // narrow for the chain, so we can no longer use it. We do allow using a + // wider phi, assuming the LSR checked for free truncation. In that case we + // should already have a truncate on this operand such that + // getSCEV(IVSrc) == IncExpr. + if (SE.getSCEV(*IVOpIter) == Head.IncExpr + || SE.getSCEV(IVSrc) == Head.IncExpr) { + break; + } + IVOpIter = findIVOperand(std::next(IVOpIter), IVOpEnd, L, SE); + } + if (IVOpIter == IVOpEnd) { + // Gracefully give up on this chain. + DEBUG(dbgs() << "Concealed chain head: " << *Head.UserInst << "\n"); + return; + } + + DEBUG(dbgs() << "Generate chain at: " << *IVSrc << "\n"); + Type *IVTy = IVSrc->getType(); + Type *IntTy = SE.getEffectiveSCEVType(IVTy); + const SCEV *LeftOverExpr = nullptr; + for (const IVInc &Inc : Chain) { + Instruction *InsertPt = Inc.UserInst; + if (isa<PHINode>(InsertPt)) + InsertPt = L->getLoopLatch()->getTerminator(); + + // IVOper will replace the current IV User's operand. IVSrc is the IV + // value currently held in a register. + Value *IVOper = IVSrc; + if (!Inc.IncExpr->isZero()) { + // IncExpr was the result of subtraction of two narrow values, so must + // be signed. + const SCEV *IncExpr = SE.getNoopOrSignExtend(Inc.IncExpr, IntTy); + LeftOverExpr = LeftOverExpr ? + SE.getAddExpr(LeftOverExpr, IncExpr) : IncExpr; + } + if (LeftOverExpr && !LeftOverExpr->isZero()) { + // Expand the IV increment. + Rewriter.clearPostInc(); + Value *IncV = Rewriter.expandCodeFor(LeftOverExpr, IntTy, InsertPt); + const SCEV *IVOperExpr = SE.getAddExpr(SE.getUnknown(IVSrc), + SE.getUnknown(IncV)); + IVOper = Rewriter.expandCodeFor(IVOperExpr, IVTy, InsertPt); + + // If an IV increment can't be folded, use it as the next IV value. + if (!canFoldIVIncExpr(LeftOverExpr, Inc.UserInst, Inc.IVOperand, TTI)) { + assert(IVTy == IVOper->getType() && "inconsistent IV increment type"); + IVSrc = IVOper; + LeftOverExpr = nullptr; + } + } + Type *OperTy = Inc.IVOperand->getType(); + if (IVTy != OperTy) { + assert(SE.getTypeSizeInBits(IVTy) >= SE.getTypeSizeInBits(OperTy) && + "cannot extend a chained IV"); + IRBuilder<> Builder(InsertPt); + IVOper = Builder.CreateTruncOrBitCast(IVOper, OperTy, "lsr.chain"); + } + Inc.UserInst->replaceUsesOfWith(Inc.IVOperand, IVOper); + DeadInsts.emplace_back(Inc.IVOperand); + } + // If LSR created a new, wider phi, we may also replace its postinc. We only + // do this if we also found a wide value for the head of the chain. + if (isa<PHINode>(Chain.tailUserInst())) { + for (PHINode &Phi : L->getHeader()->phis()) { + if (!isCompatibleIVType(&Phi, IVSrc)) + continue; + Instruction *PostIncV = dyn_cast<Instruction>( + Phi.getIncomingValueForBlock(L->getLoopLatch())); + if (!PostIncV || (SE.getSCEV(PostIncV) != SE.getSCEV(IVSrc))) + continue; + Value *IVOper = IVSrc; + Type *PostIncTy = PostIncV->getType(); + if (IVTy != PostIncTy) { + assert(PostIncTy->isPointerTy() && "mixing int/ptr IV types"); + IRBuilder<> Builder(L->getLoopLatch()->getTerminator()); + Builder.SetCurrentDebugLocation(PostIncV->getDebugLoc()); + IVOper = Builder.CreatePointerCast(IVSrc, PostIncTy, "lsr.chain"); + } + Phi.replaceUsesOfWith(PostIncV, IVOper); + DeadInsts.emplace_back(PostIncV); + } + } +} + +void LSRInstance::CollectFixupsAndInitialFormulae() { + for (const IVStrideUse &U : IU) { + Instruction *UserInst = U.getUser(); + // Skip IV users that are part of profitable IV Chains. + User::op_iterator UseI = + find(UserInst->operands(), U.getOperandValToReplace()); + assert(UseI != UserInst->op_end() && "cannot find IV operand"); + if (IVIncSet.count(UseI)) { + DEBUG(dbgs() << "Use is in profitable chain: " << **UseI << '\n'); + continue; + } + + LSRUse::KindType Kind = LSRUse::Basic; + MemAccessTy AccessTy; + if (isAddressUse(TTI, UserInst, U.getOperandValToReplace())) { + Kind = LSRUse::Address; + AccessTy = getAccessType(TTI, UserInst); + } + + const SCEV *S = IU.getExpr(U); + PostIncLoopSet TmpPostIncLoops = U.getPostIncLoops(); + + // Equality (== and !=) ICmps are special. We can rewrite (i == N) as + // (N - i == 0), and this allows (N - i) to be the expression that we work + // with rather than just N or i, so we can consider the register + // requirements for both N and i at the same time. Limiting this code to + // equality icmps is not a problem because all interesting loops use + // equality icmps, thanks to IndVarSimplify. + if (ICmpInst *CI = dyn_cast<ICmpInst>(UserInst)) + if (CI->isEquality()) { + // Swap the operands if needed to put the OperandValToReplace on the + // left, for consistency. + Value *NV = CI->getOperand(1); + if (NV == U.getOperandValToReplace()) { + CI->setOperand(1, CI->getOperand(0)); + CI->setOperand(0, NV); + NV = CI->getOperand(1); + Changed = true; + } + + // x == y --> x - y == 0 + const SCEV *N = SE.getSCEV(NV); + if (SE.isLoopInvariant(N, L) && isSafeToExpand(N, SE)) { + // S is normalized, so normalize N before folding it into S + // to keep the result normalized. + N = normalizeForPostIncUse(N, TmpPostIncLoops, SE); + Kind = LSRUse::ICmpZero; + S = SE.getMinusSCEV(N, S); + } + + // -1 and the negations of all interesting strides (except the negation + // of -1) are now also interesting. + for (size_t i = 0, e = Factors.size(); i != e; ++i) + if (Factors[i] != -1) + Factors.insert(-(uint64_t)Factors[i]); + Factors.insert(-1); + } + + // Get or create an LSRUse. + std::pair<size_t, int64_t> P = getUse(S, Kind, AccessTy); + size_t LUIdx = P.first; + int64_t Offset = P.second; + LSRUse &LU = Uses[LUIdx]; + + // Record the fixup. + LSRFixup &LF = LU.getNewFixup(); + LF.UserInst = UserInst; + LF.OperandValToReplace = U.getOperandValToReplace(); + LF.PostIncLoops = TmpPostIncLoops; + LF.Offset = Offset; + LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); + + if (!LU.WidestFixupType || + SE.getTypeSizeInBits(LU.WidestFixupType) < + SE.getTypeSizeInBits(LF.OperandValToReplace->getType())) + LU.WidestFixupType = LF.OperandValToReplace->getType(); + + // If this is the first use of this LSRUse, give it a formula. + if (LU.Formulae.empty()) { + InsertInitialFormula(S, LU, LUIdx); + CountRegisters(LU.Formulae.back(), LUIdx); + } + } + + DEBUG(print_fixups(dbgs())); +} + +/// Insert a formula for the given expression into the given use, separating out +/// loop-variant portions from loop-invariant and loop-computable portions. +void +LSRInstance::InsertInitialFormula(const SCEV *S, LSRUse &LU, size_t LUIdx) { + // Mark uses whose expressions cannot be expanded. + if (!isSafeToExpand(S, SE)) + LU.RigidFormula = true; + + Formula F; + F.initialMatch(S, L, SE); + bool Inserted = InsertFormula(LU, LUIdx, F); + assert(Inserted && "Initial formula already exists!"); (void)Inserted; +} + +/// Insert a simple single-register formula for the given expression into the +/// given use. +void +LSRInstance::InsertSupplementalFormula(const SCEV *S, + LSRUse &LU, size_t LUIdx) { + Formula F; + F.BaseRegs.push_back(S); + F.HasBaseReg = true; + bool Inserted = InsertFormula(LU, LUIdx, F); + assert(Inserted && "Supplemental formula already exists!"); (void)Inserted; +} + +/// Note which registers are used by the given formula, updating RegUses. +void LSRInstance::CountRegisters(const Formula &F, size_t LUIdx) { + if (F.ScaledReg) + RegUses.countRegister(F.ScaledReg, LUIdx); + for (const SCEV *BaseReg : F.BaseRegs) + RegUses.countRegister(BaseReg, LUIdx); +} + +/// If the given formula has not yet been inserted, add it to the list, and +/// return true. Return false otherwise. +bool LSRInstance::InsertFormula(LSRUse &LU, unsigned LUIdx, const Formula &F) { + // Do not insert formula that we will not be able to expand. + assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F) && + "Formula is illegal"); + + if (!LU.InsertFormula(F, *L)) + return false; + + CountRegisters(F, LUIdx); + return true; +} + +/// Check for other uses of loop-invariant values which we're tracking. These +/// other uses will pin these values in registers, making them less profitable +/// for elimination. +/// TODO: This currently misses non-constant addrec step registers. +/// TODO: Should this give more weight to users inside the loop? +void +LSRInstance::CollectLoopInvariantFixupsAndFormulae() { + SmallVector<const SCEV *, 8> Worklist(RegUses.begin(), RegUses.end()); + SmallPtrSet<const SCEV *, 32> Visited; + + while (!Worklist.empty()) { + const SCEV *S = Worklist.pop_back_val(); + + // Don't process the same SCEV twice + if (!Visited.insert(S).second) + continue; + + if (const SCEVNAryExpr *N = dyn_cast<SCEVNAryExpr>(S)) + Worklist.append(N->op_begin(), N->op_end()); + else if (const SCEVCastExpr *C = dyn_cast<SCEVCastExpr>(S)) + Worklist.push_back(C->getOperand()); + else if (const SCEVUDivExpr *D = dyn_cast<SCEVUDivExpr>(S)) { + Worklist.push_back(D->getLHS()); + Worklist.push_back(D->getRHS()); + } else if (const SCEVUnknown *US = dyn_cast<SCEVUnknown>(S)) { + const Value *V = US->getValue(); + if (const Instruction *Inst = dyn_cast<Instruction>(V)) { + // Look for instructions defined outside the loop. + if (L->contains(Inst)) continue; + } else if (isa<UndefValue>(V)) + // Undef doesn't have a live range, so it doesn't matter. + continue; + for (const Use &U : V->uses()) { + const Instruction *UserInst = dyn_cast<Instruction>(U.getUser()); + // Ignore non-instructions. + if (!UserInst) + continue; + // Ignore instructions in other functions (as can happen with + // Constants). + if (UserInst->getParent()->getParent() != L->getHeader()->getParent()) + continue; + // Ignore instructions not dominated by the loop. + const BasicBlock *UseBB = !isa<PHINode>(UserInst) ? + UserInst->getParent() : + cast<PHINode>(UserInst)->getIncomingBlock( + PHINode::getIncomingValueNumForOperand(U.getOperandNo())); + if (!DT.dominates(L->getHeader(), UseBB)) + continue; + // Don't bother if the instruction is in a BB which ends in an EHPad. + if (UseBB->getTerminator()->isEHPad()) + continue; + // Don't bother rewriting PHIs in catchswitch blocks. + if (isa<CatchSwitchInst>(UserInst->getParent()->getTerminator())) + continue; + // Ignore uses which are part of other SCEV expressions, to avoid + // analyzing them multiple times. + if (SE.isSCEVable(UserInst->getType())) { + const SCEV *UserS = SE.getSCEV(const_cast<Instruction *>(UserInst)); + // If the user is a no-op, look through to its uses. + if (!isa<SCEVUnknown>(UserS)) + continue; + if (UserS == US) { + Worklist.push_back( + SE.getUnknown(const_cast<Instruction *>(UserInst))); + continue; + } + } + // Ignore icmp instructions which are already being analyzed. + if (const ICmpInst *ICI = dyn_cast<ICmpInst>(UserInst)) { + unsigned OtherIdx = !U.getOperandNo(); + Value *OtherOp = const_cast<Value *>(ICI->getOperand(OtherIdx)); + if (SE.hasComputableLoopEvolution(SE.getSCEV(OtherOp), L)) + continue; + } + + std::pair<size_t, int64_t> P = getUse( + S, LSRUse::Basic, MemAccessTy()); + size_t LUIdx = P.first; + int64_t Offset = P.second; + LSRUse &LU = Uses[LUIdx]; + LSRFixup &LF = LU.getNewFixup(); + LF.UserInst = const_cast<Instruction *>(UserInst); + LF.OperandValToReplace = U; + LF.Offset = Offset; + LU.AllFixupsOutsideLoop &= LF.isUseFullyOutsideLoop(L); + if (!LU.WidestFixupType || + SE.getTypeSizeInBits(LU.WidestFixupType) < + SE.getTypeSizeInBits(LF.OperandValToReplace->getType())) + LU.WidestFixupType = LF.OperandValToReplace->getType(); + InsertSupplementalFormula(US, LU, LUIdx); + CountRegisters(LU.Formulae.back(), Uses.size() - 1); + break; + } + } + } +} + +/// Split S into subexpressions which can be pulled out into separate +/// registers. If C is non-null, multiply each subexpression by C. +/// +/// Return remainder expression after factoring the subexpressions captured by +/// Ops. If Ops is complete, return NULL. +static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C, + SmallVectorImpl<const SCEV *> &Ops, + const Loop *L, + ScalarEvolution &SE, + unsigned Depth = 0) { + // Arbitrarily cap recursion to protect compile time. + if (Depth >= 3) + return S; + + if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(S)) { + // Break out add operands. + for (const SCEV *S : Add->operands()) { + const SCEV *Remainder = CollectSubexprs(S, C, Ops, L, SE, Depth+1); + if (Remainder) + Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder); + } + return nullptr; + } else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) { + // Split a non-zero base out of an addrec. + if (AR->getStart()->isZero() || !AR->isAffine()) + return S; + + const SCEV *Remainder = CollectSubexprs(AR->getStart(), + C, Ops, L, SE, Depth+1); + // Split the non-zero AddRec unless it is part of a nested recurrence that + // does not pertain to this loop. + if (Remainder && (AR->getLoop() == L || !isa<SCEVAddRecExpr>(Remainder))) { + Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder); + Remainder = nullptr; + } + if (Remainder != AR->getStart()) { + if (!Remainder) + Remainder = SE.getConstant(AR->getType(), 0); + return SE.getAddRecExpr(Remainder, + AR->getStepRecurrence(SE), + AR->getLoop(), + //FIXME: AR->getNoWrapFlags(SCEV::FlagNW) + SCEV::FlagAnyWrap); + } + } else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) { + // Break (C * (a + b + c)) into C*a + C*b + C*c. + if (Mul->getNumOperands() != 2) + return S; + if (const SCEVConstant *Op0 = + dyn_cast<SCEVConstant>(Mul->getOperand(0))) { + C = C ? cast<SCEVConstant>(SE.getMulExpr(C, Op0)) : Op0; + const SCEV *Remainder = + CollectSubexprs(Mul->getOperand(1), C, Ops, L, SE, Depth+1); + if (Remainder) + Ops.push_back(SE.getMulExpr(C, Remainder)); + return nullptr; + } + } + return S; +} + +/// \brief Helper function for LSRInstance::GenerateReassociations. +void LSRInstance::GenerateReassociationsImpl(LSRUse &LU, unsigned LUIdx, + const Formula &Base, + unsigned Depth, size_t Idx, + bool IsScaledReg) { + const SCEV *BaseReg = IsScaledReg ? Base.ScaledReg : Base.BaseRegs[Idx]; + SmallVector<const SCEV *, 8> AddOps; + const SCEV *Remainder = CollectSubexprs(BaseReg, nullptr, AddOps, L, SE); + if (Remainder) + AddOps.push_back(Remainder); + + if (AddOps.size() == 1) + return; + + for (SmallVectorImpl<const SCEV *>::const_iterator J = AddOps.begin(), + JE = AddOps.end(); + J != JE; ++J) { + // Loop-variant "unknown" values are uninteresting; we won't be able to + // do anything meaningful with them. + if (isa<SCEVUnknown>(*J) && !SE.isLoopInvariant(*J, L)) + continue; + + // Don't pull a constant into a register if the constant could be folded + // into an immediate field. + if (isAlwaysFoldable(TTI, SE, LU.MinOffset, LU.MaxOffset, LU.Kind, + LU.AccessTy, *J, Base.getNumRegs() > 1)) + continue; + + // Collect all operands except *J. + SmallVector<const SCEV *, 8> InnerAddOps( + ((const SmallVector<const SCEV *, 8> &)AddOps).begin(), J); + InnerAddOps.append(std::next(J), + ((const SmallVector<const SCEV *, 8> &)AddOps).end()); + + // Don't leave just a constant behind in a register if the constant could + // be folded into an immediate field. + if (InnerAddOps.size() == 1 && + isAlwaysFoldable(TTI, SE, LU.MinOffset, LU.MaxOffset, LU.Kind, + LU.AccessTy, InnerAddOps[0], Base.getNumRegs() > 1)) + continue; + + const SCEV *InnerSum = SE.getAddExpr(InnerAddOps); + if (InnerSum->isZero()) + continue; + Formula F = Base; + + // Add the remaining pieces of the add back into the new formula. + const SCEVConstant *InnerSumSC = dyn_cast<SCEVConstant>(InnerSum); + if (InnerSumSC && SE.getTypeSizeInBits(InnerSumSC->getType()) <= 64 && + TTI.isLegalAddImmediate((uint64_t)F.UnfoldedOffset + + InnerSumSC->getValue()->getZExtValue())) { + F.UnfoldedOffset = + (uint64_t)F.UnfoldedOffset + InnerSumSC->getValue()->getZExtValue(); + if (IsScaledReg) + F.ScaledReg = nullptr; + else + F.BaseRegs.erase(F.BaseRegs.begin() + Idx); + } else if (IsScaledReg) + F.ScaledReg = InnerSum; + else + F.BaseRegs[Idx] = InnerSum; + + // Add J as its own register, or an unfolded immediate. + const SCEVConstant *SC = dyn_cast<SCEVConstant>(*J); + if (SC && SE.getTypeSizeInBits(SC->getType()) <= 64 && + TTI.isLegalAddImmediate((uint64_t)F.UnfoldedOffset + + SC->getValue()->getZExtValue())) + F.UnfoldedOffset = + (uint64_t)F.UnfoldedOffset + SC->getValue()->getZExtValue(); + else + F.BaseRegs.push_back(*J); + // We may have changed the number of register in base regs, adjust the + // formula accordingly. + F.canonicalize(*L); + + if (InsertFormula(LU, LUIdx, F)) + // If that formula hadn't been seen before, recurse to find more like + // it. + GenerateReassociations(LU, LUIdx, LU.Formulae.back(), Depth + 1); + } +} + +/// Split out subexpressions from adds and the bases of addrecs. +void LSRInstance::GenerateReassociations(LSRUse &LU, unsigned LUIdx, + Formula Base, unsigned Depth) { + assert(Base.isCanonical(*L) && "Input must be in the canonical form"); + // Arbitrarily cap recursion to protect compile time. + if (Depth >= 3) + return; + + for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) + GenerateReassociationsImpl(LU, LUIdx, Base, Depth, i); + + if (Base.Scale == 1) + GenerateReassociationsImpl(LU, LUIdx, Base, Depth, + /* Idx */ -1, /* IsScaledReg */ true); +} + +/// Generate a formula consisting of all of the loop-dominating registers added +/// into a single register. +void LSRInstance::GenerateCombinations(LSRUse &LU, unsigned LUIdx, + Formula Base) { + // This method is only interesting on a plurality of registers. + if (Base.BaseRegs.size() + (Base.Scale == 1) <= 1) + return; + + // Flatten the representation, i.e., reg1 + 1*reg2 => reg1 + reg2, before + // processing the formula. + Base.unscale(); + Formula F = Base; + F.BaseRegs.clear(); + SmallVector<const SCEV *, 4> Ops; + for (const SCEV *BaseReg : Base.BaseRegs) { + if (SE.properlyDominates(BaseReg, L->getHeader()) && + !SE.hasComputableLoopEvolution(BaseReg, L)) + Ops.push_back(BaseReg); + else + F.BaseRegs.push_back(BaseReg); + } + if (Ops.size() > 1) { + const SCEV *Sum = SE.getAddExpr(Ops); + // TODO: If Sum is zero, it probably means ScalarEvolution missed an + // opportunity to fold something. For now, just ignore such cases + // rather than proceed with zero in a register. + if (!Sum->isZero()) { + F.BaseRegs.push_back(Sum); + F.canonicalize(*L); + (void)InsertFormula(LU, LUIdx, F); + } + } +} + +/// \brief Helper function for LSRInstance::GenerateSymbolicOffsets. +void LSRInstance::GenerateSymbolicOffsetsImpl(LSRUse &LU, unsigned LUIdx, + const Formula &Base, size_t Idx, + bool IsScaledReg) { + const SCEV *G = IsScaledReg ? Base.ScaledReg : Base.BaseRegs[Idx]; + GlobalValue *GV = ExtractSymbol(G, SE); + if (G->isZero() || !GV) + return; + Formula F = Base; + F.BaseGV = GV; + if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F)) + return; + if (IsScaledReg) + F.ScaledReg = G; + else + F.BaseRegs[Idx] = G; + (void)InsertFormula(LU, LUIdx, F); +} + +/// Generate reuse formulae using symbolic offsets. +void LSRInstance::GenerateSymbolicOffsets(LSRUse &LU, unsigned LUIdx, + Formula Base) { + // We can't add a symbolic offset if the address already contains one. + if (Base.BaseGV) return; + + for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) + GenerateSymbolicOffsetsImpl(LU, LUIdx, Base, i); + if (Base.Scale == 1) + GenerateSymbolicOffsetsImpl(LU, LUIdx, Base, /* Idx */ -1, + /* IsScaledReg */ true); +} + +/// \brief Helper function for LSRInstance::GenerateConstantOffsets. +void LSRInstance::GenerateConstantOffsetsImpl( + LSRUse &LU, unsigned LUIdx, const Formula &Base, + const SmallVectorImpl<int64_t> &Worklist, size_t Idx, bool IsScaledReg) { + const SCEV *G = IsScaledReg ? Base.ScaledReg : Base.BaseRegs[Idx]; + for (int64_t Offset : Worklist) { + Formula F = Base; + F.BaseOffset = (uint64_t)Base.BaseOffset - Offset; + if (isLegalUse(TTI, LU.MinOffset - Offset, LU.MaxOffset - Offset, LU.Kind, + LU.AccessTy, F)) { + // Add the offset to the base register. + const SCEV *NewG = SE.getAddExpr(SE.getConstant(G->getType(), Offset), G); + // If it cancelled out, drop the base register, otherwise update it. + if (NewG->isZero()) { + if (IsScaledReg) { + F.Scale = 0; + F.ScaledReg = nullptr; + } else + F.deleteBaseReg(F.BaseRegs[Idx]); + F.canonicalize(*L); + } else if (IsScaledReg) + F.ScaledReg = NewG; + else + F.BaseRegs[Idx] = NewG; + + (void)InsertFormula(LU, LUIdx, F); + } + } + + int64_t Imm = ExtractImmediate(G, SE); + if (G->isZero() || Imm == 0) + return; + Formula F = Base; + F.BaseOffset = (uint64_t)F.BaseOffset + Imm; + if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, F)) + return; + if (IsScaledReg) + F.ScaledReg = G; + else + F.BaseRegs[Idx] = G; + (void)InsertFormula(LU, LUIdx, F); +} + +/// GenerateConstantOffsets - Generate reuse formulae using symbolic offsets. +void LSRInstance::GenerateConstantOffsets(LSRUse &LU, unsigned LUIdx, + Formula Base) { + // TODO: For now, just add the min and max offset, because it usually isn't + // worthwhile looking at everything inbetween. + SmallVector<int64_t, 2> Worklist; + Worklist.push_back(LU.MinOffset); + if (LU.MaxOffset != LU.MinOffset) + Worklist.push_back(LU.MaxOffset); + + for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) + GenerateConstantOffsetsImpl(LU, LUIdx, Base, Worklist, i); + if (Base.Scale == 1) + GenerateConstantOffsetsImpl(LU, LUIdx, Base, Worklist, /* Idx */ -1, + /* IsScaledReg */ true); +} + +/// For ICmpZero, check to see if we can scale up the comparison. For example, x +/// == y -> x*c == y*c. +void LSRInstance::GenerateICmpZeroScales(LSRUse &LU, unsigned LUIdx, + Formula Base) { + if (LU.Kind != LSRUse::ICmpZero) return; + + // Determine the integer type for the base formula. + Type *IntTy = Base.getType(); + if (!IntTy) return; + if (SE.getTypeSizeInBits(IntTy) > 64) return; + + // Don't do this if there is more than one offset. + if (LU.MinOffset != LU.MaxOffset) return; + + // Check if transformation is valid. It is illegal to multiply pointer. + if (Base.ScaledReg && Base.ScaledReg->getType()->isPointerTy()) + return; + for (const SCEV *BaseReg : Base.BaseRegs) + if (BaseReg->getType()->isPointerTy()) + return; + assert(!Base.BaseGV && "ICmpZero use is not legal!"); + + // Check each interesting stride. + for (int64_t Factor : Factors) { + // Check that the multiplication doesn't overflow. + if (Base.BaseOffset == std::numeric_limits<int64_t>::min() && Factor == -1) + continue; + int64_t NewBaseOffset = (uint64_t)Base.BaseOffset * Factor; + if (NewBaseOffset / Factor != Base.BaseOffset) + continue; + // If the offset will be truncated at this use, check that it is in bounds. + if (!IntTy->isPointerTy() && + !ConstantInt::isValueValidForType(IntTy, NewBaseOffset)) + continue; + + // Check that multiplying with the use offset doesn't overflow. + int64_t Offset = LU.MinOffset; + if (Offset == std::numeric_limits<int64_t>::min() && Factor == -1) + continue; + Offset = (uint64_t)Offset * Factor; + if (Offset / Factor != LU.MinOffset) + continue; + // If the offset will be truncated at this use, check that it is in bounds. + if (!IntTy->isPointerTy() && + !ConstantInt::isValueValidForType(IntTy, Offset)) + continue; + + Formula F = Base; + F.BaseOffset = NewBaseOffset; + + // Check that this scale is legal. + if (!isLegalUse(TTI, Offset, Offset, LU.Kind, LU.AccessTy, F)) + continue; + + // Compensate for the use having MinOffset built into it. + F.BaseOffset = (uint64_t)F.BaseOffset + Offset - LU.MinOffset; + + const SCEV *FactorS = SE.getConstant(IntTy, Factor); + + // Check that multiplying with each base register doesn't overflow. + for (size_t i = 0, e = F.BaseRegs.size(); i != e; ++i) { + F.BaseRegs[i] = SE.getMulExpr(F.BaseRegs[i], FactorS); + if (getExactSDiv(F.BaseRegs[i], FactorS, SE) != Base.BaseRegs[i]) + goto next; + } + + // Check that multiplying with the scaled register doesn't overflow. + if (F.ScaledReg) { + F.ScaledReg = SE.getMulExpr(F.ScaledReg, FactorS); + if (getExactSDiv(F.ScaledReg, FactorS, SE) != Base.ScaledReg) + continue; + } + + // Check that multiplying with the unfolded offset doesn't overflow. + if (F.UnfoldedOffset != 0) { + if (F.UnfoldedOffset == std::numeric_limits<int64_t>::min() && + Factor == -1) + continue; + F.UnfoldedOffset = (uint64_t)F.UnfoldedOffset * Factor; + if (F.UnfoldedOffset / Factor != Base.UnfoldedOffset) + continue; + // If the offset will be truncated, check that it is in bounds. + if (!IntTy->isPointerTy() && + !ConstantInt::isValueValidForType(IntTy, F.UnfoldedOffset)) + continue; + } + + // If we make it here and it's legal, add it. + (void)InsertFormula(LU, LUIdx, F); + next:; + } +} + +/// Generate stride factor reuse formulae by making use of scaled-offset address +/// modes, for example. +void LSRInstance::GenerateScales(LSRUse &LU, unsigned LUIdx, Formula Base) { + // Determine the integer type for the base formula. + Type *IntTy = Base.getType(); + if (!IntTy) return; + + // If this Formula already has a scaled register, we can't add another one. + // Try to unscale the formula to generate a better scale. + if (Base.Scale != 0 && !Base.unscale()) + return; + + assert(Base.Scale == 0 && "unscale did not did its job!"); + + // Check each interesting stride. + for (int64_t Factor : Factors) { + Base.Scale = Factor; + Base.HasBaseReg = Base.BaseRegs.size() > 1; + // Check whether this scale is going to be legal. + if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, + Base)) { + // As a special-case, handle special out-of-loop Basic users specially. + // TODO: Reconsider this special case. + if (LU.Kind == LSRUse::Basic && + isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LSRUse::Special, + LU.AccessTy, Base) && + LU.AllFixupsOutsideLoop) + LU.Kind = LSRUse::Special; + else + continue; + } + // For an ICmpZero, negating a solitary base register won't lead to + // new solutions. + if (LU.Kind == LSRUse::ICmpZero && + !Base.HasBaseReg && Base.BaseOffset == 0 && !Base.BaseGV) + continue; + // For each addrec base reg, if its loop is current loop, apply the scale. + for (size_t i = 0, e = Base.BaseRegs.size(); i != e; ++i) { + const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Base.BaseRegs[i]); + if (AR && (AR->getLoop() == L || LU.AllFixupsOutsideLoop)) { + const SCEV *FactorS = SE.getConstant(IntTy, Factor); + if (FactorS->isZero()) + continue; + // Divide out the factor, ignoring high bits, since we'll be + // scaling the value back up in the end. + if (const SCEV *Quotient = getExactSDiv(AR, FactorS, SE, true)) { + // TODO: This could be optimized to avoid all the copying. + Formula F = Base; + F.ScaledReg = Quotient; + F.deleteBaseReg(F.BaseRegs[i]); + // The canonical representation of 1*reg is reg, which is already in + // Base. In that case, do not try to insert the formula, it will be + // rejected anyway. + if (F.Scale == 1 && (F.BaseRegs.empty() || + (AR->getLoop() != L && LU.AllFixupsOutsideLoop))) + continue; + // If AllFixupsOutsideLoop is true and F.Scale is 1, we may generate + // non canonical Formula with ScaledReg's loop not being L. + if (F.Scale == 1 && LU.AllFixupsOutsideLoop) + F.canonicalize(*L); + (void)InsertFormula(LU, LUIdx, F); + } + } + } + } +} + +/// Generate reuse formulae from different IV types. +void LSRInstance::GenerateTruncates(LSRUse &LU, unsigned LUIdx, Formula Base) { + // Don't bother truncating symbolic values. + if (Base.BaseGV) return; + + // Determine the integer type for the base formula. + Type *DstTy = Base.getType(); + if (!DstTy) return; + DstTy = SE.getEffectiveSCEVType(DstTy); + + for (Type *SrcTy : Types) { + if (SrcTy != DstTy && TTI.isTruncateFree(SrcTy, DstTy)) { + Formula F = Base; + + if (F.ScaledReg) F.ScaledReg = SE.getAnyExtendExpr(F.ScaledReg, SrcTy); + for (const SCEV *&BaseReg : F.BaseRegs) + BaseReg = SE.getAnyExtendExpr(BaseReg, SrcTy); + + // TODO: This assumes we've done basic processing on all uses and + // have an idea what the register usage is. + if (!F.hasRegsUsedByUsesOtherThan(LUIdx, RegUses)) + continue; + + F.canonicalize(*L); + (void)InsertFormula(LU, LUIdx, F); + } + } +} + +namespace { + +/// Helper class for GenerateCrossUseConstantOffsets. It's used to defer +/// modifications so that the search phase doesn't have to worry about the data +/// structures moving underneath it. +struct WorkItem { + size_t LUIdx; + int64_t Imm; + const SCEV *OrigReg; + + WorkItem(size_t LI, int64_t I, const SCEV *R) + : LUIdx(LI), Imm(I), OrigReg(R) {} + + void print(raw_ostream &OS) const; + void dump() const; +}; + +} // end anonymous namespace + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void WorkItem::print(raw_ostream &OS) const { + OS << "in formulae referencing " << *OrigReg << " in use " << LUIdx + << " , add offset " << Imm; +} + +LLVM_DUMP_METHOD void WorkItem::dump() const { + print(errs()); errs() << '\n'; +} +#endif + +/// Look for registers which are a constant distance apart and try to form reuse +/// opportunities between them. +void LSRInstance::GenerateCrossUseConstantOffsets() { + // Group the registers by their value without any added constant offset. + using ImmMapTy = std::map<int64_t, const SCEV *>; + + DenseMap<const SCEV *, ImmMapTy> Map; + DenseMap<const SCEV *, SmallBitVector> UsedByIndicesMap; + SmallVector<const SCEV *, 8> Sequence; + for (const SCEV *Use : RegUses) { + const SCEV *Reg = Use; // Make a copy for ExtractImmediate to modify. + int64_t Imm = ExtractImmediate(Reg, SE); + auto Pair = Map.insert(std::make_pair(Reg, ImmMapTy())); + if (Pair.second) + Sequence.push_back(Reg); + Pair.first->second.insert(std::make_pair(Imm, Use)); + UsedByIndicesMap[Reg] |= RegUses.getUsedByIndices(Use); + } + + // Now examine each set of registers with the same base value. Build up + // a list of work to do and do the work in a separate step so that we're + // not adding formulae and register counts while we're searching. + SmallVector<WorkItem, 32> WorkItems; + SmallSet<std::pair<size_t, int64_t>, 32> UniqueItems; + for (const SCEV *Reg : Sequence) { + const ImmMapTy &Imms = Map.find(Reg)->second; + + // It's not worthwhile looking for reuse if there's only one offset. + if (Imms.size() == 1) + continue; + + DEBUG(dbgs() << "Generating cross-use offsets for " << *Reg << ':'; + for (const auto &Entry : Imms) + dbgs() << ' ' << Entry.first; + dbgs() << '\n'); + + // Examine each offset. + for (ImmMapTy::const_iterator J = Imms.begin(), JE = Imms.end(); + J != JE; ++J) { + const SCEV *OrigReg = J->second; + + int64_t JImm = J->first; + const SmallBitVector &UsedByIndices = RegUses.getUsedByIndices(OrigReg); + + if (!isa<SCEVConstant>(OrigReg) && + UsedByIndicesMap[Reg].count() == 1) { + DEBUG(dbgs() << "Skipping cross-use reuse for " << *OrigReg << '\n'); + continue; + } + + // Conservatively examine offsets between this orig reg a few selected + // other orig regs. + ImmMapTy::const_iterator OtherImms[] = { + Imms.begin(), std::prev(Imms.end()), + Imms.lower_bound((Imms.begin()->first + std::prev(Imms.end())->first) / + 2) + }; + for (size_t i = 0, e = array_lengthof(OtherImms); i != e; ++i) { + ImmMapTy::const_iterator M = OtherImms[i]; + if (M == J || M == JE) continue; + + // Compute the difference between the two. + int64_t Imm = (uint64_t)JImm - M->first; + for (unsigned LUIdx : UsedByIndices.set_bits()) + // Make a memo of this use, offset, and register tuple. + if (UniqueItems.insert(std::make_pair(LUIdx, Imm)).second) + WorkItems.push_back(WorkItem(LUIdx, Imm, OrigReg)); + } + } + } + + Map.clear(); + Sequence.clear(); + UsedByIndicesMap.clear(); + UniqueItems.clear(); + + // Now iterate through the worklist and add new formulae. + for (const WorkItem &WI : WorkItems) { + size_t LUIdx = WI.LUIdx; + LSRUse &LU = Uses[LUIdx]; + int64_t Imm = WI.Imm; + const SCEV *OrigReg = WI.OrigReg; + + Type *IntTy = SE.getEffectiveSCEVType(OrigReg->getType()); + const SCEV *NegImmS = SE.getSCEV(ConstantInt::get(IntTy, -(uint64_t)Imm)); + unsigned BitWidth = SE.getTypeSizeInBits(IntTy); + + // TODO: Use a more targeted data structure. + for (size_t L = 0, LE = LU.Formulae.size(); L != LE; ++L) { + Formula F = LU.Formulae[L]; + // FIXME: The code for the scaled and unscaled registers looks + // very similar but slightly different. Investigate if they + // could be merged. That way, we would not have to unscale the + // Formula. + F.unscale(); + // Use the immediate in the scaled register. + if (F.ScaledReg == OrigReg) { + int64_t Offset = (uint64_t)F.BaseOffset + Imm * (uint64_t)F.Scale; + // Don't create 50 + reg(-50). + if (F.referencesReg(SE.getSCEV( + ConstantInt::get(IntTy, -(uint64_t)Offset)))) + continue; + Formula NewF = F; + NewF.BaseOffset = Offset; + if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, + NewF)) + continue; + NewF.ScaledReg = SE.getAddExpr(NegImmS, NewF.ScaledReg); + + // If the new scale is a constant in a register, and adding the constant + // value to the immediate would produce a value closer to zero than the + // immediate itself, then the formula isn't worthwhile. + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(NewF.ScaledReg)) + if (C->getValue()->isNegative() != (NewF.BaseOffset < 0) && + (C->getAPInt().abs() * APInt(BitWidth, F.Scale)) + .ule(std::abs(NewF.BaseOffset))) + continue; + + // OK, looks good. + NewF.canonicalize(*this->L); + (void)InsertFormula(LU, LUIdx, NewF); + } else { + // Use the immediate in a base register. + for (size_t N = 0, NE = F.BaseRegs.size(); N != NE; ++N) { + const SCEV *BaseReg = F.BaseRegs[N]; + if (BaseReg != OrigReg) + continue; + Formula NewF = F; + NewF.BaseOffset = (uint64_t)NewF.BaseOffset + Imm; + if (!isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, + LU.Kind, LU.AccessTy, NewF)) { + if (!TTI.isLegalAddImmediate((uint64_t)NewF.UnfoldedOffset + Imm)) + continue; + NewF = F; + NewF.UnfoldedOffset = (uint64_t)NewF.UnfoldedOffset + Imm; + } + NewF.BaseRegs[N] = SE.getAddExpr(NegImmS, BaseReg); + + // If the new formula has a constant in a register, and adding the + // constant value to the immediate would produce a value closer to + // zero than the immediate itself, then the formula isn't worthwhile. + for (const SCEV *NewReg : NewF.BaseRegs) + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(NewReg)) + if ((C->getAPInt() + NewF.BaseOffset) + .abs() + .slt(std::abs(NewF.BaseOffset)) && + (C->getAPInt() + NewF.BaseOffset).countTrailingZeros() >= + countTrailingZeros<uint64_t>(NewF.BaseOffset)) + goto skip_formula; + + // Ok, looks good. + NewF.canonicalize(*this->L); + (void)InsertFormula(LU, LUIdx, NewF); + break; + skip_formula:; + } + } + } + } +} + +/// Generate formulae for each use. +void +LSRInstance::GenerateAllReuseFormulae() { + // This is split into multiple loops so that hasRegsUsedByUsesOtherThan + // queries are more precise. + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateReassociations(LU, LUIdx, LU.Formulae[i]); + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateCombinations(LU, LUIdx, LU.Formulae[i]); + } + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateSymbolicOffsets(LU, LUIdx, LU.Formulae[i]); + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateConstantOffsets(LU, LUIdx, LU.Formulae[i]); + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateICmpZeroScales(LU, LUIdx, LU.Formulae[i]); + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateScales(LU, LUIdx, LU.Formulae[i]); + } + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + for (size_t i = 0, f = LU.Formulae.size(); i != f; ++i) + GenerateTruncates(LU, LUIdx, LU.Formulae[i]); + } + + GenerateCrossUseConstantOffsets(); + + DEBUG(dbgs() << "\n" + "After generating reuse formulae:\n"; + print_uses(dbgs())); +} + +/// If there are multiple formulae with the same set of registers used +/// by other uses, pick the best one and delete the others. +void LSRInstance::FilterOutUndesirableDedicatedRegisters() { + DenseSet<const SCEV *> VisitedRegs; + SmallPtrSet<const SCEV *, 16> Regs; + SmallPtrSet<const SCEV *, 16> LoserRegs; +#ifndef NDEBUG + bool ChangedFormulae = false; +#endif + + // Collect the best formula for each unique set of shared registers. This + // is reset for each use. + using BestFormulaeTy = + DenseMap<SmallVector<const SCEV *, 4>, size_t, UniquifierDenseMapInfo>; + + BestFormulaeTy BestFormulae; + + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + DEBUG(dbgs() << "Filtering for use "; LU.print(dbgs()); dbgs() << '\n'); + + bool Any = false; + for (size_t FIdx = 0, NumForms = LU.Formulae.size(); + FIdx != NumForms; ++FIdx) { + Formula &F = LU.Formulae[FIdx]; + + // Some formulas are instant losers. For example, they may depend on + // nonexistent AddRecs from other loops. These need to be filtered + // immediately, otherwise heuristics could choose them over others leading + // to an unsatisfactory solution. Passing LoserRegs into RateFormula here + // avoids the need to recompute this information across formulae using the + // same bad AddRec. Passing LoserRegs is also essential unless we remove + // the corresponding bad register from the Regs set. + Cost CostF; + Regs.clear(); + CostF.RateFormula(TTI, F, Regs, VisitedRegs, L, SE, DT, LU, &LoserRegs); + if (CostF.isLoser()) { + // During initial formula generation, undesirable formulae are generated + // by uses within other loops that have some non-trivial address mode or + // use the postinc form of the IV. LSR needs to provide these formulae + // as the basis of rediscovering the desired formula that uses an AddRec + // corresponding to the existing phi. Once all formulae have been + // generated, these initial losers may be pruned. + DEBUG(dbgs() << " Filtering loser "; F.print(dbgs()); + dbgs() << "\n"); + } + else { + SmallVector<const SCEV *, 4> Key; + for (const SCEV *Reg : F.BaseRegs) { + if (RegUses.isRegUsedByUsesOtherThan(Reg, LUIdx)) + Key.push_back(Reg); + } + if (F.ScaledReg && + RegUses.isRegUsedByUsesOtherThan(F.ScaledReg, LUIdx)) + Key.push_back(F.ScaledReg); + // Unstable sort by host order ok, because this is only used for + // uniquifying. + std::sort(Key.begin(), Key.end()); + + std::pair<BestFormulaeTy::const_iterator, bool> P = + BestFormulae.insert(std::make_pair(Key, FIdx)); + if (P.second) + continue; + + Formula &Best = LU.Formulae[P.first->second]; + + Cost CostBest; + Regs.clear(); + CostBest.RateFormula(TTI, Best, Regs, VisitedRegs, L, SE, DT, LU); + if (CostF.isLess(CostBest, TTI)) + std::swap(F, Best); + DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); + dbgs() << "\n" + " in favor of formula "; Best.print(dbgs()); + dbgs() << '\n'); + } +#ifndef NDEBUG + ChangedFormulae = true; +#endif + LU.DeleteFormula(F); + --FIdx; + --NumForms; + Any = true; + } + + // Now that we've filtered out some formulae, recompute the Regs set. + if (Any) + LU.RecomputeRegs(LUIdx, RegUses); + + // Reset this to prepare for the next use. + BestFormulae.clear(); + } + + DEBUG(if (ChangedFormulae) { + dbgs() << "\n" + "After filtering out undesirable candidates:\n"; + print_uses(dbgs()); + }); +} + +// This is a rough guess that seems to work fairly well. +static const size_t ComplexityLimit = std::numeric_limits<uint16_t>::max(); + +/// Estimate the worst-case number of solutions the solver might have to +/// consider. It almost never considers this many solutions because it prune the +/// search space, but the pruning isn't always sufficient. +size_t LSRInstance::EstimateSearchSpaceComplexity() const { + size_t Power = 1; + for (const LSRUse &LU : Uses) { + size_t FSize = LU.Formulae.size(); + if (FSize >= ComplexityLimit) { + Power = ComplexityLimit; + break; + } + Power *= FSize; + if (Power >= ComplexityLimit) + break; + } + return Power; +} + +/// When one formula uses a superset of the registers of another formula, it +/// won't help reduce register pressure (though it may not necessarily hurt +/// register pressure); remove it to simplify the system. +void LSRInstance::NarrowSearchSpaceByDetectingSupersets() { + if (EstimateSearchSpaceComplexity() >= ComplexityLimit) { + DEBUG(dbgs() << "The search space is too complex.\n"); + + DEBUG(dbgs() << "Narrowing the search space by eliminating formulae " + "which use a superset of registers used by other " + "formulae.\n"); + + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + bool Any = false; + for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) { + Formula &F = LU.Formulae[i]; + // Look for a formula with a constant or GV in a register. If the use + // also has a formula with that same value in an immediate field, + // delete the one that uses a register. + for (SmallVectorImpl<const SCEV *>::const_iterator + I = F.BaseRegs.begin(), E = F.BaseRegs.end(); I != E; ++I) { + if (const SCEVConstant *C = dyn_cast<SCEVConstant>(*I)) { + Formula NewF = F; + NewF.BaseOffset += C->getValue()->getSExtValue(); + NewF.BaseRegs.erase(NewF.BaseRegs.begin() + + (I - F.BaseRegs.begin())); + if (LU.HasFormulaWithSameRegs(NewF)) { + DEBUG(dbgs() << " Deleting "; F.print(dbgs()); dbgs() << '\n'); + LU.DeleteFormula(F); + --i; + --e; + Any = true; + break; + } + } else if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(*I)) { + if (GlobalValue *GV = dyn_cast<GlobalValue>(U->getValue())) + if (!F.BaseGV) { + Formula NewF = F; + NewF.BaseGV = GV; + NewF.BaseRegs.erase(NewF.BaseRegs.begin() + + (I - F.BaseRegs.begin())); + if (LU.HasFormulaWithSameRegs(NewF)) { + DEBUG(dbgs() << " Deleting "; F.print(dbgs()); + dbgs() << '\n'); + LU.DeleteFormula(F); + --i; + --e; + Any = true; + break; + } + } + } + } + } + if (Any) + LU.RecomputeRegs(LUIdx, RegUses); + } + + DEBUG(dbgs() << "After pre-selection:\n"; + print_uses(dbgs())); + } +} + +/// When there are many registers for expressions like A, A+1, A+2, etc., +/// allocate a single register for them. +void LSRInstance::NarrowSearchSpaceByCollapsingUnrolledCode() { + if (EstimateSearchSpaceComplexity() < ComplexityLimit) + return; + + DEBUG(dbgs() << "The search space is too complex.\n" + "Narrowing the search space by assuming that uses separated " + "by a constant offset will use the same registers.\n"); + + // This is especially useful for unrolled loops. + + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + for (const Formula &F : LU.Formulae) { + if (F.BaseOffset == 0 || (F.Scale != 0 && F.Scale != 1)) + continue; + + LSRUse *LUThatHas = FindUseWithSimilarFormula(F, LU); + if (!LUThatHas) + continue; + + if (!reconcileNewOffset(*LUThatHas, F.BaseOffset, /*HasBaseReg=*/ false, + LU.Kind, LU.AccessTy)) + continue; + + DEBUG(dbgs() << " Deleting use "; LU.print(dbgs()); dbgs() << '\n'); + + LUThatHas->AllFixupsOutsideLoop &= LU.AllFixupsOutsideLoop; + + // Transfer the fixups of LU to LUThatHas. + for (LSRFixup &Fixup : LU.Fixups) { + Fixup.Offset += F.BaseOffset; + LUThatHas->pushFixup(Fixup); + DEBUG(dbgs() << "New fixup has offset " << Fixup.Offset << '\n'); + } + + // Delete formulae from the new use which are no longer legal. + bool Any = false; + for (size_t i = 0, e = LUThatHas->Formulae.size(); i != e; ++i) { + Formula &F = LUThatHas->Formulae[i]; + if (!isLegalUse(TTI, LUThatHas->MinOffset, LUThatHas->MaxOffset, + LUThatHas->Kind, LUThatHas->AccessTy, F)) { + DEBUG(dbgs() << " Deleting "; F.print(dbgs()); + dbgs() << '\n'); + LUThatHas->DeleteFormula(F); + --i; + --e; + Any = true; + } + } + + if (Any) + LUThatHas->RecomputeRegs(LUThatHas - &Uses.front(), RegUses); + + // Delete the old use. + DeleteUse(LU, LUIdx); + --LUIdx; + --NumUses; + break; + } + } + + DEBUG(dbgs() << "After pre-selection:\n"; print_uses(dbgs())); +} + +/// Call FilterOutUndesirableDedicatedRegisters again, if necessary, now that +/// we've done more filtering, as it may be able to find more formulae to +/// eliminate. +void LSRInstance::NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(){ + if (EstimateSearchSpaceComplexity() >= ComplexityLimit) { + DEBUG(dbgs() << "The search space is too complex.\n"); + + DEBUG(dbgs() << "Narrowing the search space by re-filtering out " + "undesirable dedicated registers.\n"); + + FilterOutUndesirableDedicatedRegisters(); + + DEBUG(dbgs() << "After pre-selection:\n"; + print_uses(dbgs())); + } +} + +/// If a LSRUse has multiple formulae with the same ScaledReg and Scale. +/// Pick the best one and delete the others. +/// This narrowing heuristic is to keep as many formulae with different +/// Scale and ScaledReg pair as possible while narrowing the search space. +/// The benefit is that it is more likely to find out a better solution +/// from a formulae set with more Scale and ScaledReg variations than +/// a formulae set with the same Scale and ScaledReg. The picking winner +/// reg heurstic will often keep the formulae with the same Scale and +/// ScaledReg and filter others, and we want to avoid that if possible. +void LSRInstance::NarrowSearchSpaceByFilterFormulaWithSameScaledReg() { + if (EstimateSearchSpaceComplexity() < ComplexityLimit) + return; + + DEBUG(dbgs() << "The search space is too complex.\n" + "Narrowing the search space by choosing the best Formula " + "from the Formulae with the same Scale and ScaledReg.\n"); + + // Map the "Scale * ScaledReg" pair to the best formula of current LSRUse. + using BestFormulaeTy = DenseMap<std::pair<const SCEV *, int64_t>, size_t>; + + BestFormulaeTy BestFormulae; +#ifndef NDEBUG + bool ChangedFormulae = false; +#endif + DenseSet<const SCEV *> VisitedRegs; + SmallPtrSet<const SCEV *, 16> Regs; + + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + DEBUG(dbgs() << "Filtering for use "; LU.print(dbgs()); dbgs() << '\n'); + + // Return true if Formula FA is better than Formula FB. + auto IsBetterThan = [&](Formula &FA, Formula &FB) { + // First we will try to choose the Formula with fewer new registers. + // For a register used by current Formula, the more the register is + // shared among LSRUses, the less we increase the register number + // counter of the formula. + size_t FARegNum = 0; + for (const SCEV *Reg : FA.BaseRegs) { + const SmallBitVector &UsedByIndices = RegUses.getUsedByIndices(Reg); + FARegNum += (NumUses - UsedByIndices.count() + 1); + } + size_t FBRegNum = 0; + for (const SCEV *Reg : FB.BaseRegs) { + const SmallBitVector &UsedByIndices = RegUses.getUsedByIndices(Reg); + FBRegNum += (NumUses - UsedByIndices.count() + 1); + } + if (FARegNum != FBRegNum) + return FARegNum < FBRegNum; + + // If the new register numbers are the same, choose the Formula with + // less Cost. + Cost CostFA, CostFB; + Regs.clear(); + CostFA.RateFormula(TTI, FA, Regs, VisitedRegs, L, SE, DT, LU); + Regs.clear(); + CostFB.RateFormula(TTI, FB, Regs, VisitedRegs, L, SE, DT, LU); + return CostFA.isLess(CostFB, TTI); + }; + + bool Any = false; + for (size_t FIdx = 0, NumForms = LU.Formulae.size(); FIdx != NumForms; + ++FIdx) { + Formula &F = LU.Formulae[FIdx]; + if (!F.ScaledReg) + continue; + auto P = BestFormulae.insert({{F.ScaledReg, F.Scale}, FIdx}); + if (P.second) + continue; + + Formula &Best = LU.Formulae[P.first->second]; + if (IsBetterThan(F, Best)) + std::swap(F, Best); + DEBUG(dbgs() << " Filtering out formula "; F.print(dbgs()); + dbgs() << "\n" + " in favor of formula "; + Best.print(dbgs()); dbgs() << '\n'); +#ifndef NDEBUG + ChangedFormulae = true; +#endif + LU.DeleteFormula(F); + --FIdx; + --NumForms; + Any = true; + } + if (Any) + LU.RecomputeRegs(LUIdx, RegUses); + + // Reset this to prepare for the next use. + BestFormulae.clear(); + } + + DEBUG(if (ChangedFormulae) { + dbgs() << "\n" + "After filtering out undesirable candidates:\n"; + print_uses(dbgs()); + }); +} + +/// The function delete formulas with high registers number expectation. +/// Assuming we don't know the value of each formula (already delete +/// all inefficient), generate probability of not selecting for each +/// register. +/// For example, +/// Use1: +/// reg(a) + reg({0,+,1}) +/// reg(a) + reg({-1,+,1}) + 1 +/// reg({a,+,1}) +/// Use2: +/// reg(b) + reg({0,+,1}) +/// reg(b) + reg({-1,+,1}) + 1 +/// reg({b,+,1}) +/// Use3: +/// reg(c) + reg(b) + reg({0,+,1}) +/// reg(c) + reg({b,+,1}) +/// +/// Probability of not selecting +/// Use1 Use2 Use3 +/// reg(a) (1/3) * 1 * 1 +/// reg(b) 1 * (1/3) * (1/2) +/// reg({0,+,1}) (2/3) * (2/3) * (1/2) +/// reg({-1,+,1}) (2/3) * (2/3) * 1 +/// reg({a,+,1}) (2/3) * 1 * 1 +/// reg({b,+,1}) 1 * (2/3) * (2/3) +/// reg(c) 1 * 1 * 0 +/// +/// Now count registers number mathematical expectation for each formula: +/// Note that for each use we exclude probability if not selecting for the use. +/// For example for Use1 probability for reg(a) would be just 1 * 1 (excluding +/// probabilty 1/3 of not selecting for Use1). +/// Use1: +/// reg(a) + reg({0,+,1}) 1 + 1/3 -- to be deleted +/// reg(a) + reg({-1,+,1}) + 1 1 + 4/9 -- to be deleted +/// reg({a,+,1}) 1 +/// Use2: +/// reg(b) + reg({0,+,1}) 1/2 + 1/3 -- to be deleted +/// reg(b) + reg({-1,+,1}) + 1 1/2 + 2/3 -- to be deleted +/// reg({b,+,1}) 2/3 +/// Use3: +/// reg(c) + reg(b) + reg({0,+,1}) 1 + 1/3 + 4/9 -- to be deleted +/// reg(c) + reg({b,+,1}) 1 + 2/3 +void LSRInstance::NarrowSearchSpaceByDeletingCostlyFormulas() { + if (EstimateSearchSpaceComplexity() < ComplexityLimit) + return; + // Ok, we have too many of formulae on our hands to conveniently handle. + // Use a rough heuristic to thin out the list. + + // Set of Regs wich will be 100% used in final solution. + // Used in each formula of a solution (in example above this is reg(c)). + // We can skip them in calculations. + SmallPtrSet<const SCEV *, 4> UniqRegs; + DEBUG(dbgs() << "The search space is too complex.\n"); + + // Map each register to probability of not selecting + DenseMap <const SCEV *, float> RegNumMap; + for (const SCEV *Reg : RegUses) { + if (UniqRegs.count(Reg)) + continue; + float PNotSel = 1; + for (const LSRUse &LU : Uses) { + if (!LU.Regs.count(Reg)) + continue; + float P = LU.getNotSelectedProbability(Reg); + if (P != 0.0) + PNotSel *= P; + else + UniqRegs.insert(Reg); + } + RegNumMap.insert(std::make_pair(Reg, PNotSel)); + } + + DEBUG(dbgs() << "Narrowing the search space by deleting costly formulas\n"); + + // Delete formulas where registers number expectation is high. + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + // If nothing to delete - continue. + if (LU.Formulae.size() < 2) + continue; + // This is temporary solution to test performance. Float should be + // replaced with round independent type (based on integers) to avoid + // different results for different target builds. + float FMinRegNum = LU.Formulae[0].getNumRegs(); + float FMinARegNum = LU.Formulae[0].getNumRegs(); + size_t MinIdx = 0; + for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) { + Formula &F = LU.Formulae[i]; + float FRegNum = 0; + float FARegNum = 0; + for (const SCEV *BaseReg : F.BaseRegs) { + if (UniqRegs.count(BaseReg)) + continue; + FRegNum += RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg); + if (isa<SCEVAddRecExpr>(BaseReg)) + FARegNum += + RegNumMap[BaseReg] / LU.getNotSelectedProbability(BaseReg); + } + if (const SCEV *ScaledReg = F.ScaledReg) { + if (!UniqRegs.count(ScaledReg)) { + FRegNum += + RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg); + if (isa<SCEVAddRecExpr>(ScaledReg)) + FARegNum += + RegNumMap[ScaledReg] / LU.getNotSelectedProbability(ScaledReg); + } + } + if (FMinRegNum > FRegNum || + (FMinRegNum == FRegNum && FMinARegNum > FARegNum)) { + FMinRegNum = FRegNum; + FMinARegNum = FARegNum; + MinIdx = i; + } + } + DEBUG(dbgs() << " The formula "; LU.Formulae[MinIdx].print(dbgs()); + dbgs() << " with min reg num " << FMinRegNum << '\n'); + if (MinIdx != 0) + std::swap(LU.Formulae[MinIdx], LU.Formulae[0]); + while (LU.Formulae.size() != 1) { + DEBUG(dbgs() << " Deleting "; LU.Formulae.back().print(dbgs()); + dbgs() << '\n'); + LU.Formulae.pop_back(); + } + LU.RecomputeRegs(LUIdx, RegUses); + assert(LU.Formulae.size() == 1 && "Should be exactly 1 min regs formula"); + Formula &F = LU.Formulae[0]; + DEBUG(dbgs() << " Leaving only "; F.print(dbgs()); dbgs() << '\n'); + // When we choose the formula, the regs become unique. + UniqRegs.insert(F.BaseRegs.begin(), F.BaseRegs.end()); + if (F.ScaledReg) + UniqRegs.insert(F.ScaledReg); + } + DEBUG(dbgs() << "After pre-selection:\n"; + print_uses(dbgs())); +} + +/// Pick a register which seems likely to be profitable, and then in any use +/// which has any reference to that register, delete all formulae which do not +/// reference that register. +void LSRInstance::NarrowSearchSpaceByPickingWinnerRegs() { + // With all other options exhausted, loop until the system is simple + // enough to handle. + SmallPtrSet<const SCEV *, 4> Taken; + while (EstimateSearchSpaceComplexity() >= ComplexityLimit) { + // Ok, we have too many of formulae on our hands to conveniently handle. + // Use a rough heuristic to thin out the list. + DEBUG(dbgs() << "The search space is too complex.\n"); + + // Pick the register which is used by the most LSRUses, which is likely + // to be a good reuse register candidate. + const SCEV *Best = nullptr; + unsigned BestNum = 0; + for (const SCEV *Reg : RegUses) { + if (Taken.count(Reg)) + continue; + if (!Best) { + Best = Reg; + BestNum = RegUses.getUsedByIndices(Reg).count(); + } else { + unsigned Count = RegUses.getUsedByIndices(Reg).count(); + if (Count > BestNum) { + Best = Reg; + BestNum = Count; + } + } + } + + DEBUG(dbgs() << "Narrowing the search space by assuming " << *Best + << " will yield profitable reuse.\n"); + Taken.insert(Best); + + // In any use with formulae which references this register, delete formulae + // which don't reference it. + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) { + LSRUse &LU = Uses[LUIdx]; + if (!LU.Regs.count(Best)) continue; + + bool Any = false; + for (size_t i = 0, e = LU.Formulae.size(); i != e; ++i) { + Formula &F = LU.Formulae[i]; + if (!F.referencesReg(Best)) { + DEBUG(dbgs() << " Deleting "; F.print(dbgs()); dbgs() << '\n'); + LU.DeleteFormula(F); + --e; + --i; + Any = true; + assert(e != 0 && "Use has no formulae left! Is Regs inconsistent?"); + continue; + } + } + + if (Any) + LU.RecomputeRegs(LUIdx, RegUses); + } + + DEBUG(dbgs() << "After pre-selection:\n"; + print_uses(dbgs())); + } +} + +/// If there are an extraordinary number of formulae to choose from, use some +/// rough heuristics to prune down the number of formulae. This keeps the main +/// solver from taking an extraordinary amount of time in some worst-case +/// scenarios. +void LSRInstance::NarrowSearchSpaceUsingHeuristics() { + NarrowSearchSpaceByDetectingSupersets(); + NarrowSearchSpaceByCollapsingUnrolledCode(); + NarrowSearchSpaceByRefilteringUndesirableDedicatedRegisters(); + if (FilterSameScaledReg) + NarrowSearchSpaceByFilterFormulaWithSameScaledReg(); + if (LSRExpNarrow) + NarrowSearchSpaceByDeletingCostlyFormulas(); + else + NarrowSearchSpaceByPickingWinnerRegs(); +} + +/// This is the recursive solver. +void LSRInstance::SolveRecurse(SmallVectorImpl<const Formula *> &Solution, + Cost &SolutionCost, + SmallVectorImpl<const Formula *> &Workspace, + const Cost &CurCost, + const SmallPtrSet<const SCEV *, 16> &CurRegs, + DenseSet<const SCEV *> &VisitedRegs) const { + // Some ideas: + // - prune more: + // - use more aggressive filtering + // - sort the formula so that the most profitable solutions are found first + // - sort the uses too + // - search faster: + // - don't compute a cost, and then compare. compare while computing a cost + // and bail early. + // - track register sets with SmallBitVector + + const LSRUse &LU = Uses[Workspace.size()]; + + // If this use references any register that's already a part of the + // in-progress solution, consider it a requirement that a formula must + // reference that register in order to be considered. This prunes out + // unprofitable searching. + SmallSetVector<const SCEV *, 4> ReqRegs; + for (const SCEV *S : CurRegs) + if (LU.Regs.count(S)) + ReqRegs.insert(S); + + SmallPtrSet<const SCEV *, 16> NewRegs; + Cost NewCost; + for (const Formula &F : LU.Formulae) { + // Ignore formulae which may not be ideal in terms of register reuse of + // ReqRegs. The formula should use all required registers before + // introducing new ones. + int NumReqRegsToFind = std::min(F.getNumRegs(), ReqRegs.size()); + for (const SCEV *Reg : ReqRegs) { + if ((F.ScaledReg && F.ScaledReg == Reg) || + is_contained(F.BaseRegs, Reg)) { + --NumReqRegsToFind; + if (NumReqRegsToFind == 0) + break; + } + } + if (NumReqRegsToFind != 0) { + // If none of the formulae satisfied the required registers, then we could + // clear ReqRegs and try again. Currently, we simply give up in this case. + continue; + } + + // Evaluate the cost of the current formula. If it's already worse than + // the current best, prune the search at that point. + NewCost = CurCost; + NewRegs = CurRegs; + NewCost.RateFormula(TTI, F, NewRegs, VisitedRegs, L, SE, DT, LU); + if (NewCost.isLess(SolutionCost, TTI)) { + Workspace.push_back(&F); + if (Workspace.size() != Uses.size()) { + SolveRecurse(Solution, SolutionCost, Workspace, NewCost, + NewRegs, VisitedRegs); + if (F.getNumRegs() == 1 && Workspace.size() == 1) + VisitedRegs.insert(F.ScaledReg ? F.ScaledReg : F.BaseRegs[0]); + } else { + DEBUG(dbgs() << "New best at "; NewCost.print(dbgs()); + dbgs() << ".\n Regs:"; + for (const SCEV *S : NewRegs) + dbgs() << ' ' << *S; + dbgs() << '\n'); + + SolutionCost = NewCost; + Solution = Workspace; + } + Workspace.pop_back(); + } + } +} + +/// Choose one formula from each use. Return the results in the given Solution +/// vector. +void LSRInstance::Solve(SmallVectorImpl<const Formula *> &Solution) const { + SmallVector<const Formula *, 8> Workspace; + Cost SolutionCost; + SolutionCost.Lose(); + Cost CurCost; + SmallPtrSet<const SCEV *, 16> CurRegs; + DenseSet<const SCEV *> VisitedRegs; + Workspace.reserve(Uses.size()); + + // SolveRecurse does all the work. + SolveRecurse(Solution, SolutionCost, Workspace, CurCost, + CurRegs, VisitedRegs); + if (Solution.empty()) { + DEBUG(dbgs() << "\nNo Satisfactory Solution\n"); + return; + } + + // Ok, we've now made all our decisions. + DEBUG(dbgs() << "\n" + "The chosen solution requires "; SolutionCost.print(dbgs()); + dbgs() << ":\n"; + for (size_t i = 0, e = Uses.size(); i != e; ++i) { + dbgs() << " "; + Uses[i].print(dbgs()); + dbgs() << "\n" + " "; + Solution[i]->print(dbgs()); + dbgs() << '\n'; + }); + + assert(Solution.size() == Uses.size() && "Malformed solution!"); +} + +/// Helper for AdjustInsertPositionForExpand. Climb up the dominator tree far as +/// we can go while still being dominated by the input positions. This helps +/// canonicalize the insert position, which encourages sharing. +BasicBlock::iterator +LSRInstance::HoistInsertPosition(BasicBlock::iterator IP, + const SmallVectorImpl<Instruction *> &Inputs) + const { + Instruction *Tentative = &*IP; + while (true) { + bool AllDominate = true; + Instruction *BetterPos = nullptr; + // Don't bother attempting to insert before a catchswitch, their basic block + // cannot have other non-PHI instructions. + if (isa<CatchSwitchInst>(Tentative)) + return IP; + + for (Instruction *Inst : Inputs) { + if (Inst == Tentative || !DT.dominates(Inst, Tentative)) { + AllDominate = false; + break; + } + // Attempt to find an insert position in the middle of the block, + // instead of at the end, so that it can be used for other expansions. + if (Tentative->getParent() == Inst->getParent() && + (!BetterPos || !DT.dominates(Inst, BetterPos))) + BetterPos = &*std::next(BasicBlock::iterator(Inst)); + } + if (!AllDominate) + break; + if (BetterPos) + IP = BetterPos->getIterator(); + else + IP = Tentative->getIterator(); + + const Loop *IPLoop = LI.getLoopFor(IP->getParent()); + unsigned IPLoopDepth = IPLoop ? IPLoop->getLoopDepth() : 0; + + BasicBlock *IDom; + for (DomTreeNode *Rung = DT.getNode(IP->getParent()); ; ) { + if (!Rung) return IP; + Rung = Rung->getIDom(); + if (!Rung) return IP; + IDom = Rung->getBlock(); + + // Don't climb into a loop though. + const Loop *IDomLoop = LI.getLoopFor(IDom); + unsigned IDomDepth = IDomLoop ? IDomLoop->getLoopDepth() : 0; + if (IDomDepth <= IPLoopDepth && + (IDomDepth != IPLoopDepth || IDomLoop == IPLoop)) + break; + } + + Tentative = IDom->getTerminator(); + } + + return IP; +} + +/// Determine an input position which will be dominated by the operands and +/// which will dominate the result. +BasicBlock::iterator +LSRInstance::AdjustInsertPositionForExpand(BasicBlock::iterator LowestIP, + const LSRFixup &LF, + const LSRUse &LU, + SCEVExpander &Rewriter) const { + // Collect some instructions which must be dominated by the + // expanding replacement. These must be dominated by any operands that + // will be required in the expansion. + SmallVector<Instruction *, 4> Inputs; + if (Instruction *I = dyn_cast<Instruction>(LF.OperandValToReplace)) + Inputs.push_back(I); + if (LU.Kind == LSRUse::ICmpZero) + if (Instruction *I = + dyn_cast<Instruction>(cast<ICmpInst>(LF.UserInst)->getOperand(1))) + Inputs.push_back(I); + if (LF.PostIncLoops.count(L)) { + if (LF.isUseFullyOutsideLoop(L)) + Inputs.push_back(L->getLoopLatch()->getTerminator()); + else + Inputs.push_back(IVIncInsertPos); + } + // The expansion must also be dominated by the increment positions of any + // loops it for which it is using post-inc mode. + for (const Loop *PIL : LF.PostIncLoops) { + if (PIL == L) continue; + + // Be dominated by the loop exit. + SmallVector<BasicBlock *, 4> ExitingBlocks; + PIL->getExitingBlocks(ExitingBlocks); + if (!ExitingBlocks.empty()) { + BasicBlock *BB = ExitingBlocks[0]; + for (unsigned i = 1, e = ExitingBlocks.size(); i != e; ++i) + BB = DT.findNearestCommonDominator(BB, ExitingBlocks[i]); + Inputs.push_back(BB->getTerminator()); + } + } + + assert(!isa<PHINode>(LowestIP) && !LowestIP->isEHPad() + && !isa<DbgInfoIntrinsic>(LowestIP) && + "Insertion point must be a normal instruction"); + + // Then, climb up the immediate dominator tree as far as we can go while + // still being dominated by the input positions. + BasicBlock::iterator IP = HoistInsertPosition(LowestIP, Inputs); + + // Don't insert instructions before PHI nodes. + while (isa<PHINode>(IP)) ++IP; + + // Ignore landingpad instructions. + while (IP->isEHPad()) ++IP; + + // Ignore debug intrinsics. + while (isa<DbgInfoIntrinsic>(IP)) ++IP; + + // Set IP below instructions recently inserted by SCEVExpander. This keeps the + // IP consistent across expansions and allows the previously inserted + // instructions to be reused by subsequent expansion. + while (Rewriter.isInsertedInstruction(&*IP) && IP != LowestIP) + ++IP; + + return IP; +} + +/// Emit instructions for the leading candidate expression for this LSRUse (this +/// is called "expanding"). +Value *LSRInstance::Expand(const LSRUse &LU, const LSRFixup &LF, + const Formula &F, BasicBlock::iterator IP, + SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { + if (LU.RigidFormula) + return LF.OperandValToReplace; + + // Determine an input position which will be dominated by the operands and + // which will dominate the result. + IP = AdjustInsertPositionForExpand(IP, LF, LU, Rewriter); + Rewriter.setInsertPoint(&*IP); + + // Inform the Rewriter if we have a post-increment use, so that it can + // perform an advantageous expansion. + Rewriter.setPostInc(LF.PostIncLoops); + + // This is the type that the user actually needs. + Type *OpTy = LF.OperandValToReplace->getType(); + // This will be the type that we'll initially expand to. + Type *Ty = F.getType(); + if (!Ty) + // No type known; just expand directly to the ultimate type. + Ty = OpTy; + else if (SE.getEffectiveSCEVType(Ty) == SE.getEffectiveSCEVType(OpTy)) + // Expand directly to the ultimate type if it's the right size. + Ty = OpTy; + // This is the type to do integer arithmetic in. + Type *IntTy = SE.getEffectiveSCEVType(Ty); + + // Build up a list of operands to add together to form the full base. + SmallVector<const SCEV *, 8> Ops; + + // Expand the BaseRegs portion. + for (const SCEV *Reg : F.BaseRegs) { + assert(!Reg->isZero() && "Zero allocated in a base register!"); + + // If we're expanding for a post-inc user, make the post-inc adjustment. + Reg = denormalizeForPostIncUse(Reg, LF.PostIncLoops, SE); + Ops.push_back(SE.getUnknown(Rewriter.expandCodeFor(Reg, nullptr))); + } + + // Expand the ScaledReg portion. + Value *ICmpScaledV = nullptr; + if (F.Scale != 0) { + const SCEV *ScaledS = F.ScaledReg; + + // If we're expanding for a post-inc user, make the post-inc adjustment. + PostIncLoopSet &Loops = const_cast<PostIncLoopSet &>(LF.PostIncLoops); + ScaledS = denormalizeForPostIncUse(ScaledS, Loops, SE); + + if (LU.Kind == LSRUse::ICmpZero) { + // Expand ScaleReg as if it was part of the base regs. + if (F.Scale == 1) + Ops.push_back( + SE.getUnknown(Rewriter.expandCodeFor(ScaledS, nullptr))); + else { + // An interesting way of "folding" with an icmp is to use a negated + // scale, which we'll implement by inserting it into the other operand + // of the icmp. + assert(F.Scale == -1 && + "The only scale supported by ICmpZero uses is -1!"); + ICmpScaledV = Rewriter.expandCodeFor(ScaledS, nullptr); + } + } else { + // Otherwise just expand the scaled register and an explicit scale, + // which is expected to be matched as part of the address. + + // Flush the operand list to suppress SCEVExpander hoisting address modes. + // Unless the addressing mode will not be folded. + if (!Ops.empty() && LU.Kind == LSRUse::Address && + isAMCompletelyFolded(TTI, LU, F)) { + Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty); + Ops.clear(); + Ops.push_back(SE.getUnknown(FullV)); + } + ScaledS = SE.getUnknown(Rewriter.expandCodeFor(ScaledS, nullptr)); + if (F.Scale != 1) + ScaledS = + SE.getMulExpr(ScaledS, SE.getConstant(ScaledS->getType(), F.Scale)); + Ops.push_back(ScaledS); + } + } + + // Expand the GV portion. + if (F.BaseGV) { + // Flush the operand list to suppress SCEVExpander hoisting. + if (!Ops.empty()) { + Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty); + Ops.clear(); + Ops.push_back(SE.getUnknown(FullV)); + } + Ops.push_back(SE.getUnknown(F.BaseGV)); + } + + // Flush the operand list to suppress SCEVExpander hoisting of both folded and + // unfolded offsets. LSR assumes they both live next to their uses. + if (!Ops.empty()) { + Value *FullV = Rewriter.expandCodeFor(SE.getAddExpr(Ops), Ty); + Ops.clear(); + Ops.push_back(SE.getUnknown(FullV)); + } + + // Expand the immediate portion. + int64_t Offset = (uint64_t)F.BaseOffset + LF.Offset; + if (Offset != 0) { + if (LU.Kind == LSRUse::ICmpZero) { + // The other interesting way of "folding" with an ICmpZero is to use a + // negated immediate. + if (!ICmpScaledV) + ICmpScaledV = ConstantInt::get(IntTy, -(uint64_t)Offset); + else { + Ops.push_back(SE.getUnknown(ICmpScaledV)); + ICmpScaledV = ConstantInt::get(IntTy, Offset); + } + } else { + // Just add the immediate values. These again are expected to be matched + // as part of the address. + Ops.push_back(SE.getUnknown(ConstantInt::getSigned(IntTy, Offset))); + } + } + + // Expand the unfolded offset portion. + int64_t UnfoldedOffset = F.UnfoldedOffset; + if (UnfoldedOffset != 0) { + // Just add the immediate values. + Ops.push_back(SE.getUnknown(ConstantInt::getSigned(IntTy, + UnfoldedOffset))); + } + + // Emit instructions summing all the operands. + const SCEV *FullS = Ops.empty() ? + SE.getConstant(IntTy, 0) : + SE.getAddExpr(Ops); + Value *FullV = Rewriter.expandCodeFor(FullS, Ty); + + // We're done expanding now, so reset the rewriter. + Rewriter.clearPostInc(); + + // An ICmpZero Formula represents an ICmp which we're handling as a + // comparison against zero. Now that we've expanded an expression for that + // form, update the ICmp's other operand. + if (LU.Kind == LSRUse::ICmpZero) { + ICmpInst *CI = cast<ICmpInst>(LF.UserInst); + DeadInsts.emplace_back(CI->getOperand(1)); + assert(!F.BaseGV && "ICmp does not support folding a global value and " + "a scale at the same time!"); + if (F.Scale == -1) { + if (ICmpScaledV->getType() != OpTy) { + Instruction *Cast = + CastInst::Create(CastInst::getCastOpcode(ICmpScaledV, false, + OpTy, false), + ICmpScaledV, OpTy, "tmp", CI); + ICmpScaledV = Cast; + } + CI->setOperand(1, ICmpScaledV); + } else { + // A scale of 1 means that the scale has been expanded as part of the + // base regs. + assert((F.Scale == 0 || F.Scale == 1) && + "ICmp does not support folding a global value and " + "a scale at the same time!"); + Constant *C = ConstantInt::getSigned(SE.getEffectiveSCEVType(OpTy), + -(uint64_t)Offset); + if (C->getType() != OpTy) + C = ConstantExpr::getCast(CastInst::getCastOpcode(C, false, + OpTy, false), + C, OpTy); + + CI->setOperand(1, C); + } + } + + return FullV; +} + +/// Helper for Rewrite. PHI nodes are special because the use of their operands +/// effectively happens in their predecessor blocks, so the expression may need +/// to be expanded in multiple places. +void LSRInstance::RewriteForPHI( + PHINode *PN, const LSRUse &LU, const LSRFixup &LF, const Formula &F, + SCEVExpander &Rewriter, SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { + DenseMap<BasicBlock *, Value *> Inserted; + for (unsigned i = 0, e = PN->getNumIncomingValues(); i != e; ++i) + if (PN->getIncomingValue(i) == LF.OperandValToReplace) { + BasicBlock *BB = PN->getIncomingBlock(i); + + // If this is a critical edge, split the edge so that we do not insert + // the code on all predecessor/successor paths. We do this unless this + // is the canonical backedge for this loop, which complicates post-inc + // users. + if (e != 1 && BB->getTerminator()->getNumSuccessors() > 1 && + !isa<IndirectBrInst>(BB->getTerminator()) && + !isa<CatchSwitchInst>(BB->getTerminator())) { + BasicBlock *Parent = PN->getParent(); + Loop *PNLoop = LI.getLoopFor(Parent); + if (!PNLoop || Parent != PNLoop->getHeader()) { + // Split the critical edge. + BasicBlock *NewBB = nullptr; + if (!Parent->isLandingPad()) { + NewBB = SplitCriticalEdge(BB, Parent, + CriticalEdgeSplittingOptions(&DT, &LI) + .setMergeIdenticalEdges() + .setDontDeleteUselessPHIs()); + } else { + SmallVector<BasicBlock*, 2> NewBBs; + SplitLandingPadPredecessors(Parent, BB, "", "", NewBBs, &DT, &LI); + NewBB = NewBBs[0]; + } + // If NewBB==NULL, then SplitCriticalEdge refused to split because all + // phi predecessors are identical. The simple thing to do is skip + // splitting in this case rather than complicate the API. + if (NewBB) { + // If PN is outside of the loop and BB is in the loop, we want to + // move the block to be immediately before the PHI block, not + // immediately after BB. + if (L->contains(BB) && !L->contains(PN)) + NewBB->moveBefore(PN->getParent()); + + // Splitting the edge can reduce the number of PHI entries we have. + e = PN->getNumIncomingValues(); + BB = NewBB; + i = PN->getBasicBlockIndex(BB); + } + } + } + + std::pair<DenseMap<BasicBlock *, Value *>::iterator, bool> Pair = + Inserted.insert(std::make_pair(BB, static_cast<Value *>(nullptr))); + if (!Pair.second) + PN->setIncomingValue(i, Pair.first->second); + else { + Value *FullV = Expand(LU, LF, F, BB->getTerminator()->getIterator(), + Rewriter, DeadInsts); + + // If this is reuse-by-noop-cast, insert the noop cast. + Type *OpTy = LF.OperandValToReplace->getType(); + if (FullV->getType() != OpTy) + FullV = + CastInst::Create(CastInst::getCastOpcode(FullV, false, + OpTy, false), + FullV, LF.OperandValToReplace->getType(), + "tmp", BB->getTerminator()); + + PN->setIncomingValue(i, FullV); + Pair.first->second = FullV; + } + } +} + +/// Emit instructions for the leading candidate expression for this LSRUse (this +/// is called "expanding"), and update the UserInst to reference the newly +/// expanded value. +void LSRInstance::Rewrite(const LSRUse &LU, const LSRFixup &LF, + const Formula &F, SCEVExpander &Rewriter, + SmallVectorImpl<WeakTrackingVH> &DeadInsts) const { + // First, find an insertion point that dominates UserInst. For PHI nodes, + // find the nearest block which dominates all the relevant uses. + if (PHINode *PN = dyn_cast<PHINode>(LF.UserInst)) { + RewriteForPHI(PN, LU, LF, F, Rewriter, DeadInsts); + } else { + Value *FullV = + Expand(LU, LF, F, LF.UserInst->getIterator(), Rewriter, DeadInsts); + + // If this is reuse-by-noop-cast, insert the noop cast. + Type *OpTy = LF.OperandValToReplace->getType(); + if (FullV->getType() != OpTy) { + Instruction *Cast = + CastInst::Create(CastInst::getCastOpcode(FullV, false, OpTy, false), + FullV, OpTy, "tmp", LF.UserInst); + FullV = Cast; + } + + // Update the user. ICmpZero is handled specially here (for now) because + // Expand may have updated one of the operands of the icmp already, and + // its new value may happen to be equal to LF.OperandValToReplace, in + // which case doing replaceUsesOfWith leads to replacing both operands + // with the same value. TODO: Reorganize this. + if (LU.Kind == LSRUse::ICmpZero) + LF.UserInst->setOperand(0, FullV); + else + LF.UserInst->replaceUsesOfWith(LF.OperandValToReplace, FullV); + } + + DeadInsts.emplace_back(LF.OperandValToReplace); +} + +/// Rewrite all the fixup locations with new values, following the chosen +/// solution. +void LSRInstance::ImplementSolution( + const SmallVectorImpl<const Formula *> &Solution) { + // Keep track of instructions we may have made dead, so that + // we can remove them after we are done working. + SmallVector<WeakTrackingVH, 16> DeadInsts; + + SCEVExpander Rewriter(SE, L->getHeader()->getModule()->getDataLayout(), + "lsr"); +#ifndef NDEBUG + Rewriter.setDebugType(DEBUG_TYPE); +#endif + Rewriter.disableCanonicalMode(); + Rewriter.enableLSRMode(); + Rewriter.setIVIncInsertPos(L, IVIncInsertPos); + + // Mark phi nodes that terminate chains so the expander tries to reuse them. + for (const IVChain &Chain : IVChainVec) { + if (PHINode *PN = dyn_cast<PHINode>(Chain.tailUserInst())) + Rewriter.setChainedPhi(PN); + } + + // Expand the new value definitions and update the users. + for (size_t LUIdx = 0, NumUses = Uses.size(); LUIdx != NumUses; ++LUIdx) + for (const LSRFixup &Fixup : Uses[LUIdx].Fixups) { + Rewrite(Uses[LUIdx], Fixup, *Solution[LUIdx], Rewriter, DeadInsts); + Changed = true; + } + + for (const IVChain &Chain : IVChainVec) { + GenerateIVChain(Chain, Rewriter, DeadInsts); + Changed = true; + } + // Clean up after ourselves. This must be done before deleting any + // instructions. + Rewriter.clear(); + + Changed |= DeleteTriviallyDeadInstructions(DeadInsts); +} + +LSRInstance::LSRInstance(Loop *L, IVUsers &IU, ScalarEvolution &SE, + DominatorTree &DT, LoopInfo &LI, + const TargetTransformInfo &TTI) + : IU(IU), SE(SE), DT(DT), LI(LI), TTI(TTI), L(L) { + // If LoopSimplify form is not available, stay out of trouble. + if (!L->isLoopSimplifyForm()) + return; + + // If there's no interesting work to be done, bail early. + if (IU.empty()) return; + + // If there's too much analysis to be done, bail early. We won't be able to + // model the problem anyway. + unsigned NumUsers = 0; + for (const IVStrideUse &U : IU) { + if (++NumUsers > MaxIVUsers) { + (void)U; + DEBUG(dbgs() << "LSR skipping loop, too many IV Users in " << U << "\n"); + return; + } + // Bail out if we have a PHI on an EHPad that gets a value from a + // CatchSwitchInst. Because the CatchSwitchInst cannot be split, there is + // no good place to stick any instructions. + if (auto *PN = dyn_cast<PHINode>(U.getUser())) { + auto *FirstNonPHI = PN->getParent()->getFirstNonPHI(); + if (isa<FuncletPadInst>(FirstNonPHI) || + isa<CatchSwitchInst>(FirstNonPHI)) + for (BasicBlock *PredBB : PN->blocks()) + if (isa<CatchSwitchInst>(PredBB->getFirstNonPHI())) + return; + } + } + +#ifndef NDEBUG + // All dominating loops must have preheaders, or SCEVExpander may not be able + // to materialize an AddRecExpr whose Start is an outer AddRecExpr. + // + // IVUsers analysis should only create users that are dominated by simple loop + // headers. Since this loop should dominate all of its users, its user list + // should be empty if this loop itself is not within a simple loop nest. + for (DomTreeNode *Rung = DT.getNode(L->getLoopPreheader()); + Rung; Rung = Rung->getIDom()) { + BasicBlock *BB = Rung->getBlock(); + const Loop *DomLoop = LI.getLoopFor(BB); + if (DomLoop && DomLoop->getHeader() == BB) { + assert(DomLoop->getLoopPreheader() && "LSR needs a simplified loop nest"); + } + } +#endif // DEBUG + + DEBUG(dbgs() << "\nLSR on loop "; + L->getHeader()->printAsOperand(dbgs(), /*PrintType=*/false); + dbgs() << ":\n"); + + // First, perform some low-level loop optimizations. + OptimizeShadowIV(); + OptimizeLoopTermCond(); + + // If loop preparation eliminates all interesting IV users, bail. + if (IU.empty()) return; + + // Skip nested loops until we can model them better with formulae. + if (!L->empty()) { + DEBUG(dbgs() << "LSR skipping outer loop " << *L << "\n"); + return; + } + + // Start collecting data and preparing for the solver. + CollectChains(); + CollectInterestingTypesAndFactors(); + CollectFixupsAndInitialFormulae(); + CollectLoopInvariantFixupsAndFormulae(); + + assert(!Uses.empty() && "IVUsers reported at least one use"); + DEBUG(dbgs() << "LSR found " << Uses.size() << " uses:\n"; + print_uses(dbgs())); + + // Now use the reuse data to generate a bunch of interesting ways + // to formulate the values needed for the uses. + GenerateAllReuseFormulae(); + + FilterOutUndesirableDedicatedRegisters(); + NarrowSearchSpaceUsingHeuristics(); + + SmallVector<const Formula *, 8> Solution; + Solve(Solution); + + // Release memory that is no longer needed. + Factors.clear(); + Types.clear(); + RegUses.clear(); + + if (Solution.empty()) + return; + +#ifndef NDEBUG + // Formulae should be legal. + for (const LSRUse &LU : Uses) { + for (const Formula &F : LU.Formulae) + assert(isLegalUse(TTI, LU.MinOffset, LU.MaxOffset, LU.Kind, LU.AccessTy, + F) && "Illegal formula generated!"); + }; +#endif + + // Now that we've decided what we want, make it so. + ImplementSolution(Solution); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) +void LSRInstance::print_factors_and_types(raw_ostream &OS) const { + if (Factors.empty() && Types.empty()) return; + + OS << "LSR has identified the following interesting factors and types: "; + bool First = true; + + for (int64_t Factor : Factors) { + if (!First) OS << ", "; + First = false; + OS << '*' << Factor; + } + + for (Type *Ty : Types) { + if (!First) OS << ", "; + First = false; + OS << '(' << *Ty << ')'; + } + OS << '\n'; +} + +void LSRInstance::print_fixups(raw_ostream &OS) const { + OS << "LSR is examining the following fixup sites:\n"; + for (const LSRUse &LU : Uses) + for (const LSRFixup &LF : LU.Fixups) { + dbgs() << " "; + LF.print(OS); + OS << '\n'; + } +} + +void LSRInstance::print_uses(raw_ostream &OS) const { + OS << "LSR is examining the following uses:\n"; + for (const LSRUse &LU : Uses) { + dbgs() << " "; + LU.print(OS); + OS << '\n'; + for (const Formula &F : LU.Formulae) { + OS << " "; + F.print(OS); + OS << '\n'; + } + } +} + +void LSRInstance::print(raw_ostream &OS) const { + print_factors_and_types(OS); + print_fixups(OS); + print_uses(OS); +} + +LLVM_DUMP_METHOD void LSRInstance::dump() const { + print(errs()); errs() << '\n'; +} +#endif + +namespace { + +class LoopStrengthReduce : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + + LoopStrengthReduce(); + +private: + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + void getAnalysisUsage(AnalysisUsage &AU) const override; +}; + +} // end anonymous namespace + +LoopStrengthReduce::LoopStrengthReduce() : LoopPass(ID) { + initializeLoopStrengthReducePass(*PassRegistry::getPassRegistry()); +} + +void LoopStrengthReduce::getAnalysisUsage(AnalysisUsage &AU) const { + // We split critical edges, so we change the CFG. However, we do update + // many analyses if they are around. + AU.addPreservedID(LoopSimplifyID); + + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); + // Requiring LoopSimplify a second time here prevents IVUsers from running + // twice, since LoopSimplify was invalidated by running ScalarEvolution. + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<IVUsersWrapperPass>(); + AU.addPreserved<IVUsersWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); +} + +static bool ReduceLoopStrength(Loop *L, IVUsers &IU, ScalarEvolution &SE, + DominatorTree &DT, LoopInfo &LI, + const TargetTransformInfo &TTI) { + bool Changed = false; + + // Run the main LSR transformation. + Changed |= LSRInstance(L, IU, SE, DT, LI, TTI).getChanged(); + + // Remove any extra phis created by processing inner loops. + Changed |= DeleteDeadPHIs(L->getHeader()); + if (EnablePhiElim && L->isLoopSimplifyForm()) { + SmallVector<WeakTrackingVH, 16> DeadInsts; + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + SCEVExpander Rewriter(SE, DL, "lsr"); +#ifndef NDEBUG + Rewriter.setDebugType(DEBUG_TYPE); +#endif + unsigned numFolded = Rewriter.replaceCongruentIVs(L, &DT, DeadInsts, &TTI); + if (numFolded) { + Changed = true; + DeleteTriviallyDeadInstructions(DeadInsts); + DeleteDeadPHIs(L->getHeader()); + } + } + return Changed; +} + +bool LoopStrengthReduce::runOnLoop(Loop *L, LPPassManager & /*LPM*/) { + if (skipLoop(L)) + return false; + + auto &IU = getAnalysis<IVUsersWrapperPass>().getIU(); + auto &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *L->getHeader()->getParent()); + return ReduceLoopStrength(L, IU, SE, DT, LI, TTI); +} + +PreservedAnalyses LoopStrengthReducePass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &) { + if (!ReduceLoopStrength(&L, AM.getResult<IVUsersAnalysis>(L, AR), AR.SE, + AR.DT, AR.LI, AR.TTI)) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} + +char LoopStrengthReduce::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopStrengthReduce, "loop-reduce", + "Loop Strength Reduction", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(IVUsersWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_END(LoopStrengthReduce, "loop-reduce", + "Loop Strength Reduction", false, false) + +Pass *llvm::createLoopStrengthReducePass() { return new LoopStrengthReduce(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp new file mode 100644 index 000000000000..15e7da5e1a7a --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp @@ -0,0 +1,1353 @@ +//===- LoopUnroll.cpp - Loop unroller pass --------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass implements a simple loop unroller. It works best when loops have +// been canonicalized by the -indvars pass, allowing it to determine the trip +// counts of loops easily. +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LoopUnrollPass.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/LoopUnrollAnalyzer.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/ProfileSummaryInfo.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/LoopPassManager.h" +#include "llvm/Transforms/Utils/LoopSimplify.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/UnrollLoop.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <limits> +#include <string> +#include <tuple> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "loop-unroll" + +static cl::opt<unsigned> + UnrollThreshold("unroll-threshold", cl::Hidden, + cl::desc("The cost threshold for loop unrolling")); + +static cl::opt<unsigned> UnrollPartialThreshold( + "unroll-partial-threshold", cl::Hidden, + cl::desc("The cost threshold for partial loop unrolling")); + +static cl::opt<unsigned> UnrollMaxPercentThresholdBoost( + "unroll-max-percent-threshold-boost", cl::init(400), cl::Hidden, + cl::desc("The maximum 'boost' (represented as a percentage >= 100) applied " + "to the threshold when aggressively unrolling a loop due to the " + "dynamic cost savings. If completely unrolling a loop will reduce " + "the total runtime from X to Y, we boost the loop unroll " + "threshold to DefaultThreshold*std::min(MaxPercentThresholdBoost, " + "X/Y). This limit avoids excessive code bloat.")); + +static cl::opt<unsigned> UnrollMaxIterationsCountToAnalyze( + "unroll-max-iteration-count-to-analyze", cl::init(10), cl::Hidden, + cl::desc("Don't allow loop unrolling to simulate more than this number of" + "iterations when checking full unroll profitability")); + +static cl::opt<unsigned> UnrollCount( + "unroll-count", cl::Hidden, + cl::desc("Use this unroll count for all loops including those with " + "unroll_count pragma values, for testing purposes")); + +static cl::opt<unsigned> UnrollMaxCount( + "unroll-max-count", cl::Hidden, + cl::desc("Set the max unroll count for partial and runtime unrolling, for" + "testing purposes")); + +static cl::opt<unsigned> UnrollFullMaxCount( + "unroll-full-max-count", cl::Hidden, + cl::desc( + "Set the max unroll count for full unrolling, for testing purposes")); + +static cl::opt<unsigned> UnrollPeelCount( + "unroll-peel-count", cl::Hidden, + cl::desc("Set the unroll peeling count, for testing purposes")); + +static cl::opt<bool> + UnrollAllowPartial("unroll-allow-partial", cl::Hidden, + cl::desc("Allows loops to be partially unrolled until " + "-unroll-threshold loop size is reached.")); + +static cl::opt<bool> UnrollAllowRemainder( + "unroll-allow-remainder", cl::Hidden, + cl::desc("Allow generation of a loop remainder (extra iterations) " + "when unrolling a loop.")); + +static cl::opt<bool> + UnrollRuntime("unroll-runtime", cl::ZeroOrMore, cl::Hidden, + cl::desc("Unroll loops with run-time trip counts")); + +static cl::opt<unsigned> UnrollMaxUpperBound( + "unroll-max-upperbound", cl::init(8), cl::Hidden, + cl::desc( + "The max of trip count upper bound that is considered in unrolling")); + +static cl::opt<unsigned> PragmaUnrollThreshold( + "pragma-unroll-threshold", cl::init(16 * 1024), cl::Hidden, + cl::desc("Unrolled size limit for loops with an unroll(full) or " + "unroll_count pragma.")); + +static cl::opt<unsigned> FlatLoopTripCountThreshold( + "flat-loop-tripcount-threshold", cl::init(5), cl::Hidden, + cl::desc("If the runtime tripcount for the loop is lower than the " + "threshold, the loop is considered as flat and will be less " + "aggressively unrolled.")); + +static cl::opt<bool> + UnrollAllowPeeling("unroll-allow-peeling", cl::init(true), cl::Hidden, + cl::desc("Allows loops to be peeled when the dynamic " + "trip count is known to be low.")); + +static cl::opt<bool> UnrollUnrollRemainder( + "unroll-remainder", cl::Hidden, + cl::desc("Allow the loop remainder to be unrolled.")); + +// This option isn't ever intended to be enabled, it serves to allow +// experiments to check the assumptions about when this kind of revisit is +// necessary. +static cl::opt<bool> UnrollRevisitChildLoops( + "unroll-revisit-child-loops", cl::Hidden, + cl::desc("Enqueue and re-visit child loops in the loop PM after unrolling. " + "This shouldn't typically be needed as child loops (or their " + "clones) were already visited.")); + +/// A magic value for use with the Threshold parameter to indicate +/// that the loop unroll should be performed regardless of how much +/// code expansion would result. +static const unsigned NoThreshold = std::numeric_limits<unsigned>::max(); + +/// Gather the various unrolling parameters based on the defaults, compiler +/// flags, TTI overrides and user specified parameters. +static TargetTransformInfo::UnrollingPreferences gatherUnrollingPreferences( + Loop *L, ScalarEvolution &SE, const TargetTransformInfo &TTI, int OptLevel, + Optional<unsigned> UserThreshold, Optional<unsigned> UserCount, + Optional<bool> UserAllowPartial, Optional<bool> UserRuntime, + Optional<bool> UserUpperBound, Optional<bool> UserAllowPeeling) { + TargetTransformInfo::UnrollingPreferences UP; + + // Set up the defaults + UP.Threshold = OptLevel > 2 ? 300 : 150; + UP.MaxPercentThresholdBoost = 400; + UP.OptSizeThreshold = 0; + UP.PartialThreshold = 150; + UP.PartialOptSizeThreshold = 0; + UP.Count = 0; + UP.PeelCount = 0; + UP.DefaultUnrollRuntimeCount = 8; + UP.MaxCount = std::numeric_limits<unsigned>::max(); + UP.FullUnrollMaxCount = std::numeric_limits<unsigned>::max(); + UP.BEInsns = 2; + UP.Partial = false; + UP.Runtime = false; + UP.AllowRemainder = true; + UP.UnrollRemainder = false; + UP.AllowExpensiveTripCount = false; + UP.Force = false; + UP.UpperBound = false; + UP.AllowPeeling = true; + + // Override with any target specific settings + TTI.getUnrollingPreferences(L, SE, UP); + + // Apply size attributes + if (L->getHeader()->getParent()->optForSize()) { + UP.Threshold = UP.OptSizeThreshold; + UP.PartialThreshold = UP.PartialOptSizeThreshold; + } + + // Apply any user values specified by cl::opt + if (UnrollThreshold.getNumOccurrences() > 0) + UP.Threshold = UnrollThreshold; + if (UnrollPartialThreshold.getNumOccurrences() > 0) + UP.PartialThreshold = UnrollPartialThreshold; + if (UnrollMaxPercentThresholdBoost.getNumOccurrences() > 0) + UP.MaxPercentThresholdBoost = UnrollMaxPercentThresholdBoost; + if (UnrollMaxCount.getNumOccurrences() > 0) + UP.MaxCount = UnrollMaxCount; + if (UnrollFullMaxCount.getNumOccurrences() > 0) + UP.FullUnrollMaxCount = UnrollFullMaxCount; + if (UnrollPeelCount.getNumOccurrences() > 0) + UP.PeelCount = UnrollPeelCount; + if (UnrollAllowPartial.getNumOccurrences() > 0) + UP.Partial = UnrollAllowPartial; + if (UnrollAllowRemainder.getNumOccurrences() > 0) + UP.AllowRemainder = UnrollAllowRemainder; + if (UnrollRuntime.getNumOccurrences() > 0) + UP.Runtime = UnrollRuntime; + if (UnrollMaxUpperBound == 0) + UP.UpperBound = false; + if (UnrollAllowPeeling.getNumOccurrences() > 0) + UP.AllowPeeling = UnrollAllowPeeling; + if (UnrollUnrollRemainder.getNumOccurrences() > 0) + UP.UnrollRemainder = UnrollUnrollRemainder; + + // Apply user values provided by argument + if (UserThreshold.hasValue()) { + UP.Threshold = *UserThreshold; + UP.PartialThreshold = *UserThreshold; + } + if (UserCount.hasValue()) + UP.Count = *UserCount; + if (UserAllowPartial.hasValue()) + UP.Partial = *UserAllowPartial; + if (UserRuntime.hasValue()) + UP.Runtime = *UserRuntime; + if (UserUpperBound.hasValue()) + UP.UpperBound = *UserUpperBound; + if (UserAllowPeeling.hasValue()) + UP.AllowPeeling = *UserAllowPeeling; + + return UP; +} + +namespace { + +/// A struct to densely store the state of an instruction after unrolling at +/// each iteration. +/// +/// This is designed to work like a tuple of <Instruction *, int> for the +/// purposes of hashing and lookup, but to be able to associate two boolean +/// states with each key. +struct UnrolledInstState { + Instruction *I; + int Iteration : 30; + unsigned IsFree : 1; + unsigned IsCounted : 1; +}; + +/// Hashing and equality testing for a set of the instruction states. +struct UnrolledInstStateKeyInfo { + using PtrInfo = DenseMapInfo<Instruction *>; + using PairInfo = DenseMapInfo<std::pair<Instruction *, int>>; + + static inline UnrolledInstState getEmptyKey() { + return {PtrInfo::getEmptyKey(), 0, 0, 0}; + } + + static inline UnrolledInstState getTombstoneKey() { + return {PtrInfo::getTombstoneKey(), 0, 0, 0}; + } + + static inline unsigned getHashValue(const UnrolledInstState &S) { + return PairInfo::getHashValue({S.I, S.Iteration}); + } + + static inline bool isEqual(const UnrolledInstState &LHS, + const UnrolledInstState &RHS) { + return PairInfo::isEqual({LHS.I, LHS.Iteration}, {RHS.I, RHS.Iteration}); + } +}; + +struct EstimatedUnrollCost { + /// \brief The estimated cost after unrolling. + unsigned UnrolledCost; + + /// \brief The estimated dynamic cost of executing the instructions in the + /// rolled form. + unsigned RolledDynamicCost; +}; + +} // end anonymous namespace + +/// \brief Figure out if the loop is worth full unrolling. +/// +/// Complete loop unrolling can make some loads constant, and we need to know +/// if that would expose any further optimization opportunities. This routine +/// estimates this optimization. It computes cost of unrolled loop +/// (UnrolledCost) and dynamic cost of the original loop (RolledDynamicCost). By +/// dynamic cost we mean that we won't count costs of blocks that are known not +/// to be executed (i.e. if we have a branch in the loop and we know that at the +/// given iteration its condition would be resolved to true, we won't add up the +/// cost of the 'false'-block). +/// \returns Optional value, holding the RolledDynamicCost and UnrolledCost. If +/// the analysis failed (no benefits expected from the unrolling, or the loop is +/// too big to analyze), the returned value is None. +static Optional<EstimatedUnrollCost> +analyzeLoopUnrollCost(const Loop *L, unsigned TripCount, DominatorTree &DT, + ScalarEvolution &SE, const TargetTransformInfo &TTI, + unsigned MaxUnrolledLoopSize) { + // We want to be able to scale offsets by the trip count and add more offsets + // to them without checking for overflows, and we already don't want to + // analyze *massive* trip counts, so we force the max to be reasonably small. + assert(UnrollMaxIterationsCountToAnalyze < + (unsigned)(std::numeric_limits<int>::max() / 2) && + "The unroll iterations max is too large!"); + + // Only analyze inner loops. We can't properly estimate cost of nested loops + // and we won't visit inner loops again anyway. + if (!L->empty()) + return None; + + // Don't simulate loops with a big or unknown tripcount + if (!UnrollMaxIterationsCountToAnalyze || !TripCount || + TripCount > UnrollMaxIterationsCountToAnalyze) + return None; + + SmallSetVector<BasicBlock *, 16> BBWorklist; + SmallSetVector<std::pair<BasicBlock *, BasicBlock *>, 4> ExitWorklist; + DenseMap<Value *, Constant *> SimplifiedValues; + SmallVector<std::pair<Value *, Constant *>, 4> SimplifiedInputValues; + + // The estimated cost of the unrolled form of the loop. We try to estimate + // this by simplifying as much as we can while computing the estimate. + unsigned UnrolledCost = 0; + + // We also track the estimated dynamic (that is, actually executed) cost in + // the rolled form. This helps identify cases when the savings from unrolling + // aren't just exposing dead control flows, but actual reduced dynamic + // instructions due to the simplifications which we expect to occur after + // unrolling. + unsigned RolledDynamicCost = 0; + + // We track the simplification of each instruction in each iteration. We use + // this to recursively merge costs into the unrolled cost on-demand so that + // we don't count the cost of any dead code. This is essentially a map from + // <instruction, int> to <bool, bool>, but stored as a densely packed struct. + DenseSet<UnrolledInstState, UnrolledInstStateKeyInfo> InstCostMap; + + // A small worklist used to accumulate cost of instructions from each + // observable and reached root in the loop. + SmallVector<Instruction *, 16> CostWorklist; + + // PHI-used worklist used between iterations while accumulating cost. + SmallVector<Instruction *, 4> PHIUsedList; + + // Helper function to accumulate cost for instructions in the loop. + auto AddCostRecursively = [&](Instruction &RootI, int Iteration) { + assert(Iteration >= 0 && "Cannot have a negative iteration!"); + assert(CostWorklist.empty() && "Must start with an empty cost list"); + assert(PHIUsedList.empty() && "Must start with an empty phi used list"); + CostWorklist.push_back(&RootI); + for (;; --Iteration) { + do { + Instruction *I = CostWorklist.pop_back_val(); + + // InstCostMap only uses I and Iteration as a key, the other two values + // don't matter here. + auto CostIter = InstCostMap.find({I, Iteration, 0, 0}); + if (CostIter == InstCostMap.end()) + // If an input to a PHI node comes from a dead path through the loop + // we may have no cost data for it here. What that actually means is + // that it is free. + continue; + auto &Cost = *CostIter; + if (Cost.IsCounted) + // Already counted this instruction. + continue; + + // Mark that we are counting the cost of this instruction now. + Cost.IsCounted = true; + + // If this is a PHI node in the loop header, just add it to the PHI set. + if (auto *PhiI = dyn_cast<PHINode>(I)) + if (PhiI->getParent() == L->getHeader()) { + assert(Cost.IsFree && "Loop PHIs shouldn't be evaluated as they " + "inherently simplify during unrolling."); + if (Iteration == 0) + continue; + + // Push the incoming value from the backedge into the PHI used list + // if it is an in-loop instruction. We'll use this to populate the + // cost worklist for the next iteration (as we count backwards). + if (auto *OpI = dyn_cast<Instruction>( + PhiI->getIncomingValueForBlock(L->getLoopLatch()))) + if (L->contains(OpI)) + PHIUsedList.push_back(OpI); + continue; + } + + // First accumulate the cost of this instruction. + if (!Cost.IsFree) { + UnrolledCost += TTI.getUserCost(I); + DEBUG(dbgs() << "Adding cost of instruction (iteration " << Iteration + << "): "); + DEBUG(I->dump()); + } + + // We must count the cost of every operand which is not free, + // recursively. If we reach a loop PHI node, simply add it to the set + // to be considered on the next iteration (backwards!). + for (Value *Op : I->operands()) { + // Check whether this operand is free due to being a constant or + // outside the loop. + auto *OpI = dyn_cast<Instruction>(Op); + if (!OpI || !L->contains(OpI)) + continue; + + // Otherwise accumulate its cost. + CostWorklist.push_back(OpI); + } + } while (!CostWorklist.empty()); + + if (PHIUsedList.empty()) + // We've exhausted the search. + break; + + assert(Iteration > 0 && + "Cannot track PHI-used values past the first iteration!"); + CostWorklist.append(PHIUsedList.begin(), PHIUsedList.end()); + PHIUsedList.clear(); + } + }; + + // Ensure that we don't violate the loop structure invariants relied on by + // this analysis. + assert(L->isLoopSimplifyForm() && "Must put loop into normal form first."); + assert(L->isLCSSAForm(DT) && + "Must have loops in LCSSA form to track live-out values."); + + DEBUG(dbgs() << "Starting LoopUnroll profitability analysis...\n"); + + // Simulate execution of each iteration of the loop counting instructions, + // which would be simplified. + // Since the same load will take different values on different iterations, + // we literally have to go through all loop's iterations. + for (unsigned Iteration = 0; Iteration < TripCount; ++Iteration) { + DEBUG(dbgs() << " Analyzing iteration " << Iteration << "\n"); + + // Prepare for the iteration by collecting any simplified entry or backedge + // inputs. + for (Instruction &I : *L->getHeader()) { + auto *PHI = dyn_cast<PHINode>(&I); + if (!PHI) + break; + + // The loop header PHI nodes must have exactly two input: one from the + // loop preheader and one from the loop latch. + assert( + PHI->getNumIncomingValues() == 2 && + "Must have an incoming value only for the preheader and the latch."); + + Value *V = PHI->getIncomingValueForBlock( + Iteration == 0 ? L->getLoopPreheader() : L->getLoopLatch()); + Constant *C = dyn_cast<Constant>(V); + if (Iteration != 0 && !C) + C = SimplifiedValues.lookup(V); + if (C) + SimplifiedInputValues.push_back({PHI, C}); + } + + // Now clear and re-populate the map for the next iteration. + SimplifiedValues.clear(); + while (!SimplifiedInputValues.empty()) + SimplifiedValues.insert(SimplifiedInputValues.pop_back_val()); + + UnrolledInstAnalyzer Analyzer(Iteration, SimplifiedValues, SE, L); + + BBWorklist.clear(); + BBWorklist.insert(L->getHeader()); + // Note that we *must not* cache the size, this loop grows the worklist. + for (unsigned Idx = 0; Idx != BBWorklist.size(); ++Idx) { + BasicBlock *BB = BBWorklist[Idx]; + + // Visit all instructions in the given basic block and try to simplify + // it. We don't change the actual IR, just count optimization + // opportunities. + for (Instruction &I : *BB) { + if (isa<DbgInfoIntrinsic>(I)) + continue; + + // Track this instruction's expected baseline cost when executing the + // rolled loop form. + RolledDynamicCost += TTI.getUserCost(&I); + + // Visit the instruction to analyze its loop cost after unrolling, + // and if the visitor returns true, mark the instruction as free after + // unrolling and continue. + bool IsFree = Analyzer.visit(I); + bool Inserted = InstCostMap.insert({&I, (int)Iteration, + (unsigned)IsFree, + /*IsCounted*/ false}).second; + (void)Inserted; + assert(Inserted && "Cannot have a state for an unvisited instruction!"); + + if (IsFree) + continue; + + // Can't properly model a cost of a call. + // FIXME: With a proper cost model we should be able to do it. + if(isa<CallInst>(&I)) + return None; + + // If the instruction might have a side-effect recursively account for + // the cost of it and all the instructions leading up to it. + if (I.mayHaveSideEffects()) + AddCostRecursively(I, Iteration); + + // If unrolled body turns out to be too big, bail out. + if (UnrolledCost > MaxUnrolledLoopSize) { + DEBUG(dbgs() << " Exceeded threshold.. exiting.\n" + << " UnrolledCost: " << UnrolledCost + << ", MaxUnrolledLoopSize: " << MaxUnrolledLoopSize + << "\n"); + return None; + } + } + + TerminatorInst *TI = BB->getTerminator(); + + // Add in the live successors by first checking whether we have terminator + // that may be simplified based on the values simplified by this call. + BasicBlock *KnownSucc = nullptr; + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + if (BI->isConditional()) { + if (Constant *SimpleCond = + SimplifiedValues.lookup(BI->getCondition())) { + // Just take the first successor if condition is undef + if (isa<UndefValue>(SimpleCond)) + KnownSucc = BI->getSuccessor(0); + else if (ConstantInt *SimpleCondVal = + dyn_cast<ConstantInt>(SimpleCond)) + KnownSucc = BI->getSuccessor(SimpleCondVal->isZero() ? 1 : 0); + } + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + if (Constant *SimpleCond = + SimplifiedValues.lookup(SI->getCondition())) { + // Just take the first successor if condition is undef + if (isa<UndefValue>(SimpleCond)) + KnownSucc = SI->getSuccessor(0); + else if (ConstantInt *SimpleCondVal = + dyn_cast<ConstantInt>(SimpleCond)) + KnownSucc = SI->findCaseValue(SimpleCondVal)->getCaseSuccessor(); + } + } + if (KnownSucc) { + if (L->contains(KnownSucc)) + BBWorklist.insert(KnownSucc); + else + ExitWorklist.insert({BB, KnownSucc}); + continue; + } + + // Add BB's successors to the worklist. + for (BasicBlock *Succ : successors(BB)) + if (L->contains(Succ)) + BBWorklist.insert(Succ); + else + ExitWorklist.insert({BB, Succ}); + AddCostRecursively(*TI, Iteration); + } + + // If we found no optimization opportunities on the first iteration, we + // won't find them on later ones too. + if (UnrolledCost == RolledDynamicCost) { + DEBUG(dbgs() << " No opportunities found.. exiting.\n" + << " UnrolledCost: " << UnrolledCost << "\n"); + return None; + } + } + + while (!ExitWorklist.empty()) { + BasicBlock *ExitingBB, *ExitBB; + std::tie(ExitingBB, ExitBB) = ExitWorklist.pop_back_val(); + + for (Instruction &I : *ExitBB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + break; + + Value *Op = PN->getIncomingValueForBlock(ExitingBB); + if (auto *OpI = dyn_cast<Instruction>(Op)) + if (L->contains(OpI)) + AddCostRecursively(*OpI, TripCount - 1); + } + } + + DEBUG(dbgs() << "Analysis finished:\n" + << "UnrolledCost: " << UnrolledCost << ", " + << "RolledDynamicCost: " << RolledDynamicCost << "\n"); + return {{UnrolledCost, RolledDynamicCost}}; +} + +/// ApproximateLoopSize - Approximate the size of the loop. +static unsigned ApproximateLoopSize(const Loop *L, unsigned &NumCalls, + bool &NotDuplicatable, bool &Convergent, + const TargetTransformInfo &TTI, + AssumptionCache *AC, unsigned BEInsns) { + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + + CodeMetrics Metrics; + for (BasicBlock *BB : L->blocks()) + Metrics.analyzeBasicBlock(BB, TTI, EphValues); + NumCalls = Metrics.NumInlineCandidates; + NotDuplicatable = Metrics.notDuplicatable; + Convergent = Metrics.convergent; + + unsigned LoopSize = Metrics.NumInsts; + + // Don't allow an estimate of size zero. This would allows unrolling of loops + // with huge iteration counts, which is a compile time problem even if it's + // not a problem for code quality. Also, the code using this size may assume + // that each loop has at least three instructions (likely a conditional + // branch, a comparison feeding that branch, and some kind of loop increment + // feeding that comparison instruction). + LoopSize = std::max(LoopSize, BEInsns + 1); + + return LoopSize; +} + +// Returns the loop hint metadata node with the given name (for example, +// "llvm.loop.unroll.count"). If no such metadata node exists, then nullptr is +// returned. +static MDNode *GetUnrollMetadataForLoop(const Loop *L, StringRef Name) { + if (MDNode *LoopID = L->getLoopID()) + return GetUnrollMetadata(LoopID, Name); + return nullptr; +} + +// Returns true if the loop has an unroll(full) pragma. +static bool HasUnrollFullPragma(const Loop *L) { + return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.full"); +} + +// Returns true if the loop has an unroll(enable) pragma. This metadata is used +// for both "#pragma unroll" and "#pragma clang loop unroll(enable)" directives. +static bool HasUnrollEnablePragma(const Loop *L) { + return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.enable"); +} + +// Returns true if the loop has an unroll(disable) pragma. +static bool HasUnrollDisablePragma(const Loop *L) { + return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.disable"); +} + +// Returns true if the loop has an runtime unroll(disable) pragma. +static bool HasRuntimeUnrollDisablePragma(const Loop *L) { + return GetUnrollMetadataForLoop(L, "llvm.loop.unroll.runtime.disable"); +} + +// If loop has an unroll_count pragma return the (necessarily +// positive) value from the pragma. Otherwise return 0. +static unsigned UnrollCountPragmaValue(const Loop *L) { + MDNode *MD = GetUnrollMetadataForLoop(L, "llvm.loop.unroll.count"); + if (MD) { + assert(MD->getNumOperands() == 2 && + "Unroll count hint metadata should have two operands."); + unsigned Count = + mdconst::extract<ConstantInt>(MD->getOperand(1))->getZExtValue(); + assert(Count >= 1 && "Unroll count must be positive."); + return Count; + } + return 0; +} + +// Computes the boosting factor for complete unrolling. +// If fully unrolling the loop would save a lot of RolledDynamicCost, it would +// be beneficial to fully unroll the loop even if unrolledcost is large. We +// use (RolledDynamicCost / UnrolledCost) to model the unroll benefits to adjust +// the unroll threshold. +static unsigned getFullUnrollBoostingFactor(const EstimatedUnrollCost &Cost, + unsigned MaxPercentThresholdBoost) { + if (Cost.RolledDynamicCost >= std::numeric_limits<unsigned>::max() / 100) + return 100; + else if (Cost.UnrolledCost != 0) + // The boosting factor is RolledDynamicCost / UnrolledCost + return std::min(100 * Cost.RolledDynamicCost / Cost.UnrolledCost, + MaxPercentThresholdBoost); + else + return MaxPercentThresholdBoost; +} + +// Returns loop size estimation for unrolled loop. +static uint64_t getUnrolledLoopSize( + unsigned LoopSize, + TargetTransformInfo::UnrollingPreferences &UP) { + assert(LoopSize >= UP.BEInsns && "LoopSize should not be less than BEInsns!"); + return (uint64_t)(LoopSize - UP.BEInsns) * UP.Count + UP.BEInsns; +} + +// Returns true if unroll count was set explicitly. +// Calculates unroll count and writes it to UP.Count. +static bool computeUnrollCount( + Loop *L, const TargetTransformInfo &TTI, DominatorTree &DT, LoopInfo *LI, + ScalarEvolution &SE, OptimizationRemarkEmitter *ORE, unsigned &TripCount, + unsigned MaxTripCount, unsigned &TripMultiple, unsigned LoopSize, + TargetTransformInfo::UnrollingPreferences &UP, bool &UseUpperBound) { + // Check for explicit Count. + // 1st priority is unroll count set by "unroll-count" option. + bool UserUnrollCount = UnrollCount.getNumOccurrences() > 0; + if (UserUnrollCount) { + UP.Count = UnrollCount; + UP.AllowExpensiveTripCount = true; + UP.Force = true; + if (UP.AllowRemainder && getUnrolledLoopSize(LoopSize, UP) < UP.Threshold) + return true; + } + + // 2nd priority is unroll count set by pragma. + unsigned PragmaCount = UnrollCountPragmaValue(L); + if (PragmaCount > 0) { + UP.Count = PragmaCount; + UP.Runtime = true; + UP.AllowExpensiveTripCount = true; + UP.Force = true; + if (UP.AllowRemainder && + getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold) + return true; + } + bool PragmaFullUnroll = HasUnrollFullPragma(L); + if (PragmaFullUnroll && TripCount != 0) { + UP.Count = TripCount; + if (getUnrolledLoopSize(LoopSize, UP) < PragmaUnrollThreshold) + return false; + } + + bool PragmaEnableUnroll = HasUnrollEnablePragma(L); + bool ExplicitUnroll = PragmaCount > 0 || PragmaFullUnroll || + PragmaEnableUnroll || UserUnrollCount; + + if (ExplicitUnroll && TripCount != 0) { + // If the loop has an unrolling pragma, we want to be more aggressive with + // unrolling limits. Set thresholds to at least the PragmaThreshold value + // which is larger than the default limits. + UP.Threshold = std::max<unsigned>(UP.Threshold, PragmaUnrollThreshold); + UP.PartialThreshold = + std::max<unsigned>(UP.PartialThreshold, PragmaUnrollThreshold); + } + + // 3rd priority is full unroll count. + // Full unroll makes sense only when TripCount or its upper bound could be + // statically calculated. + // Also we need to check if we exceed FullUnrollMaxCount. + // If using the upper bound to unroll, TripMultiple should be set to 1 because + // we do not know when loop may exit. + // MaxTripCount and ExactTripCount cannot both be non zero since we only + // compute the former when the latter is zero. + unsigned ExactTripCount = TripCount; + assert((ExactTripCount == 0 || MaxTripCount == 0) && + "ExtractTripCound and MaxTripCount cannot both be non zero."); + unsigned FullUnrollTripCount = ExactTripCount ? ExactTripCount : MaxTripCount; + UP.Count = FullUnrollTripCount; + if (FullUnrollTripCount && FullUnrollTripCount <= UP.FullUnrollMaxCount) { + // When computing the unrolled size, note that BEInsns are not replicated + // like the rest of the loop body. + if (getUnrolledLoopSize(LoopSize, UP) < UP.Threshold) { + UseUpperBound = (MaxTripCount == FullUnrollTripCount); + TripCount = FullUnrollTripCount; + TripMultiple = UP.UpperBound ? 1 : TripMultiple; + return ExplicitUnroll; + } else { + // The loop isn't that small, but we still can fully unroll it if that + // helps to remove a significant number of instructions. + // To check that, run additional analysis on the loop. + if (Optional<EstimatedUnrollCost> Cost = analyzeLoopUnrollCost( + L, FullUnrollTripCount, DT, SE, TTI, + UP.Threshold * UP.MaxPercentThresholdBoost / 100)) { + unsigned Boost = + getFullUnrollBoostingFactor(*Cost, UP.MaxPercentThresholdBoost); + if (Cost->UnrolledCost < UP.Threshold * Boost / 100) { + UseUpperBound = (MaxTripCount == FullUnrollTripCount); + TripCount = FullUnrollTripCount; + TripMultiple = UP.UpperBound ? 1 : TripMultiple; + return ExplicitUnroll; + } + } + } + } + + // 4th priority is loop peeling + computePeelCount(L, LoopSize, UP, TripCount); + if (UP.PeelCount) { + UP.Runtime = false; + UP.Count = 1; + return ExplicitUnroll; + } + + // 5th priority is partial unrolling. + // Try partial unroll only when TripCount could be staticaly calculated. + if (TripCount) { + UP.Partial |= ExplicitUnroll; + if (!UP.Partial) { + DEBUG(dbgs() << " will not try to unroll partially because " + << "-unroll-allow-partial not given\n"); + UP.Count = 0; + return false; + } + if (UP.Count == 0) + UP.Count = TripCount; + if (UP.PartialThreshold != NoThreshold) { + // Reduce unroll count to be modulo of TripCount for partial unrolling. + if (getUnrolledLoopSize(LoopSize, UP) > UP.PartialThreshold) + UP.Count = + (std::max(UP.PartialThreshold, UP.BEInsns + 1) - UP.BEInsns) / + (LoopSize - UP.BEInsns); + if (UP.Count > UP.MaxCount) + UP.Count = UP.MaxCount; + while (UP.Count != 0 && TripCount % UP.Count != 0) + UP.Count--; + if (UP.AllowRemainder && UP.Count <= 1) { + // If there is no Count that is modulo of TripCount, set Count to + // largest power-of-two factor that satisfies the threshold limit. + // As we'll create fixup loop, do the type of unrolling only if + // remainder loop is allowed. + UP.Count = UP.DefaultUnrollRuntimeCount; + while (UP.Count != 0 && + getUnrolledLoopSize(LoopSize, UP) > UP.PartialThreshold) + UP.Count >>= 1; + } + if (UP.Count < 2) { + if (PragmaEnableUnroll) + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "UnrollAsDirectedTooLarge", + L->getStartLoc(), L->getHeader()) + << "Unable to unroll loop as directed by unroll(enable) " + "pragma " + "because unrolled size is too large."; + }); + UP.Count = 0; + } + } else { + UP.Count = TripCount; + } + if (UP.Count > UP.MaxCount) + UP.Count = UP.MaxCount; + if ((PragmaFullUnroll || PragmaEnableUnroll) && TripCount && + UP.Count != TripCount) + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "FullUnrollAsDirectedTooLarge", + L->getStartLoc(), L->getHeader()) + << "Unable to fully unroll loop as directed by unroll pragma " + "because " + "unrolled size is too large."; + }); + return ExplicitUnroll; + } + assert(TripCount == 0 && + "All cases when TripCount is constant should be covered here."); + if (PragmaFullUnroll) + ORE->emit([&]() { + return OptimizationRemarkMissed( + DEBUG_TYPE, "CantFullUnrollAsDirectedRuntimeTripCount", + L->getStartLoc(), L->getHeader()) + << "Unable to fully unroll loop as directed by unroll(full) " + "pragma " + "because loop has a runtime trip count."; + }); + + // 6th priority is runtime unrolling. + // Don't unroll a runtime trip count loop when it is disabled. + if (HasRuntimeUnrollDisablePragma(L)) { + UP.Count = 0; + return false; + } + + // Check if the runtime trip count is too small when profile is available. + if (L->getHeader()->getParent()->hasProfileData()) { + if (auto ProfileTripCount = getLoopEstimatedTripCount(L)) { + if (*ProfileTripCount < FlatLoopTripCountThreshold) + return false; + else + UP.AllowExpensiveTripCount = true; + } + } + + // Reduce count based on the type of unrolling and the threshold values. + UP.Runtime |= PragmaEnableUnroll || PragmaCount > 0 || UserUnrollCount; + if (!UP.Runtime) { + DEBUG(dbgs() << " will not try to unroll loop with runtime trip count " + << "-unroll-runtime not given\n"); + UP.Count = 0; + return false; + } + if (UP.Count == 0) + UP.Count = UP.DefaultUnrollRuntimeCount; + + // Reduce unroll count to be the largest power-of-two factor of + // the original count which satisfies the threshold limit. + while (UP.Count != 0 && + getUnrolledLoopSize(LoopSize, UP) > UP.PartialThreshold) + UP.Count >>= 1; + +#ifndef NDEBUG + unsigned OrigCount = UP.Count; +#endif + + if (!UP.AllowRemainder && UP.Count != 0 && (TripMultiple % UP.Count) != 0) { + while (UP.Count != 0 && TripMultiple % UP.Count != 0) + UP.Count >>= 1; + DEBUG(dbgs() << "Remainder loop is restricted (that could architecture " + "specific or because the loop contains a convergent " + "instruction), so unroll count must divide the trip " + "multiple, " + << TripMultiple << ". Reducing unroll count from " + << OrigCount << " to " << UP.Count << ".\n"); + + using namespace ore; + + if (PragmaCount > 0 && !UP.AllowRemainder) + ORE->emit([&]() { + return OptimizationRemarkMissed(DEBUG_TYPE, + "DifferentUnrollCountFromDirected", + L->getStartLoc(), L->getHeader()) + << "Unable to unroll loop the number of times directed by " + "unroll_count pragma because remainder loop is restricted " + "(that could architecture specific or because the loop " + "contains a convergent instruction) and so must have an " + "unroll " + "count that divides the loop trip multiple of " + << NV("TripMultiple", TripMultiple) << ". Unrolling instead " + << NV("UnrollCount", UP.Count) << " time(s)."; + }); + } + + if (UP.Count > UP.MaxCount) + UP.Count = UP.MaxCount; + DEBUG(dbgs() << " partially unrolling with count: " << UP.Count << "\n"); + if (UP.Count < 2) + UP.Count = 0; + return ExplicitUnroll; +} + +static LoopUnrollResult tryToUnrollLoop( + Loop *L, DominatorTree &DT, LoopInfo *LI, ScalarEvolution &SE, + const TargetTransformInfo &TTI, AssumptionCache &AC, + OptimizationRemarkEmitter &ORE, bool PreserveLCSSA, int OptLevel, + Optional<unsigned> ProvidedCount, Optional<unsigned> ProvidedThreshold, + Optional<bool> ProvidedAllowPartial, Optional<bool> ProvidedRuntime, + Optional<bool> ProvidedUpperBound, Optional<bool> ProvidedAllowPeeling) { + DEBUG(dbgs() << "Loop Unroll: F[" << L->getHeader()->getParent()->getName() + << "] Loop %" << L->getHeader()->getName() << "\n"); + if (HasUnrollDisablePragma(L)) + return LoopUnrollResult::Unmodified; + if (!L->isLoopSimplifyForm()) { + DEBUG( + dbgs() << " Not unrolling loop which is not in loop-simplify form.\n"); + return LoopUnrollResult::Unmodified; + } + + unsigned NumInlineCandidates; + bool NotDuplicatable; + bool Convergent; + TargetTransformInfo::UnrollingPreferences UP = gatherUnrollingPreferences( + L, SE, TTI, OptLevel, ProvidedThreshold, ProvidedCount, + ProvidedAllowPartial, ProvidedRuntime, ProvidedUpperBound, + ProvidedAllowPeeling); + // Exit early if unrolling is disabled. + if (UP.Threshold == 0 && (!UP.Partial || UP.PartialThreshold == 0)) + return LoopUnrollResult::Unmodified; + unsigned LoopSize = ApproximateLoopSize( + L, NumInlineCandidates, NotDuplicatable, Convergent, TTI, &AC, UP.BEInsns); + DEBUG(dbgs() << " Loop Size = " << LoopSize << "\n"); + if (NotDuplicatable) { + DEBUG(dbgs() << " Not unrolling loop which contains non-duplicatable" + << " instructions.\n"); + return LoopUnrollResult::Unmodified; + } + if (NumInlineCandidates != 0) { + DEBUG(dbgs() << " Not unrolling loop with inlinable calls.\n"); + return LoopUnrollResult::Unmodified; + } + + // Find trip count and trip multiple if count is not available + unsigned TripCount = 0; + unsigned MaxTripCount = 0; + unsigned TripMultiple = 1; + // If there are multiple exiting blocks but one of them is the latch, use the + // latch for the trip count estimation. Otherwise insist on a single exiting + // block for the trip count estimation. + BasicBlock *ExitingBlock = L->getLoopLatch(); + if (!ExitingBlock || !L->isLoopExiting(ExitingBlock)) + ExitingBlock = L->getExitingBlock(); + if (ExitingBlock) { + TripCount = SE.getSmallConstantTripCount(L, ExitingBlock); + TripMultiple = SE.getSmallConstantTripMultiple(L, ExitingBlock); + } + + // If the loop contains a convergent operation, the prelude we'd add + // to do the first few instructions before we hit the unrolled loop + // is unsafe -- it adds a control-flow dependency to the convergent + // operation. Therefore restrict remainder loop (try unrollig without). + // + // TODO: This is quite conservative. In practice, convergent_op() + // is likely to be called unconditionally in the loop. In this + // case, the program would be ill-formed (on most architectures) + // unless n were the same on all threads in a thread group. + // Assuming n is the same on all threads, any kind of unrolling is + // safe. But currently llvm's notion of convergence isn't powerful + // enough to express this. + if (Convergent) + UP.AllowRemainder = false; + + // Try to find the trip count upper bound if we cannot find the exact trip + // count. + bool MaxOrZero = false; + if (!TripCount) { + MaxTripCount = SE.getSmallConstantMaxTripCount(L); + MaxOrZero = SE.isBackedgeTakenCountMaxOrZero(L); + // We can unroll by the upper bound amount if it's generally allowed or if + // we know that the loop is executed either the upper bound or zero times. + // (MaxOrZero unrolling keeps only the first loop test, so the number of + // loop tests remains the same compared to the non-unrolled version, whereas + // the generic upper bound unrolling keeps all but the last loop test so the + // number of loop tests goes up which may end up being worse on targets with + // constriained branch predictor resources so is controlled by an option.) + // In addition we only unroll small upper bounds. + if (!(UP.UpperBound || MaxOrZero) || MaxTripCount > UnrollMaxUpperBound) { + MaxTripCount = 0; + } + } + + // computeUnrollCount() decides whether it is beneficial to use upper bound to + // fully unroll the loop. + bool UseUpperBound = false; + bool IsCountSetExplicitly = + computeUnrollCount(L, TTI, DT, LI, SE, &ORE, TripCount, MaxTripCount, + TripMultiple, LoopSize, UP, UseUpperBound); + if (!UP.Count) + return LoopUnrollResult::Unmodified; + // Unroll factor (Count) must be less or equal to TripCount. + if (TripCount && UP.Count > TripCount) + UP.Count = TripCount; + + // Unroll the loop. + LoopUnrollResult UnrollResult = UnrollLoop( + L, UP.Count, TripCount, UP.Force, UP.Runtime, UP.AllowExpensiveTripCount, + UseUpperBound, MaxOrZero, TripMultiple, UP.PeelCount, UP.UnrollRemainder, + LI, &SE, &DT, &AC, &ORE, PreserveLCSSA); + if (UnrollResult == LoopUnrollResult::Unmodified) + return LoopUnrollResult::Unmodified; + + // If loop has an unroll count pragma or unrolled by explicitly set count + // mark loop as unrolled to prevent unrolling beyond that requested. + // If the loop was peeled, we already "used up" the profile information + // we had, so we don't want to unroll or peel again. + if (UnrollResult != LoopUnrollResult::FullyUnrolled && + (IsCountSetExplicitly || UP.PeelCount)) + L->setLoopAlreadyUnrolled(); + + return UnrollResult; +} + +namespace { + +class LoopUnroll : public LoopPass { +public: + static char ID; // Pass ID, replacement for typeid + + int OptLevel; + Optional<unsigned> ProvidedCount; + Optional<unsigned> ProvidedThreshold; + Optional<bool> ProvidedAllowPartial; + Optional<bool> ProvidedRuntime; + Optional<bool> ProvidedUpperBound; + Optional<bool> ProvidedAllowPeeling; + + LoopUnroll(int OptLevel = 2, Optional<unsigned> Threshold = None, + Optional<unsigned> Count = None, + Optional<bool> AllowPartial = None, Optional<bool> Runtime = None, + Optional<bool> UpperBound = None, + Optional<bool> AllowPeeling = None) + : LoopPass(ID), OptLevel(OptLevel), ProvidedCount(std::move(Count)), + ProvidedThreshold(Threshold), ProvidedAllowPartial(AllowPartial), + ProvidedRuntime(Runtime), ProvidedUpperBound(UpperBound), + ProvidedAllowPeeling(AllowPeeling) { + initializeLoopUnrollPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override { + if (skipLoop(L)) + return false; + + Function &F = *L->getHeader()->getParent(); + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + ScalarEvolution &SE = getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + const TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + // For the old PM, we can't use OptimizationRemarkEmitter as an analysis + // pass. Function analyses need to be preserved across loop transformations + // but ORE cannot be preserved (see comment before the pass definition). + OptimizationRemarkEmitter ORE(&F); + bool PreserveLCSSA = mustPreserveAnalysisID(LCSSAID); + + LoopUnrollResult Result = tryToUnrollLoop( + L, DT, LI, SE, TTI, AC, ORE, PreserveLCSSA, OptLevel, ProvidedCount, + ProvidedThreshold, ProvidedAllowPartial, ProvidedRuntime, + ProvidedUpperBound, ProvidedAllowPeeling); + + if (Result == LoopUnrollResult::FullyUnrolled) + LPM.markLoopAsDeleted(*L); + + return Result != LoopUnrollResult::Unmodified; + } + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG... + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + // FIXME: Loop passes are required to preserve domtree, and for now we just + // recreate dom info if anything gets unrolled. + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +char LoopUnroll::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopUnroll, "loop-unroll", "Unroll loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(LoopUnroll, "loop-unroll", "Unroll loops", false, false) + +Pass *llvm::createLoopUnrollPass(int OptLevel, int Threshold, int Count, + int AllowPartial, int Runtime, int UpperBound, + int AllowPeeling) { + // TODO: It would make more sense for this function to take the optionals + // directly, but that's dangerous since it would silently break out of tree + // callers. + return new LoopUnroll( + OptLevel, Threshold == -1 ? None : Optional<unsigned>(Threshold), + Count == -1 ? None : Optional<unsigned>(Count), + AllowPartial == -1 ? None : Optional<bool>(AllowPartial), + Runtime == -1 ? None : Optional<bool>(Runtime), + UpperBound == -1 ? None : Optional<bool>(UpperBound), + AllowPeeling == -1 ? None : Optional<bool>(AllowPeeling)); +} + +Pass *llvm::createSimpleLoopUnrollPass(int OptLevel) { + return createLoopUnrollPass(OptLevel, -1, -1, 0, 0, 0, 0); +} + +PreservedAnalyses LoopFullUnrollPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &Updater) { + const auto &FAM = + AM.getResult<FunctionAnalysisManagerLoopProxy>(L, AR).getManager(); + Function *F = L.getHeader()->getParent(); + + auto *ORE = FAM.getCachedResult<OptimizationRemarkEmitterAnalysis>(*F); + // FIXME: This should probably be optional rather than required. + if (!ORE) + report_fatal_error( + "LoopFullUnrollPass: OptimizationRemarkEmitterAnalysis not " + "cached at a higher level"); + + // Keep track of the previous loop structure so we can identify new loops + // created by unrolling. + Loop *ParentL = L.getParentLoop(); + SmallPtrSet<Loop *, 4> OldLoops; + if (ParentL) + OldLoops.insert(ParentL->begin(), ParentL->end()); + else + OldLoops.insert(AR.LI.begin(), AR.LI.end()); + + std::string LoopName = L.getName(); + + bool Changed = + tryToUnrollLoop(&L, AR.DT, &AR.LI, AR.SE, AR.TTI, AR.AC, *ORE, + /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, + /*Threshold*/ None, /*AllowPartial*/ false, + /*Runtime*/ false, /*UpperBound*/ false, + /*AllowPeeling*/ false) != LoopUnrollResult::Unmodified; + if (!Changed) + return PreservedAnalyses::all(); + + // The parent must not be damaged by unrolling! +#ifndef NDEBUG + if (ParentL) + ParentL->verifyLoop(); +#endif + + // Unrolling can do several things to introduce new loops into a loop nest: + // - Full unrolling clones child loops within the current loop but then + // removes the current loop making all of the children appear to be new + // sibling loops. + // + // When a new loop appears as a sibling loop after fully unrolling, + // its nesting structure has fundamentally changed and we want to revisit + // it to reflect that. + // + // When unrolling has removed the current loop, we need to tell the + // infrastructure that it is gone. + // + // Finally, we support a debugging/testing mode where we revisit child loops + // as well. These are not expected to require further optimizations as either + // they or the loop they were cloned from have been directly visited already. + // But the debugging mode allows us to check this assumption. + bool IsCurrentLoopValid = false; + SmallVector<Loop *, 4> SibLoops; + if (ParentL) + SibLoops.append(ParentL->begin(), ParentL->end()); + else + SibLoops.append(AR.LI.begin(), AR.LI.end()); + erase_if(SibLoops, [&](Loop *SibLoop) { + if (SibLoop == &L) { + IsCurrentLoopValid = true; + return true; + } + + // Otherwise erase the loop from the list if it was in the old loops. + return OldLoops.count(SibLoop) != 0; + }); + Updater.addSiblingLoops(SibLoops); + + if (!IsCurrentLoopValid) { + Updater.markLoopAsDeleted(L, LoopName); + } else { + // We can only walk child loops if the current loop remained valid. + if (UnrollRevisitChildLoops) { + // Walk *all* of the child loops. + SmallVector<Loop *, 4> ChildLoops(L.begin(), L.end()); + Updater.addChildLoops(ChildLoops); + } + } + + return getLoopPassPreservedAnalyses(); +} + +template <typename RangeT> +static SmallVector<Loop *, 8> appendLoopsToWorklist(RangeT &&Loops) { + SmallVector<Loop *, 8> Worklist; + // We use an internal worklist to build up the preorder traversal without + // recursion. + SmallVector<Loop *, 4> PreOrderLoops, PreOrderWorklist; + + for (Loop *RootL : Loops) { + assert(PreOrderLoops.empty() && "Must start with an empty preorder walk."); + assert(PreOrderWorklist.empty() && + "Must start with an empty preorder walk worklist."); + PreOrderWorklist.push_back(RootL); + do { + Loop *L = PreOrderWorklist.pop_back_val(); + PreOrderWorklist.append(L->begin(), L->end()); + PreOrderLoops.push_back(L); + } while (!PreOrderWorklist.empty()); + + Worklist.append(PreOrderLoops.begin(), PreOrderLoops.end()); + PreOrderLoops.clear(); + } + return Worklist; +} + +PreservedAnalyses LoopUnrollPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &SE = AM.getResult<ScalarEvolutionAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + LoopAnalysisManager *LAM = nullptr; + if (auto *LAMProxy = AM.getCachedResult<LoopAnalysisManagerFunctionProxy>(F)) + LAM = &LAMProxy->getManager(); + + const ModuleAnalysisManager &MAM = + AM.getResult<ModuleAnalysisManagerFunctionProxy>(F).getManager(); + ProfileSummaryInfo *PSI = + MAM.getCachedResult<ProfileSummaryAnalysis>(*F.getParent()); + + bool Changed = false; + + // The unroller requires loops to be in simplified form, and also needs LCSSA. + // Since simplification may add new inner loops, it has to run before the + // legality and profitability checks. This means running the loop unroller + // will simplify all loops, regardless of whether anything end up being + // unrolled. + for (auto &L : LI) { + Changed |= simplifyLoop(L, &DT, &LI, &SE, &AC, false /* PreserveLCSSA */); + Changed |= formLCSSARecursively(*L, DT, &LI, &SE); + } + + SmallVector<Loop *, 8> Worklist = appendLoopsToWorklist(LI); + + while (!Worklist.empty()) { + // Because the LoopInfo stores the loops in RPO, we walk the worklist + // from back to front so that we work forward across the CFG, which + // for unrolling is only needed to get optimization remarks emitted in + // a forward order. + Loop &L = *Worklist.pop_back_val(); +#ifndef NDEBUG + Loop *ParentL = L.getParentLoop(); +#endif + + // The API here is quite complex to call, but there are only two interesting + // states we support: partial and full (or "simple") unrolling. However, to + // enable these things we actually pass "None" in for the optional to avoid + // providing an explicit choice. + Optional<bool> AllowPartialParam, RuntimeParam, UpperBoundParam, + AllowPeeling; + // Check if the profile summary indicates that the profiled application + // has a huge working set size, in which case we disable peeling to avoid + // bloating it further. + if (PSI && PSI->hasHugeWorkingSetSize()) + AllowPeeling = false; + std::string LoopName = L.getName(); + LoopUnrollResult Result = + tryToUnrollLoop(&L, DT, &LI, SE, TTI, AC, ORE, + /*PreserveLCSSA*/ true, OptLevel, /*Count*/ None, + /*Threshold*/ None, AllowPartialParam, RuntimeParam, + UpperBoundParam, AllowPeeling); + Changed |= Result != LoopUnrollResult::Unmodified; + + // The parent must not be damaged by unrolling! +#ifndef NDEBUG + if (Result != LoopUnrollResult::Unmodified && ParentL) + ParentL->verifyLoop(); +#endif + + // Clear any cached analysis results for L if we removed it completely. + if (LAM && Result == LoopUnrollResult::FullyUnrolled) + LAM->clear(L, LoopName); + } + + if (!Changed) + return PreservedAnalyses::all(); + + return getLoopPassPreservedAnalyses(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp new file mode 100644 index 000000000000..f2405d9b0c03 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopUnswitch.cpp @@ -0,0 +1,1613 @@ +//===- LoopUnswitch.cpp - Hoist loop-invariant conditionals in loop -------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass transforms loops that contain branches on loop-invariant conditions +// to multiple loops. For example, it turns the left into the right code: +// +// for (...) if (lic) +// A for (...) +// if (lic) A; B; C +// B else +// C for (...) +// A; C +// +// This can increase the size of the code exponentially (doubling it every time +// a loop is unswitched) so we only unswitch if the resultant code will be +// smaller than a threshold. +// +// This pass expects LICM to be run before it to hoist invariant conditions out +// of the loop, to make the unswitching opportunity obvious. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <map> +#include <set> +#include <tuple> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "loop-unswitch" + +STATISTIC(NumBranches, "Number of branches unswitched"); +STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumGuards, "Number of guards unswitched"); +STATISTIC(NumSelects , "Number of selects unswitched"); +STATISTIC(NumTrivial , "Number of unswitches that are trivial"); +STATISTIC(NumSimplify, "Number of simplifications of unswitched code"); +STATISTIC(TotalInsts, "Total number of instructions analyzed"); + +// The specific value of 100 here was chosen based only on intuition and a +// few specific examples. +static cl::opt<unsigned> +Threshold("loop-unswitch-threshold", cl::desc("Max loop size to unswitch"), + cl::init(100), cl::Hidden); + +namespace { + + class LUAnalysisCache { + using UnswitchedValsMap = + DenseMap<const SwitchInst *, SmallPtrSet<const Value *, 8>>; + using UnswitchedValsIt = UnswitchedValsMap::iterator; + + struct LoopProperties { + unsigned CanBeUnswitchedCount; + unsigned WasUnswitchedCount; + unsigned SizeEstimation; + UnswitchedValsMap UnswitchedVals; + }; + + // Here we use std::map instead of DenseMap, since we need to keep valid + // LoopProperties pointer for current loop for better performance. + using LoopPropsMap = std::map<const Loop *, LoopProperties>; + using LoopPropsMapIt = LoopPropsMap::iterator; + + LoopPropsMap LoopsProperties; + UnswitchedValsMap *CurLoopInstructions = nullptr; + LoopProperties *CurrentLoopProperties = nullptr; + + // A loop unswitching with an estimated cost above this threshold + // is not performed. MaxSize is turned into unswitching quota for + // the current loop, and reduced correspondingly, though note that + // the quota is returned by releaseMemory() when the loop has been + // processed, so that MaxSize will return to its previous + // value. So in most cases MaxSize will equal the Threshold flag + // when a new loop is processed. An exception to that is that + // MaxSize will have a smaller value while processing nested loops + // that were introduced due to loop unswitching of an outer loop. + // + // FIXME: The way that MaxSize works is subtle and depends on the + // pass manager processing loops and calling releaseMemory() in a + // specific order. It would be good to find a more straightforward + // way of doing what MaxSize does. + unsigned MaxSize; + + public: + LUAnalysisCache() : MaxSize(Threshold) {} + + // Analyze loop. Check its size, calculate is it possible to unswitch + // it. Returns true if we can unswitch this loop. + bool countLoop(const Loop *L, const TargetTransformInfo &TTI, + AssumptionCache *AC); + + // Clean all data related to given loop. + void forgetLoop(const Loop *L); + + // Mark case value as unswitched. + // Since SI instruction can be partly unswitched, in order to avoid + // extra unswitching in cloned loops keep track all unswitched values. + void setUnswitched(const SwitchInst *SI, const Value *V); + + // Check was this case value unswitched before or not. + bool isUnswitched(const SwitchInst *SI, const Value *V); + + // Returns true if another unswitching could be done within the cost + // threshold. + bool CostAllowsUnswitching(); + + // Clone all loop-unswitch related loop properties. + // Redistribute unswitching quotas. + // Note, that new loop data is stored inside the VMap. + void cloneData(const Loop *NewLoop, const Loop *OldLoop, + const ValueToValueMapTy &VMap); + }; + + class LoopUnswitch : public LoopPass { + LoopInfo *LI; // Loop information + LPPassManager *LPM; + AssumptionCache *AC; + + // Used to check if second loop needs processing after + // RewriteLoopBodyWithConditionConstant rewrites first loop. + std::vector<Loop*> LoopProcessWorklist; + + LUAnalysisCache BranchesInfo; + + bool OptimizeForSize; + bool redoLoop = false; + + Loop *currentLoop = nullptr; + DominatorTree *DT = nullptr; + BasicBlock *loopHeader = nullptr; + BasicBlock *loopPreheader = nullptr; + + bool SanitizeMemory; + LoopSafetyInfo SafetyInfo; + + // LoopBlocks contains all of the basic blocks of the loop, including the + // preheader of the loop, the body of the loop, and the exit blocks of the + // loop, in that order. + std::vector<BasicBlock*> LoopBlocks; + // NewBlocks contained cloned copy of basic blocks from LoopBlocks. + std::vector<BasicBlock*> NewBlocks; + + bool hasBranchDivergence; + + public: + static char ID; // Pass ID, replacement for typeid + + explicit LoopUnswitch(bool Os = false, bool hasBranchDivergence = false) + : LoopPass(ID), OptimizeForSize(Os), + hasBranchDivergence(hasBranchDivergence) { + initializeLoopUnswitchPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + bool processCurrentLoop(); + bool isUnreachableDueToPreviousUnswitching(BasicBlock *); + + /// This transformation requires natural loop information & requires that + /// loop preheaders be inserted into the CFG. + /// + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + if (hasBranchDivergence) + AU.addRequired<DivergenceAnalysis>(); + getLoopAnalysisUsage(AU); + } + + private: + void releaseMemory() override { + BranchesInfo.forgetLoop(currentLoop); + } + + void initLoopData() { + loopHeader = currentLoop->getHeader(); + loopPreheader = currentLoop->getLoopPreheader(); + } + + /// Split all of the edges from inside the loop to their exit blocks. + /// Update the appropriate Phi nodes as we do so. + void SplitExitEdges(Loop *L, + const SmallVectorImpl<BasicBlock *> &ExitBlocks); + + bool TryTrivialLoopUnswitch(bool &Changed); + + bool UnswitchIfProfitable(Value *LoopCond, Constant *Val, + TerminatorInst *TI = nullptr); + void UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, + BasicBlock *ExitBlock, TerminatorInst *TI); + void UnswitchNontrivialCondition(Value *LIC, Constant *OnVal, Loop *L, + TerminatorInst *TI); + + void RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, + Constant *Val, bool isEqual); + + void EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, + BasicBlock *TrueDest, + BasicBlock *FalseDest, + BranchInst *OldBranch, + TerminatorInst *TI); + + void SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L); + + /// Given that the Invariant is not equal to Val. Simplify instructions + /// in the loop. + Value *SimplifyInstructionWithNotEqual(Instruction *Inst, Value *Invariant, + Constant *Val); + }; + +} // end anonymous namespace + +// Analyze loop. Check its size, calculate is it possible to unswitch +// it. Returns true if we can unswitch this loop. +bool LUAnalysisCache::countLoop(const Loop *L, const TargetTransformInfo &TTI, + AssumptionCache *AC) { + LoopPropsMapIt PropsIt; + bool Inserted; + std::tie(PropsIt, Inserted) = + LoopsProperties.insert(std::make_pair(L, LoopProperties())); + + LoopProperties &Props = PropsIt->second; + + if (Inserted) { + // New loop. + + // Limit the number of instructions to avoid causing significant code + // expansion, and the number of basic blocks, to avoid loops with + // large numbers of branches which cause loop unswitching to go crazy. + // This is a very ad-hoc heuristic. + + SmallPtrSet<const Value *, 32> EphValues; + CodeMetrics::collectEphemeralValues(L, AC, EphValues); + + // FIXME: This is overly conservative because it does not take into + // consideration code simplification opportunities and code that can + // be shared by the resultant unswitched loops. + CodeMetrics Metrics; + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); I != E; + ++I) + Metrics.analyzeBasicBlock(*I, TTI, EphValues); + + Props.SizeEstimation = Metrics.NumInsts; + Props.CanBeUnswitchedCount = MaxSize / (Props.SizeEstimation); + Props.WasUnswitchedCount = 0; + MaxSize -= Props.SizeEstimation * Props.CanBeUnswitchedCount; + + if (Metrics.notDuplicatable) { + DEBUG(dbgs() << "NOT unswitching loop %" + << L->getHeader()->getName() << ", contents cannot be " + << "duplicated!\n"); + return false; + } + } + + // Be careful. This links are good only before new loop addition. + CurrentLoopProperties = &Props; + CurLoopInstructions = &Props.UnswitchedVals; + + return true; +} + +// Clean all data related to given loop. +void LUAnalysisCache::forgetLoop(const Loop *L) { + LoopPropsMapIt LIt = LoopsProperties.find(L); + + if (LIt != LoopsProperties.end()) { + LoopProperties &Props = LIt->second; + MaxSize += (Props.CanBeUnswitchedCount + Props.WasUnswitchedCount) * + Props.SizeEstimation; + LoopsProperties.erase(LIt); + } + + CurrentLoopProperties = nullptr; + CurLoopInstructions = nullptr; +} + +// Mark case value as unswitched. +// Since SI instruction can be partly unswitched, in order to avoid +// extra unswitching in cloned loops keep track all unswitched values. +void LUAnalysisCache::setUnswitched(const SwitchInst *SI, const Value *V) { + (*CurLoopInstructions)[SI].insert(V); +} + +// Check was this case value unswitched before or not. +bool LUAnalysisCache::isUnswitched(const SwitchInst *SI, const Value *V) { + return (*CurLoopInstructions)[SI].count(V); +} + +bool LUAnalysisCache::CostAllowsUnswitching() { + return CurrentLoopProperties->CanBeUnswitchedCount > 0; +} + +// Clone all loop-unswitch related loop properties. +// Redistribute unswitching quotas. +// Note, that new loop data is stored inside the VMap. +void LUAnalysisCache::cloneData(const Loop *NewLoop, const Loop *OldLoop, + const ValueToValueMapTy &VMap) { + LoopProperties &NewLoopProps = LoopsProperties[NewLoop]; + LoopProperties &OldLoopProps = *CurrentLoopProperties; + UnswitchedValsMap &Insts = OldLoopProps.UnswitchedVals; + + // Reallocate "can-be-unswitched quota" + + --OldLoopProps.CanBeUnswitchedCount; + ++OldLoopProps.WasUnswitchedCount; + NewLoopProps.WasUnswitchedCount = 0; + unsigned Quota = OldLoopProps.CanBeUnswitchedCount; + NewLoopProps.CanBeUnswitchedCount = Quota / 2; + OldLoopProps.CanBeUnswitchedCount = Quota - Quota / 2; + + NewLoopProps.SizeEstimation = OldLoopProps.SizeEstimation; + + // Clone unswitched values info: + // for new loop switches we clone info about values that was + // already unswitched and has redundant successors. + for (UnswitchedValsIt I = Insts.begin(); I != Insts.end(); ++I) { + const SwitchInst *OldInst = I->first; + Value *NewI = VMap.lookup(OldInst); + const SwitchInst *NewInst = cast_or_null<SwitchInst>(NewI); + assert(NewInst && "All instructions that are in SrcBB must be in VMap."); + + NewLoopProps.UnswitchedVals[NewInst] = OldLoopProps.UnswitchedVals[OldInst]; + } +} + +char LoopUnswitch::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopUnswitch, "loop-unswitch", "Unswitch loops", + false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) +INITIALIZE_PASS_END(LoopUnswitch, "loop-unswitch", "Unswitch loops", + false, false) + +Pass *llvm::createLoopUnswitchPass(bool Os, bool hasBranchDivergence) { + return new LoopUnswitch(Os, hasBranchDivergence); +} + +/// Operator chain lattice. +enum OperatorChain { + OC_OpChainNone, ///< There is no operator. + OC_OpChainOr, ///< There are only ORs. + OC_OpChainAnd, ///< There are only ANDs. + OC_OpChainMixed ///< There are ANDs and ORs. +}; + +/// Cond is a condition that occurs in L. If it is invariant in the loop, or has +/// an invariant piece, return the invariant. Otherwise, return null. +// +/// NOTE: FindLIVLoopCondition will not return a partial LIV by walking up a +/// mixed operator chain, as we can not reliably find a value which will simplify +/// the operator chain. If the chain is AND-only or OR-only, we can use 0 or ~0 +/// to simplify the chain. +/// +/// NOTE: In case a partial LIV and a mixed operator chain, we may be able to +/// simplify the condition itself to a loop variant condition, but at the +/// cost of creating an entirely new loop. +static Value *FindLIVLoopCondition(Value *Cond, Loop *L, bool &Changed, + OperatorChain &ParentChain, + DenseMap<Value *, Value *> &Cache) { + auto CacheIt = Cache.find(Cond); + if (CacheIt != Cache.end()) + return CacheIt->second; + + // We started analyze new instruction, increment scanned instructions counter. + ++TotalInsts; + + // We can never unswitch on vector conditions. + if (Cond->getType()->isVectorTy()) + return nullptr; + + // Constants should be folded, not unswitched on! + if (isa<Constant>(Cond)) return nullptr; + + // TODO: Handle: br (VARIANT|INVARIANT). + + // Hoist simple values out. + if (L->makeLoopInvariant(Cond, Changed)) { + Cache[Cond] = Cond; + return Cond; + } + + // Walk up the operator chain to find partial invariant conditions. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Cond)) + if (BO->getOpcode() == Instruction::And || + BO->getOpcode() == Instruction::Or) { + // Given the previous operator, compute the current operator chain status. + OperatorChain NewChain; + switch (ParentChain) { + case OC_OpChainNone: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainOr; + break; + case OC_OpChainOr: + NewChain = BO->getOpcode() == Instruction::Or ? OC_OpChainOr : + OC_OpChainMixed; + break; + case OC_OpChainAnd: + NewChain = BO->getOpcode() == Instruction::And ? OC_OpChainAnd : + OC_OpChainMixed; + break; + case OC_OpChainMixed: + NewChain = OC_OpChainMixed; + break; + } + + // If we reach a Mixed state, we do not want to keep walking up as we can not + // reliably find a value that will simplify the chain. With this check, we + // will return null on the first sight of mixed chain and the caller will + // either backtrack to find partial LIV in other operand or return null. + if (NewChain != OC_OpChainMixed) { + // Update the current operator chain type before we search up the chain. + ParentChain = NewChain; + // If either the left or right side is invariant, we can unswitch on this, + // which will cause the branch to go away in one loop and the condition to + // simplify in the other one. + if (Value *LHS = FindLIVLoopCondition(BO->getOperand(0), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = LHS; + return LHS; + } + // We did not manage to find a partial LIV in operand(0). Backtrack and try + // operand(1). + ParentChain = NewChain; + if (Value *RHS = FindLIVLoopCondition(BO->getOperand(1), L, Changed, + ParentChain, Cache)) { + Cache[Cond] = RHS; + return RHS; + } + } + } + + Cache[Cond] = nullptr; + return nullptr; +} + +/// Cond is a condition that occurs in L. If it is invariant in the loop, or has +/// an invariant piece, return the invariant along with the operator chain type. +/// Otherwise, return null. +static std::pair<Value *, OperatorChain> FindLIVLoopCondition(Value *Cond, + Loop *L, + bool &Changed) { + DenseMap<Value *, Value *> Cache; + OperatorChain OpChain = OC_OpChainNone; + Value *FCond = FindLIVLoopCondition(Cond, L, Changed, OpChain, Cache); + + // In case we do find a LIV, it can not be obtained by walking up a mixed + // operator chain. + assert((!FCond || OpChain != OC_OpChainMixed) && + "Do not expect a partial LIV with mixed operator chain"); + return {FCond, OpChain}; +} + +bool LoopUnswitch::runOnLoop(Loop *L, LPPassManager &LPM_Ref) { + if (skipLoop(L)) + return false; + + AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache( + *L->getHeader()->getParent()); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + LPM = &LPM_Ref; + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + currentLoop = L; + Function *F = currentLoop->getHeader()->getParent(); + + SanitizeMemory = F->hasFnAttribute(Attribute::SanitizeMemory); + if (SanitizeMemory) + computeLoopSafetyInfo(&SafetyInfo, L); + + bool Changed = false; + do { + assert(currentLoop->isLCSSAForm(*DT)); + redoLoop = false; + Changed |= processCurrentLoop(); + } while(redoLoop); + + return Changed; +} + +// Return true if the BasicBlock BB is unreachable from the loop header. +// Return false, otherwise. +bool LoopUnswitch::isUnreachableDueToPreviousUnswitching(BasicBlock *BB) { + auto *Node = DT->getNode(BB)->getIDom(); + BasicBlock *DomBB = Node->getBlock(); + while (currentLoop->contains(DomBB)) { + BranchInst *BInst = dyn_cast<BranchInst>(DomBB->getTerminator()); + + Node = DT->getNode(DomBB)->getIDom(); + DomBB = Node->getBlock(); + + if (!BInst || !BInst->isConditional()) + continue; + + Value *Cond = BInst->getCondition(); + if (!isa<ConstantInt>(Cond)) + continue; + + BasicBlock *UnreachableSucc = + Cond == ConstantInt::getTrue(Cond->getContext()) + ? BInst->getSuccessor(1) + : BInst->getSuccessor(0); + + if (DT->dominates(UnreachableSucc, BB)) + return true; + } + return false; +} + +/// FIXME: Remove this workaround when freeze related patches are done. +/// LoopUnswitch and Equality propagation in GVN have discrepancy about +/// whether branch on undef/poison has undefine behavior. Here it is to +/// rule out some common cases that we found such discrepancy already +/// causing problems. Detail could be found in PR31652. Note if the +/// func returns true, it is unsafe. But if it is false, it doesn't mean +/// it is necessarily safe. +static bool EqualityPropUnSafe(Value &LoopCond) { + ICmpInst *CI = dyn_cast<ICmpInst>(&LoopCond); + if (!CI || !CI->isEquality()) + return false; + + Value *LHS = CI->getOperand(0); + Value *RHS = CI->getOperand(1); + if (isa<UndefValue>(LHS) || isa<UndefValue>(RHS)) + return true; + + auto hasUndefInPHI = [](PHINode &PN) { + for (Value *Opd : PN.incoming_values()) { + if (isa<UndefValue>(Opd)) + return true; + } + return false; + }; + PHINode *LPHI = dyn_cast<PHINode>(LHS); + PHINode *RPHI = dyn_cast<PHINode>(RHS); + if ((LPHI && hasUndefInPHI(*LPHI)) || (RPHI && hasUndefInPHI(*RPHI))) + return true; + + auto hasUndefInSelect = [](SelectInst &SI) { + if (isa<UndefValue>(SI.getTrueValue()) || + isa<UndefValue>(SI.getFalseValue())) + return true; + return false; + }; + SelectInst *LSI = dyn_cast<SelectInst>(LHS); + SelectInst *RSI = dyn_cast<SelectInst>(RHS); + if ((LSI && hasUndefInSelect(*LSI)) || (RSI && hasUndefInSelect(*RSI))) + return true; + return false; +} + +/// Do actual work and unswitch loop if possible and profitable. +bool LoopUnswitch::processCurrentLoop() { + bool Changed = false; + + initLoopData(); + + // If LoopSimplify was unable to form a preheader, don't do any unswitching. + if (!loopPreheader) + return false; + + // Loops with indirectbr cannot be cloned. + if (!currentLoop->isSafeToClone()) + return false; + + // Without dedicated exits, splitting the exit edge may fail. + if (!currentLoop->hasDedicatedExits()) + return false; + + LLVMContext &Context = loopHeader->getContext(); + + // Analyze loop cost, and stop unswitching if loop content can not be duplicated. + if (!BranchesInfo.countLoop( + currentLoop, getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *currentLoop->getHeader()->getParent()), + AC)) + return false; + + // Try trivial unswitch first before loop over other basic blocks in the loop. + if (TryTrivialLoopUnswitch(Changed)) { + return true; + } + + // Run through the instructions in the loop, keeping track of three things: + // + // - That we do not unswitch loops containing convergent operations, as we + // might be making them control dependent on the unswitch value when they + // were not before. + // FIXME: This could be refined to only bail if the convergent operation is + // not already control-dependent on the unswitch value. + // + // - That basic blocks in the loop contain invokes whose predecessor edges we + // cannot split. + // + // - The set of guard intrinsics encountered (these are non terminator + // instructions that are also profitable to be unswitched). + + SmallVector<IntrinsicInst *, 4> Guards; + + for (const auto BB : currentLoop->blocks()) { + for (auto &I : *BB) { + auto CS = CallSite(&I); + if (!CS) continue; + if (CS.hasFnAttr(Attribute::Convergent)) + return false; + if (auto *II = dyn_cast<InvokeInst>(&I)) + if (!II->getUnwindDest()->canSplitPredecessors()) + return false; + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::experimental_guard) + Guards.push_back(II); + } + } + + // Do not do non-trivial unswitch while optimizing for size. + // FIXME: Use Function::optForSize(). + if (OptimizeForSize || + loopHeader->getParent()->hasFnAttribute(Attribute::OptimizeForSize)) + return false; + + for (IntrinsicInst *Guard : Guards) { + Value *LoopCond = + FindLIVLoopCondition(Guard->getOperand(0), currentLoop, Changed).first; + if (LoopCond && + UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context))) { + // NB! Unswitching (if successful) could have erased some of the + // instructions in Guards leaving dangling pointers there. This is fine + // because we're returning now, and won't look at Guards again. + ++NumGuards; + return true; + } + } + + // Loop over all of the basic blocks in the loop. If we find an interior + // block that is branching on a loop-invariant condition, we can unswitch this + // loop. + for (Loop::block_iterator I = currentLoop->block_begin(), + E = currentLoop->block_end(); I != E; ++I) { + TerminatorInst *TI = (*I)->getTerminator(); + + // Unswitching on a potentially uninitialized predicate is not + // MSan-friendly. Limit this to the cases when the original predicate is + // guaranteed to execute, to avoid creating a use-of-uninitialized-value + // in the code that did not have one. + // This is a workaround for the discrepancy between LLVM IR and MSan + // semantics. See PR28054 for more details. + if (SanitizeMemory && + !isGuaranteedToExecute(*TI, DT, currentLoop, &SafetyInfo)) + continue; + + if (BranchInst *BI = dyn_cast<BranchInst>(TI)) { + // Some branches may be rendered unreachable because of previous + // unswitching. + // Unswitch only those branches that are reachable. + if (isUnreachableDueToPreviousUnswitching(*I)) + continue; + + // If this isn't branching on an invariant condition, we can't unswitch + // it. + if (BI->isConditional()) { + // See if this, or some part of it, is loop invariant. If so, we can + // unswitch on it if we desire. + Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), + currentLoop, Changed).first; + if (LoopCond && !EqualityPropUnSafe(*LoopCond) && + UnswitchIfProfitable(LoopCond, ConstantInt::getTrue(Context), TI)) { + ++NumBranches; + return true; + } + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(TI)) { + Value *SC = SI->getCondition(); + Value *LoopCond; + OperatorChain OpChain; + std::tie(LoopCond, OpChain) = + FindLIVLoopCondition(SC, currentLoop, Changed); + + unsigned NumCases = SI->getNumCases(); + if (LoopCond && NumCases) { + // Find a value to unswitch on: + // FIXME: this should chose the most expensive case! + // FIXME: scan for a case with a non-critical edge? + Constant *UnswitchVal = nullptr; + // Find a case value such that at least one case value is unswitched + // out. + if (OpChain == OC_OpChainAnd) { + // If the chain only has ANDs and the switch has a case value of 0. + // Dropping in a 0 to the chain will unswitch out the 0-casevalue. + auto *AllZero = cast<ConstantInt>(Constant::getNullValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllZero)) + continue; + // We are unswitching 0 out. + UnswitchVal = AllZero; + } else if (OpChain == OC_OpChainOr) { + // If the chain only has ORs and the switch has a case value of ~0. + // Dropping in a ~0 to the chain will unswitch out the ~0-casevalue. + auto *AllOne = cast<ConstantInt>(Constant::getAllOnesValue(SC->getType())); + if (BranchesInfo.isUnswitched(SI, AllOne)) + continue; + // We are unswitching ~0 out. + UnswitchVal = AllOne; + } else { + assert(OpChain == OC_OpChainNone && + "Expect to unswitch on trivial chain"); + // Do not process same value again and again. + // At this point we have some cases already unswitched and + // some not yet unswitched. Let's find the first not yet unswitched one. + for (auto Case : SI->cases()) { + Constant *UnswitchValCandidate = Case.getCaseValue(); + if (!BranchesInfo.isUnswitched(SI, UnswitchValCandidate)) { + UnswitchVal = UnswitchValCandidate; + break; + } + } + } + + if (!UnswitchVal) + continue; + + if (UnswitchIfProfitable(LoopCond, UnswitchVal)) { + ++NumSwitches; + // In case of a full LIV, UnswitchVal is the value we unswitched out. + // In case of a partial LIV, we only unswitch when its an AND-chain + // or OR-chain. In both cases switch input value simplifies to + // UnswitchVal. + BranchesInfo.setUnswitched(SI, UnswitchVal); + return true; + } + } + } + + // Scan the instructions to check for unswitchable values. + for (BasicBlock::iterator BBI = (*I)->begin(), E = (*I)->end(); + BBI != E; ++BBI) + if (SelectInst *SI = dyn_cast<SelectInst>(BBI)) { + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), + currentLoop, Changed).first; + if (LoopCond && UnswitchIfProfitable(LoopCond, + ConstantInt::getTrue(Context))) { + ++NumSelects; + return true; + } + } + } + return Changed; +} + +/// Check to see if all paths from BB exit the loop with no side effects +/// (including infinite loops). +/// +/// If true, we return true and set ExitBB to the block we +/// exit through. +/// +static bool isTrivialLoopExitBlockHelper(Loop *L, BasicBlock *BB, + BasicBlock *&ExitBB, + std::set<BasicBlock*> &Visited) { + if (!Visited.insert(BB).second) { + // Already visited. Without more analysis, this could indicate an infinite + // loop. + return false; + } + if (!L->contains(BB)) { + // Otherwise, this is a loop exit, this is fine so long as this is the + // first exit. + if (ExitBB) return false; + ExitBB = BB; + return true; + } + + // Otherwise, this is an unvisited intra-loop node. Check all successors. + for (succ_iterator SI = succ_begin(BB), E = succ_end(BB); SI != E; ++SI) { + // Check to see if the successor is a trivial loop exit. + if (!isTrivialLoopExitBlockHelper(L, *SI, ExitBB, Visited)) + return false; + } + + // Okay, everything after this looks good, check to make sure that this block + // doesn't include any side effects. + for (Instruction &I : *BB) + if (I.mayHaveSideEffects()) + return false; + + return true; +} + +/// Return true if the specified block unconditionally leads to an exit from +/// the specified loop, and has no side-effects in the process. If so, return +/// the block that is exited to, otherwise return null. +static BasicBlock *isTrivialLoopExitBlock(Loop *L, BasicBlock *BB) { + std::set<BasicBlock*> Visited; + Visited.insert(L->getHeader()); // Branches to header make infinite loops. + BasicBlock *ExitBB = nullptr; + if (isTrivialLoopExitBlockHelper(L, BB, ExitBB, Visited)) + return ExitBB; + return nullptr; +} + +/// We have found that we can unswitch currentLoop when LoopCond == Val to +/// simplify the loop. If we decide that this is profitable, +/// unswitch the loop, reprocess the pieces, then return true. +bool LoopUnswitch::UnswitchIfProfitable(Value *LoopCond, Constant *Val, + TerminatorInst *TI) { + // Check to see if it would be profitable to unswitch current loop. + if (!BranchesInfo.CostAllowsUnswitching()) { + DEBUG(dbgs() << "NOT unswitching loop %" + << currentLoop->getHeader()->getName() + << " at non-trivial condition '" << *Val + << "' == " << *LoopCond << "\n" + << ". Cost too high.\n"); + return false; + } + if (hasBranchDivergence && + getAnalysis<DivergenceAnalysis>().isDivergent(LoopCond)) { + DEBUG(dbgs() << "NOT unswitching loop %" + << currentLoop->getHeader()->getName() + << " at non-trivial condition '" << *Val + << "' == " << *LoopCond << "\n" + << ". Condition is divergent.\n"); + return false; + } + + UnswitchNontrivialCondition(LoopCond, Val, currentLoop, TI); + return true; +} + +/// Recursively clone the specified loop and all of its children, +/// mapping the blocks with the specified map. +static Loop *CloneLoop(Loop *L, Loop *PL, ValueToValueMapTy &VM, + LoopInfo *LI, LPPassManager *LPM) { + Loop &New = *LI->AllocateLoop(); + if (PL) + PL->addChildLoop(&New); + else + LI->addTopLevelLoop(&New); + LPM->addLoop(New); + + // Add all of the blocks in L to the new loop. + for (Loop::block_iterator I = L->block_begin(), E = L->block_end(); + I != E; ++I) + if (LI->getLoopFor(*I) == L) + New.addBasicBlockToLoop(cast<BasicBlock>(VM[*I]), *LI); + + // Add all of the subloops to the new loop. + for (Loop *I : *L) + CloneLoop(I, &New, VM, LI, LPM); + + return &New; +} + +/// Emit a conditional branch on two values if LIC == Val, branch to TrueDst, +/// otherwise branch to FalseDest. Insert the code immediately before OldBranch +/// and remove (but not erase!) it from the function. +void LoopUnswitch::EmitPreheaderBranchOnCondition(Value *LIC, Constant *Val, + BasicBlock *TrueDest, + BasicBlock *FalseDest, + BranchInst *OldBranch, + TerminatorInst *TI) { + assert(OldBranch->isUnconditional() && "Preheader is not split correctly"); + // Insert a conditional branch on LIC to the two preheaders. The original + // code is the true version and the new code is the false version. + Value *BranchVal = LIC; + bool Swapped = false; + if (!isa<ConstantInt>(Val) || + Val->getType() != Type::getInt1Ty(LIC->getContext())) + BranchVal = new ICmpInst(OldBranch, ICmpInst::ICMP_EQ, LIC, Val); + else if (Val != ConstantInt::getTrue(Val->getContext())) { + // We want to enter the new loop when the condition is true. + std::swap(TrueDest, FalseDest); + Swapped = true; + } + + // Old branch will be removed, so save its parent and successor to update the + // DomTree. + auto *OldBranchSucc = OldBranch->getSuccessor(0); + auto *OldBranchParent = OldBranch->getParent(); + + // Insert the new branch. + BranchInst *BI = + IRBuilder<>(OldBranch).CreateCondBr(BranchVal, TrueDest, FalseDest, TI); + if (Swapped) + BI->swapProfMetadata(); + + // Remove the old branch so there is only one branch at the end. This is + // needed to perform DomTree's internal DFS walk on the function's CFG. + OldBranch->removeFromParent(); + + // Inform the DT about the new branch. + if (DT) { + // First, add both successors. + SmallVector<DominatorTree::UpdateType, 3> Updates; + if (TrueDest != OldBranchParent) + Updates.push_back({DominatorTree::Insert, OldBranchParent, TrueDest}); + if (FalseDest != OldBranchParent) + Updates.push_back({DominatorTree::Insert, OldBranchParent, FalseDest}); + // If both of the new successors are different from the old one, inform the + // DT that the edge was deleted. + if (OldBranchSucc != TrueDest && OldBranchSucc != FalseDest) { + Updates.push_back({DominatorTree::Delete, OldBranchParent, OldBranchSucc}); + } + + DT->applyUpdates(Updates); + } + + // If either edge is critical, split it. This helps preserve LoopSimplify + // form for enclosing loops. + auto Options = CriticalEdgeSplittingOptions(DT, LI).setPreserveLCSSA(); + SplitCriticalEdge(BI, 0, Options); + SplitCriticalEdge(BI, 1, Options); +} + +/// Given a loop that has a trivial unswitchable condition in it (a cond branch +/// from its header block to its latch block, where the path through the loop +/// that doesn't execute its body has no side-effects), unswitch it. This +/// doesn't involve any code duplication, just moving the conditional branch +/// outside of the loop and updating loop info. +void LoopUnswitch::UnswitchTrivialCondition(Loop *L, Value *Cond, Constant *Val, + BasicBlock *ExitBlock, + TerminatorInst *TI) { + DEBUG(dbgs() << "loop-unswitch: Trivial-Unswitch loop %" + << loopHeader->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " + << L->getHeader()->getParent()->getName() << " on cond: " << *Val + << " == " << *Cond << "\n"); + + // First step, split the preheader, so that we know that there is a safe place + // to insert the conditional branch. We will change loopPreheader to have a + // conditional branch on Cond. + BasicBlock *NewPH = SplitEdge(loopPreheader, loopHeader, DT, LI); + + // Now that we have a place to insert the conditional branch, create a place + // to branch to: this is the exit block out of the loop that we should + // short-circuit to. + + // Split this block now, so that the loop maintains its exit block, and so + // that the jump from the preheader can execute the contents of the exit block + // without actually branching to it (the exit block should be dominated by the + // loop header, not the preheader). + assert(!L->contains(ExitBlock) && "Exit block is in the loop?"); + BasicBlock *NewExit = SplitBlock(ExitBlock, &ExitBlock->front(), DT, LI); + + // Okay, now we have a position to branch from and a position to branch to, + // insert the new conditional branch. + auto *OldBranch = dyn_cast<BranchInst>(loopPreheader->getTerminator()); + assert(OldBranch && "Failed to split the preheader"); + EmitPreheaderBranchOnCondition(Cond, Val, NewExit, NewPH, OldBranch, TI); + LPM->deleteSimpleAnalysisValue(OldBranch, L); + + // EmitPreheaderBranchOnCondition removed the OldBranch from the function. + // Delete it, as it is no longer needed. + delete OldBranch; + + // We need to reprocess this loop, it could be unswitched again. + redoLoop = true; + + // Now that we know that the loop is never entered when this condition is a + // particular value, rewrite the loop with this info. We know that this will + // at least eliminate the old branch. + RewriteLoopBodyWithConditionConstant(L, Cond, Val, false); + ++NumTrivial; +} + +/// Check if the first non-constant condition starting from the loop header is +/// a trivial unswitch condition: that is, a condition controls whether or not +/// the loop does anything at all. If it is a trivial condition, unswitching +/// produces no code duplications (equivalently, it produces a simpler loop and +/// a new empty loop, which gets deleted). Therefore always unswitch trivial +/// condition. +bool LoopUnswitch::TryTrivialLoopUnswitch(bool &Changed) { + BasicBlock *CurrentBB = currentLoop->getHeader(); + TerminatorInst *CurrentTerm = CurrentBB->getTerminator(); + LLVMContext &Context = CurrentBB->getContext(); + + // If loop header has only one reachable successor (currently via an + // unconditional branch or constant foldable conditional branch, but + // should also consider adding constant foldable switch instruction in + // future), we should keep looking for trivial condition candidates in + // the successor as well. An alternative is to constant fold conditions + // and merge successors into loop header (then we only need to check header's + // terminator). The reason for not doing this in LoopUnswitch pass is that + // it could potentially break LoopPassManager's invariants. Folding dead + // branches could either eliminate the current loop or make other loops + // unreachable. LCSSA form might also not be preserved after deleting + // branches. The following code keeps traversing loop header's successors + // until it finds the trivial condition candidate (condition that is not a + // constant). Since unswitching generates branches with constant conditions, + // this scenario could be very common in practice. + SmallSet<BasicBlock*, 8> Visited; + + while (true) { + // If we exit loop or reach a previous visited block, then + // we can not reach any trivial condition candidates (unfoldable + // branch instructions or switch instructions) and no unswitch + // can happen. Exit and return false. + if (!currentLoop->contains(CurrentBB) || !Visited.insert(CurrentBB).second) + return false; + + // Check if this loop will execute any side-effecting instructions (e.g. + // stores, calls, volatile loads) in the part of the loop that the code + // *would* execute. Check the header first. + for (Instruction &I : *CurrentBB) + if (I.mayHaveSideEffects()) + return false; + + if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) { + if (BI->isUnconditional()) { + CurrentBB = BI->getSuccessor(0); + } else if (BI->getCondition() == ConstantInt::getTrue(Context)) { + CurrentBB = BI->getSuccessor(0); + } else if (BI->getCondition() == ConstantInt::getFalse(Context)) { + CurrentBB = BI->getSuccessor(1); + } else { + // Found a trivial condition candidate: non-foldable conditional branch. + break; + } + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { + // At this point, any constant-foldable instructions should have probably + // been folded. + ConstantInt *Cond = dyn_cast<ConstantInt>(SI->getCondition()); + if (!Cond) + break; + // Find the target block we are definitely going to. + CurrentBB = SI->findCaseValue(Cond)->getCaseSuccessor(); + } else { + // We do not understand these terminator instructions. + break; + } + + CurrentTerm = CurrentBB->getTerminator(); + } + + // CondVal is the condition that controls the trivial condition. + // LoopExitBB is the BasicBlock that loop exits when meets trivial condition. + Constant *CondVal = nullptr; + BasicBlock *LoopExitBB = nullptr; + + if (BranchInst *BI = dyn_cast<BranchInst>(CurrentTerm)) { + // If this isn't branching on an invariant condition, we can't unswitch it. + if (!BI->isConditional()) + return false; + + Value *LoopCond = FindLIVLoopCondition(BI->getCondition(), + currentLoop, Changed).first; + + // Unswitch only if the trivial condition itself is an LIV (not + // partial LIV which could occur in and/or) + if (!LoopCond || LoopCond != BI->getCondition()) + return false; + + // Check to see if a successor of the branch is guaranteed to + // exit through a unique exit block without having any + // side-effects. If so, determine the value of Cond that causes + // it to do this. + if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, + BI->getSuccessor(0)))) { + CondVal = ConstantInt::getTrue(Context); + } else if ((LoopExitBB = isTrivialLoopExitBlock(currentLoop, + BI->getSuccessor(1)))) { + CondVal = ConstantInt::getFalse(Context); + } + + // If we didn't find a single unique LoopExit block, or if the loop exit + // block contains phi nodes, this isn't trivial. + if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) + return false; // Can't handle this. + + if (EqualityPropUnSafe(*LoopCond)) + return false; + + UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, + CurrentTerm); + ++NumBranches; + return true; + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(CurrentTerm)) { + // If this isn't switching on an invariant condition, we can't unswitch it. + Value *LoopCond = FindLIVLoopCondition(SI->getCondition(), + currentLoop, Changed).first; + + // Unswitch only if the trivial condition itself is an LIV (not + // partial LIV which could occur in and/or) + if (!LoopCond || LoopCond != SI->getCondition()) + return false; + + // Check to see if a successor of the switch is guaranteed to go to the + // latch block or exit through a one exit block without having any + // side-effects. If so, determine the value of Cond that causes it to do + // this. + // Note that we can't trivially unswitch on the default case or + // on already unswitched cases. + for (auto Case : SI->cases()) { + BasicBlock *LoopExitCandidate; + if ((LoopExitCandidate = + isTrivialLoopExitBlock(currentLoop, Case.getCaseSuccessor()))) { + // Okay, we found a trivial case, remember the value that is trivial. + ConstantInt *CaseVal = Case.getCaseValue(); + + // Check that it was not unswitched before, since already unswitched + // trivial vals are looks trivial too. + if (BranchesInfo.isUnswitched(SI, CaseVal)) + continue; + LoopExitBB = LoopExitCandidate; + CondVal = CaseVal; + break; + } + } + + // If we didn't find a single unique LoopExit block, or if the loop exit + // block contains phi nodes, this isn't trivial. + if (!LoopExitBB || isa<PHINode>(LoopExitBB->begin())) + return false; // Can't handle this. + + UnswitchTrivialCondition(currentLoop, LoopCond, CondVal, LoopExitBB, + nullptr); + + // We are only unswitching full LIV. + BranchesInfo.setUnswitched(SI, CondVal); + ++NumSwitches; + return true; + } + return false; +} + +/// Split all of the edges from inside the loop to their exit blocks. +/// Update the appropriate Phi nodes as we do so. +void LoopUnswitch::SplitExitEdges(Loop *L, + const SmallVectorImpl<BasicBlock *> &ExitBlocks){ + + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *ExitBlock = ExitBlocks[i]; + SmallVector<BasicBlock *, 4> Preds(pred_begin(ExitBlock), + pred_end(ExitBlock)); + + // Although SplitBlockPredecessors doesn't preserve loop-simplify in + // general, if we call it on all predecessors of all exits then it does. + SplitBlockPredecessors(ExitBlock, Preds, ".us-lcssa", DT, LI, + /*PreserveLCSSA*/ true); + } +} + +/// We determined that the loop is profitable to unswitch when LIC equal Val. +/// Split it into loop versions and test the condition outside of either loop. +/// Return the loops created as Out1/Out2. +void LoopUnswitch::UnswitchNontrivialCondition(Value *LIC, Constant *Val, + Loop *L, TerminatorInst *TI) { + Function *F = loopHeader->getParent(); + DEBUG(dbgs() << "loop-unswitch: Unswitching loop %" + << loopHeader->getName() << " [" << L->getBlocks().size() + << " blocks] in Function " << F->getName() + << " when '" << *Val << "' == " << *LIC << "\n"); + + if (auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>()) + SEWP->getSE().forgetLoop(L); + + LoopBlocks.clear(); + NewBlocks.clear(); + + // First step, split the preheader and exit blocks, and add these blocks to + // the LoopBlocks list. + BasicBlock *NewPreheader = SplitEdge(loopPreheader, loopHeader, DT, LI); + LoopBlocks.push_back(NewPreheader); + + // We want the loop to come after the preheader, but before the exit blocks. + LoopBlocks.insert(LoopBlocks.end(), L->block_begin(), L->block_end()); + + SmallVector<BasicBlock*, 8> ExitBlocks; + L->getUniqueExitBlocks(ExitBlocks); + + // Split all of the edges from inside the loop to their exit blocks. Update + // the appropriate Phi nodes as we do so. + SplitExitEdges(L, ExitBlocks); + + // The exit blocks may have been changed due to edge splitting, recompute. + ExitBlocks.clear(); + L->getUniqueExitBlocks(ExitBlocks); + + // Add exit blocks to the loop blocks. + LoopBlocks.insert(LoopBlocks.end(), ExitBlocks.begin(), ExitBlocks.end()); + + // Next step, clone all of the basic blocks that make up the loop (including + // the loop preheader and exit blocks), keeping track of the mapping between + // the instructions and blocks. + NewBlocks.reserve(LoopBlocks.size()); + ValueToValueMapTy VMap; + for (unsigned i = 0, e = LoopBlocks.size(); i != e; ++i) { + BasicBlock *NewBB = CloneBasicBlock(LoopBlocks[i], VMap, ".us", F); + + NewBlocks.push_back(NewBB); + VMap[LoopBlocks[i]] = NewBB; // Keep the BB mapping. + LPM->cloneBasicBlockSimpleAnalysis(LoopBlocks[i], NewBB, L); + } + + // Splice the newly inserted blocks into the function right before the + // original preheader. + F->getBasicBlockList().splice(NewPreheader->getIterator(), + F->getBasicBlockList(), + NewBlocks[0]->getIterator(), F->end()); + + // Now we create the new Loop object for the versioned loop. + Loop *NewLoop = CloneLoop(L, L->getParentLoop(), VMap, LI, LPM); + + // Recalculate unswitching quota, inherit simplified switches info for NewBB, + // Probably clone more loop-unswitch related loop properties. + BranchesInfo.cloneData(NewLoop, L, VMap); + + Loop *ParentLoop = L->getParentLoop(); + if (ParentLoop) { + // Make sure to add the cloned preheader and exit blocks to the parent loop + // as well. + ParentLoop->addBasicBlockToLoop(NewBlocks[0], *LI); + } + + for (unsigned i = 0, e = ExitBlocks.size(); i != e; ++i) { + BasicBlock *NewExit = cast<BasicBlock>(VMap[ExitBlocks[i]]); + // The new exit block should be in the same loop as the old one. + if (Loop *ExitBBLoop = LI->getLoopFor(ExitBlocks[i])) + ExitBBLoop->addBasicBlockToLoop(NewExit, *LI); + + assert(NewExit->getTerminator()->getNumSuccessors() == 1 && + "Exit block should have been split to have one successor!"); + BasicBlock *ExitSucc = NewExit->getTerminator()->getSuccessor(0); + + // If the successor of the exit block had PHI nodes, add an entry for + // NewExit. + for (PHINode &PN : ExitSucc->phis()) { + Value *V = PN.getIncomingValueForBlock(ExitBlocks[i]); + ValueToValueMapTy::iterator It = VMap.find(V); + if (It != VMap.end()) V = It->second; + PN.addIncoming(V, NewExit); + } + + if (LandingPadInst *LPad = NewExit->getLandingPadInst()) { + PHINode *PN = PHINode::Create(LPad->getType(), 0, "", + &*ExitSucc->getFirstInsertionPt()); + + for (pred_iterator I = pred_begin(ExitSucc), E = pred_end(ExitSucc); + I != E; ++I) { + BasicBlock *BB = *I; + LandingPadInst *LPI = BB->getLandingPadInst(); + LPI->replaceAllUsesWith(PN); + PN->addIncoming(LPI, BB); + } + } + } + + // Rewrite the code to refer to itself. + for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i) { + for (Instruction &I : *NewBlocks[i]) { + RemapInstruction(&I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC->registerAssumption(II); + } + } + + // Rewrite the original preheader to select between versions of the loop. + BranchInst *OldBR = cast<BranchInst>(loopPreheader->getTerminator()); + assert(OldBR->isUnconditional() && OldBR->getSuccessor(0) == LoopBlocks[0] && + "Preheader splitting did not work correctly!"); + + // Emit the new branch that selects between the two versions of this loop. + EmitPreheaderBranchOnCondition(LIC, Val, NewBlocks[0], LoopBlocks[0], OldBR, + TI); + LPM->deleteSimpleAnalysisValue(OldBR, L); + + // The OldBr was replaced by a new one and removed (but not erased) by + // EmitPreheaderBranchOnCondition. It is no longer needed, so delete it. + delete OldBR; + + LoopProcessWorklist.push_back(NewLoop); + redoLoop = true; + + // Keep a WeakTrackingVH holding onto LIC. If the first call to + // RewriteLoopBody + // deletes the instruction (for example by simplifying a PHI that feeds into + // the condition that we're unswitching on), we don't rewrite the second + // iteration. + WeakTrackingVH LICHandle(LIC); + + // Now we rewrite the original code to know that the condition is true and the + // new code to know that the condition is false. + RewriteLoopBodyWithConditionConstant(L, LIC, Val, false); + + // It's possible that simplifying one loop could cause the other to be + // changed to another value or a constant. If its a constant, don't simplify + // it. + if (!LoopProcessWorklist.empty() && LoopProcessWorklist.back() == NewLoop && + LICHandle && !isa<Constant>(LICHandle)) + RewriteLoopBodyWithConditionConstant(NewLoop, LICHandle, Val, true); +} + +/// Remove all instances of I from the worklist vector specified. +static void RemoveFromWorklist(Instruction *I, + std::vector<Instruction*> &Worklist) { + + Worklist.erase(std::remove(Worklist.begin(), Worklist.end(), I), + Worklist.end()); +} + +/// When we find that I really equals V, remove I from the +/// program, replacing all uses with V and update the worklist. +static void ReplaceUsesOfWith(Instruction *I, Value *V, + std::vector<Instruction*> &Worklist, + Loop *L, LPPassManager *LPM) { + DEBUG(dbgs() << "Replace with '" << *V << "': " << *I << "\n"); + + // Add uses to the worklist, which may be dead now. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i))) + Worklist.push_back(Use); + + // Add users to the worklist which may be simplified now. + for (User *U : I->users()) + Worklist.push_back(cast<Instruction>(U)); + LPM->deleteSimpleAnalysisValue(I, L); + RemoveFromWorklist(I, Worklist); + I->replaceAllUsesWith(V); + if (!I->mayHaveSideEffects()) + I->eraseFromParent(); + ++NumSimplify; +} + +/// We know either that the value LIC has the value specified by Val in the +/// specified loop, or we know it does NOT have that value. +/// Rewrite any uses of LIC or of properties correlated to it. +void LoopUnswitch::RewriteLoopBodyWithConditionConstant(Loop *L, Value *LIC, + Constant *Val, + bool IsEqual) { + assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); + + // FIXME: Support correlated properties, like: + // for (...) + // if (li1 < li2) + // ... + // if (li1 > li2) + // ... + + // FOLD boolean conditions (X|LIC), (X&LIC). Fold conditional branches, + // selects, switches. + std::vector<Instruction*> Worklist; + LLVMContext &Context = Val->getContext(); + + // If we know that LIC == Val, or that LIC == NotVal, just replace uses of LIC + // in the loop with the appropriate one directly. + if (IsEqual || (isa<ConstantInt>(Val) && + Val->getType()->isIntegerTy(1))) { + Value *Replacement; + if (IsEqual) + Replacement = Val; + else + Replacement = ConstantInt::get(Type::getInt1Ty(Val->getContext()), + !cast<ConstantInt>(Val)->getZExtValue()); + + for (User *U : LIC->users()) { + Instruction *UI = dyn_cast<Instruction>(U); + if (!UI || !L->contains(UI)) + continue; + Worklist.push_back(UI); + } + + for (Instruction *UI : Worklist) + UI->replaceUsesOfWith(LIC, Replacement); + + SimplifyCode(Worklist, L); + return; + } + + // Otherwise, we don't know the precise value of LIC, but we do know that it + // is certainly NOT "Val". As such, simplify any uses in the loop that we + // can. This case occurs when we unswitch switch statements. + for (User *U : LIC->users()) { + Instruction *UI = dyn_cast<Instruction>(U); + if (!UI || !L->contains(UI)) + continue; + + // At this point, we know LIC is definitely not Val. Try to use some simple + // logic to simplify the user w.r.t. to the context. + if (Value *Replacement = SimplifyInstructionWithNotEqual(UI, LIC, Val)) { + if (LI->replacementPreservesLCSSAForm(UI, Replacement)) { + // This in-loop instruction has been simplified w.r.t. its context, + // i.e. LIC != Val, make sure we propagate its replacement value to + // all its users. + // + // We can not yet delete UI, the LIC user, yet, because that would invalidate + // the LIC->users() iterator !. However, we can make this instruction + // dead by replacing all its users and push it onto the worklist so that + // it can be properly deleted and its operands simplified. + UI->replaceAllUsesWith(Replacement); + } + } + + // This is a LIC user, push it into the worklist so that SimplifyCode can + // attempt to simplify it. + Worklist.push_back(UI); + + // If we know that LIC is not Val, use this info to simplify code. + SwitchInst *SI = dyn_cast<SwitchInst>(UI); + if (!SI || !isa<ConstantInt>(Val)) continue; + + // NOTE: if a case value for the switch is unswitched out, we record it + // after the unswitch finishes. We can not record it here as the switch + // is not a direct user of the partial LIV. + SwitchInst::CaseHandle DeadCase = + *SI->findCaseValue(cast<ConstantInt>(Val)); + // Default case is live for multiple values. + if (DeadCase == *SI->case_default()) + continue; + + // Found a dead case value. Don't remove PHI nodes in the + // successor if they become single-entry, those PHI nodes may + // be in the Users list. + + BasicBlock *Switch = SI->getParent(); + BasicBlock *SISucc = DeadCase.getCaseSuccessor(); + BasicBlock *Latch = L->getLoopLatch(); + + if (!SI->findCaseDest(SISucc)) continue; // Edge is critical. + // If the DeadCase successor dominates the loop latch, then the + // transformation isn't safe since it will delete the sole predecessor edge + // to the latch. + if (Latch && DT->dominates(SISucc, Latch)) + continue; + + // FIXME: This is a hack. We need to keep the successor around + // and hooked up so as to preserve the loop structure, because + // trying to update it is complicated. So instead we preserve the + // loop structure and put the block on a dead code path. + SplitEdge(Switch, SISucc, DT, LI); + // Compute the successors instead of relying on the return value + // of SplitEdge, since it may have split the switch successor + // after PHI nodes. + BasicBlock *NewSISucc = DeadCase.getCaseSuccessor(); + BasicBlock *OldSISucc = *succ_begin(NewSISucc); + // Create an "unreachable" destination. + BasicBlock *Abort = BasicBlock::Create(Context, "us-unreachable", + Switch->getParent(), + OldSISucc); + new UnreachableInst(Context, Abort); + // Force the new case destination to branch to the "unreachable" + // block while maintaining a (dead) CFG edge to the old block. + NewSISucc->getTerminator()->eraseFromParent(); + BranchInst::Create(Abort, OldSISucc, + ConstantInt::getTrue(Context), NewSISucc); + // Release the PHI operands for this edge. + for (PHINode &PN : NewSISucc->phis()) + PN.setIncomingValue(PN.getBasicBlockIndex(Switch), + UndefValue::get(PN.getType())); + // Tell the domtree about the new block. We don't fully update the + // domtree here -- instead we force it to do a full recomputation + // after the pass is complete -- but we do need to inform it of + // new blocks. + DT->addNewBlock(Abort, NewSISucc); + } + + SimplifyCode(Worklist, L); +} + +/// Now that we have simplified some instructions in the loop, walk over it and +/// constant prop, dce, and fold control flow where possible. Note that this is +/// effectively a very simple loop-structure-aware optimizer. During processing +/// of this loop, L could very well be deleted, so it must not be used. +/// +/// FIXME: When the loop optimizer is more mature, separate this out to a new +/// pass. +/// +void LoopUnswitch::SimplifyCode(std::vector<Instruction*> &Worklist, Loop *L) { + const DataLayout &DL = L->getHeader()->getModule()->getDataLayout(); + while (!Worklist.empty()) { + Instruction *I = Worklist.back(); + Worklist.pop_back(); + + // Simple DCE. + if (isInstructionTriviallyDead(I)) { + DEBUG(dbgs() << "Remove dead instruction '" << *I << "\n"); + + // Add uses to the worklist, which may be dead now. + for (unsigned i = 0, e = I->getNumOperands(); i != e; ++i) + if (Instruction *Use = dyn_cast<Instruction>(I->getOperand(i))) + Worklist.push_back(Use); + LPM->deleteSimpleAnalysisValue(I, L); + RemoveFromWorklist(I, Worklist); + I->eraseFromParent(); + ++NumSimplify; + continue; + } + + // See if instruction simplification can hack this up. This is common for + // things like "select false, X, Y" after unswitching made the condition be + // 'false'. TODO: update the domtree properly so we can pass it here. + if (Value *V = SimplifyInstruction(I, DL)) + if (LI->replacementPreservesLCSSAForm(I, V)) { + ReplaceUsesOfWith(I, V, Worklist, L, LPM); + continue; + } + + // Special case hacks that appear commonly in unswitched code. + if (BranchInst *BI = dyn_cast<BranchInst>(I)) { + if (BI->isUnconditional()) { + // If BI's parent is the only pred of the successor, fold the two blocks + // together. + BasicBlock *Pred = BI->getParent(); + BasicBlock *Succ = BI->getSuccessor(0); + BasicBlock *SinglePred = Succ->getSinglePredecessor(); + if (!SinglePred) continue; // Nothing to do. + assert(SinglePred == Pred && "CFG broken"); + + DEBUG(dbgs() << "Merging blocks: " << Pred->getName() << " <- " + << Succ->getName() << "\n"); + + // Resolve any single entry PHI nodes in Succ. + while (PHINode *PN = dyn_cast<PHINode>(Succ->begin())) + ReplaceUsesOfWith(PN, PN->getIncomingValue(0), Worklist, L, LPM); + + // If Succ has any successors with PHI nodes, update them to have + // entries coming from Pred instead of Succ. + Succ->replaceAllUsesWith(Pred); + + // Move all of the successor contents from Succ to Pred. + Pred->getInstList().splice(BI->getIterator(), Succ->getInstList(), + Succ->begin(), Succ->end()); + LPM->deleteSimpleAnalysisValue(BI, L); + RemoveFromWorklist(BI, Worklist); + BI->eraseFromParent(); + + // Remove Succ from the loop tree. + LI->removeBlock(Succ); + LPM->deleteSimpleAnalysisValue(Succ, L); + Succ->eraseFromParent(); + ++NumSimplify; + continue; + } + + continue; + } + } +} + +/// Simple simplifications we can do given the information that Cond is +/// definitely not equal to Val. +Value *LoopUnswitch::SimplifyInstructionWithNotEqual(Instruction *Inst, + Value *Invariant, + Constant *Val) { + // icmp eq cond, val -> false + ICmpInst *CI = dyn_cast<ICmpInst>(Inst); + if (CI && CI->isEquality()) { + Value *Op0 = CI->getOperand(0); + Value *Op1 = CI->getOperand(1); + if ((Op0 == Invariant && Op1 == Val) || (Op0 == Val && Op1 == Invariant)) { + LLVMContext &Ctx = Inst->getContext(); + if (CI->getPredicate() == CmpInst::ICMP_EQ) + return ConstantInt::getFalse(Ctx); + else + return ConstantInt::getTrue(Ctx); + } + } + + // FIXME: there may be other opportunities, e.g. comparison with floating + // point, or Invariant - Val != 0, etc. + return nullptr; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp b/contrib/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp new file mode 100644 index 000000000000..53b25e688e82 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LoopVersioningLICM.cpp @@ -0,0 +1,598 @@ +//===- LoopVersioningLICM.cpp - LICM Loop Versioning ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// When alias analysis is uncertain about the aliasing between any two accesses, +// it will return MayAlias. This uncertainty from alias analysis restricts LICM +// from proceeding further. In cases where alias analysis is uncertain we might +// use loop versioning as an alternative. +// +// Loop Versioning will create a version of the loop with aggressive aliasing +// assumptions in addition to the original with conservative (default) aliasing +// assumptions. The version of the loop making aggressive aliasing assumptions +// will have all the memory accesses marked as no-alias. These two versions of +// loop will be preceded by a memory runtime check. This runtime check consists +// of bound checks for all unique memory accessed in loop, and it ensures the +// lack of memory aliasing. The result of the runtime check determines which of +// the loop versions is executed: If the runtime check detects any memory +// aliasing, then the original loop is executed. Otherwise, the version with +// aggressive aliasing assumptions is used. +// +// Following are the top level steps: +// +// a) Perform LoopVersioningLICM's feasibility check. +// b) If loop is a candidate for versioning then create a memory bound check, +// by considering all the memory accesses in loop body. +// c) Clone original loop and set all memory accesses as no-alias in new loop. +// d) Set original loop & versioned loop as a branch target of the runtime check +// result. +// +// It transforms loop as shown below: +// +// +----------------+ +// |Runtime Memcheck| +// +----------------+ +// | +// +----------+----------------+----------+ +// | | +// +---------+----------+ +-----------+----------+ +// |Orig Loop Preheader | |Cloned Loop Preheader | +// +--------------------+ +----------------------+ +// | | +// +--------------------+ +----------------------+ +// |Orig Loop Body | |Cloned Loop Body | +// +--------------------+ +----------------------+ +// | | +// +--------------------+ +----------------------+ +// |Orig Loop Exit Block| |Cloned Loop Exit Block| +// +--------------------+ +-----------+----------+ +// | | +// +----------+--------------+-----------+ +// | +// +-----+----+ +// |Join Block| +// +----------+ +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AliasSetTracker.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/LoopVersioning.h" +#include <cassert> +#include <memory> + +using namespace llvm; + +#define DEBUG_TYPE "loop-versioning-licm" + +static const char *LICMVersioningMetaData = "llvm.loop.licm_versioning.disable"; + +/// Threshold minimum allowed percentage for possible +/// invariant instructions in a loop. +static cl::opt<float> + LVInvarThreshold("licm-versioning-invariant-threshold", + cl::desc("LoopVersioningLICM's minimum allowed percentage" + "of possible invariant instructions per loop"), + cl::init(25), cl::Hidden); + +/// Threshold for maximum allowed loop nest/depth +static cl::opt<unsigned> LVLoopDepthThreshold( + "licm-versioning-max-depth-threshold", + cl::desc( + "LoopVersioningLICM's threshold for maximum allowed loop nest/depth"), + cl::init(2), cl::Hidden); + +/// \brief Create MDNode for input string. +static MDNode *createStringMetadata(Loop *TheLoop, StringRef Name, unsigned V) { + LLVMContext &Context = TheLoop->getHeader()->getContext(); + Metadata *MDs[] = { + MDString::get(Context, Name), + ConstantAsMetadata::get(ConstantInt::get(Type::getInt32Ty(Context), V))}; + return MDNode::get(Context, MDs); +} + +/// \brief Set input string into loop metadata by keeping other values intact. +void llvm::addStringMetadataToLoop(Loop *TheLoop, const char *MDString, + unsigned V) { + SmallVector<Metadata *, 4> MDs(1); + // If the loop already has metadata, retain it. + MDNode *LoopID = TheLoop->getLoopID(); + if (LoopID) { + for (unsigned i = 1, ie = LoopID->getNumOperands(); i < ie; ++i) { + MDNode *Node = cast<MDNode>(LoopID->getOperand(i)); + MDs.push_back(Node); + } + } + // Add new metadata. + MDs.push_back(createStringMetadata(TheLoop, MDString, V)); + // Replace current metadata node with new one. + LLVMContext &Context = TheLoop->getHeader()->getContext(); + MDNode *NewLoopID = MDNode::get(Context, MDs); + // Set operand 0 to refer to the loop id itself. + NewLoopID->replaceOperandWith(0, NewLoopID); + TheLoop->setLoopID(NewLoopID); +} + +namespace { + +struct LoopVersioningLICM : public LoopPass { + static char ID; + + LoopVersioningLICM() + : LoopPass(ID), LoopDepthThreshold(LVLoopDepthThreshold), + InvariantThreshold(LVInvarThreshold) { + initializeLoopVersioningLICMPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequiredID(LCSSAID); + AU.addRequired<LoopAccessLegacyAnalysis>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequiredID(LoopSimplifyID); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addPreserved<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + + StringRef getPassName() const override { return "Loop Versioning for LICM"; } + + void reset() { + AA = nullptr; + SE = nullptr; + LAA = nullptr; + CurLoop = nullptr; + LoadAndStoreCounter = 0; + InvariantCounter = 0; + IsReadOnlyLoop = true; + CurAST.reset(); + } + + class AutoResetter { + public: + AutoResetter(LoopVersioningLICM &LVLICM) : LVLICM(LVLICM) {} + ~AutoResetter() { LVLICM.reset(); } + + private: + LoopVersioningLICM &LVLICM; + }; + +private: + // Current AliasAnalysis information + AliasAnalysis *AA = nullptr; + + // Current ScalarEvolution + ScalarEvolution *SE = nullptr; + + // Current LoopAccessAnalysis + LoopAccessLegacyAnalysis *LAA = nullptr; + + // Current Loop's LoopAccessInfo + const LoopAccessInfo *LAI = nullptr; + + // The current loop we are working on. + Loop *CurLoop = nullptr; + + // AliasSet information for the current loop. + std::unique_ptr<AliasSetTracker> CurAST; + + // Maximum loop nest threshold + unsigned LoopDepthThreshold; + + // Minimum invariant threshold + float InvariantThreshold; + + // Counter to track num of load & store + unsigned LoadAndStoreCounter = 0; + + // Counter to track num of invariant + unsigned InvariantCounter = 0; + + // Read only loop marker. + bool IsReadOnlyLoop = true; + + bool isLegalForVersioning(); + bool legalLoopStructure(); + bool legalLoopInstructions(); + bool legalLoopMemoryAccesses(); + bool isLoopAlreadyVisited(); + void setNoAliasToLoop(Loop *VerLoop); + bool instructionSafeForVersioning(Instruction *I); +}; + +} // end anonymous namespace + +/// \brief Check loop structure and confirms it's good for LoopVersioningLICM. +bool LoopVersioningLICM::legalLoopStructure() { + // Loop must be in loop simplify form. + if (!CurLoop->isLoopSimplifyForm()) { + DEBUG( + dbgs() << " loop is not in loop-simplify form.\n"); + return false; + } + // Loop should be innermost loop, if not return false. + if (!CurLoop->getSubLoops().empty()) { + DEBUG(dbgs() << " loop is not innermost\n"); + return false; + } + // Loop should have a single backedge, if not return false. + if (CurLoop->getNumBackEdges() != 1) { + DEBUG(dbgs() << " loop has multiple backedges\n"); + return false; + } + // Loop must have a single exiting block, if not return false. + if (!CurLoop->getExitingBlock()) { + DEBUG(dbgs() << " loop has multiple exiting block\n"); + return false; + } + // We only handle bottom-tested loop, i.e. loop in which the condition is + // checked at the end of each iteration. With that we can assume that all + // instructions in the loop are executed the same number of times. + if (CurLoop->getExitingBlock() != CurLoop->getLoopLatch()) { + DEBUG(dbgs() << " loop is not bottom tested\n"); + return false; + } + // Parallel loops must not have aliasing loop-invariant memory accesses. + // Hence we don't need to version anything in this case. + if (CurLoop->isAnnotatedParallel()) { + DEBUG(dbgs() << " Parallel loop is not worth versioning\n"); + return false; + } + // Loop depth more then LoopDepthThreshold are not allowed + if (CurLoop->getLoopDepth() > LoopDepthThreshold) { + DEBUG(dbgs() << " loop depth is more then threshold\n"); + return false; + } + // We need to be able to compute the loop trip count in order + // to generate the bound checks. + const SCEV *ExitCount = SE->getBackedgeTakenCount(CurLoop); + if (ExitCount == SE->getCouldNotCompute()) { + DEBUG(dbgs() << " loop does not has trip count\n"); + return false; + } + return true; +} + +/// \brief Check memory accesses in loop and confirms it's good for +/// LoopVersioningLICM. +bool LoopVersioningLICM::legalLoopMemoryAccesses() { + bool HasMayAlias = false; + bool TypeSafety = false; + bool HasMod = false; + // Memory check: + // Transform phase will generate a versioned loop and also a runtime check to + // ensure the pointers are independent and they don’t alias. + // In version variant of loop, alias meta data asserts that all access are + // mutually independent. + // + // Pointers aliasing in alias domain are avoided because with multiple + // aliasing domains we may not be able to hoist potential loop invariant + // access out of the loop. + // + // Iterate over alias tracker sets, and confirm AliasSets doesn't have any + // must alias set. + for (const auto &I : *CurAST) { + const AliasSet &AS = I; + // Skip Forward Alias Sets, as this should be ignored as part of + // the AliasSetTracker object. + if (AS.isForwardingAliasSet()) + continue; + // With MustAlias its not worth adding runtime bound check. + if (AS.isMustAlias()) + return false; + Value *SomePtr = AS.begin()->getValue(); + bool TypeCheck = true; + // Check for Mod & MayAlias + HasMayAlias |= AS.isMayAlias(); + HasMod |= AS.isMod(); + for (const auto &A : AS) { + Value *Ptr = A.getValue(); + // Alias tracker should have pointers of same data type. + TypeCheck = (TypeCheck && (SomePtr->getType() == Ptr->getType())); + } + // At least one alias tracker should have pointers of same data type. + TypeSafety |= TypeCheck; + } + // Ensure types should be of same type. + if (!TypeSafety) { + DEBUG(dbgs() << " Alias tracker type safety failed!\n"); + return false; + } + // Ensure loop body shouldn't be read only. + if (!HasMod) { + DEBUG(dbgs() << " No memory modified in loop body\n"); + return false; + } + // Make sure alias set has may alias case. + // If there no alias memory ambiguity, return false. + if (!HasMayAlias) { + DEBUG(dbgs() << " No ambiguity in memory access.\n"); + return false; + } + return true; +} + +/// \brief Check loop instructions safe for Loop versioning. +/// It returns true if it's safe else returns false. +/// Consider following: +/// 1) Check all load store in loop body are non atomic & non volatile. +/// 2) Check function call safety, by ensuring its not accessing memory. +/// 3) Loop body shouldn't have any may throw instruction. +bool LoopVersioningLICM::instructionSafeForVersioning(Instruction *I) { + assert(I != nullptr && "Null instruction found!"); + // Check function call safety + if (isa<CallInst>(I) && !AA->doesNotAccessMemory(CallSite(I))) { + DEBUG(dbgs() << " Unsafe call site found.\n"); + return false; + } + // Avoid loops with possiblity of throw + if (I->mayThrow()) { + DEBUG(dbgs() << " May throw instruction found in loop body\n"); + return false; + } + // If current instruction is load instructions + // make sure it's a simple load (non atomic & non volatile) + if (I->mayReadFromMemory()) { + LoadInst *Ld = dyn_cast<LoadInst>(I); + if (!Ld || !Ld->isSimple()) { + DEBUG(dbgs() << " Found a non-simple load.\n"); + return false; + } + LoadAndStoreCounter++; + Value *Ptr = Ld->getPointerOperand(); + // Check loop invariant. + if (SE->isLoopInvariant(SE->getSCEV(Ptr), CurLoop)) + InvariantCounter++; + } + // If current instruction is store instruction + // make sure it's a simple store (non atomic & non volatile) + else if (I->mayWriteToMemory()) { + StoreInst *St = dyn_cast<StoreInst>(I); + if (!St || !St->isSimple()) { + DEBUG(dbgs() << " Found a non-simple store.\n"); + return false; + } + LoadAndStoreCounter++; + Value *Ptr = St->getPointerOperand(); + // Check loop invariant. + if (SE->isLoopInvariant(SE->getSCEV(Ptr), CurLoop)) + InvariantCounter++; + + IsReadOnlyLoop = false; + } + return true; +} + +/// \brief Check loop instructions and confirms it's good for +/// LoopVersioningLICM. +bool LoopVersioningLICM::legalLoopInstructions() { + // Resetting counters. + LoadAndStoreCounter = 0; + InvariantCounter = 0; + IsReadOnlyLoop = true; + // Iterate over loop blocks and instructions of each block and check + // instruction safety. + for (auto *Block : CurLoop->getBlocks()) + for (auto &Inst : *Block) { + // If instruction is unsafe just return false. + if (!instructionSafeForVersioning(&Inst)) + return false; + } + // Get LoopAccessInfo from current loop. + LAI = &LAA->getInfo(CurLoop); + // Check LoopAccessInfo for need of runtime check. + if (LAI->getRuntimePointerChecking()->getChecks().empty()) { + DEBUG(dbgs() << " LAA: Runtime check not found !!\n"); + return false; + } + // Number of runtime-checks should be less then RuntimeMemoryCheckThreshold + if (LAI->getNumRuntimePointerChecks() > + VectorizerParams::RuntimeMemoryCheckThreshold) { + DEBUG(dbgs() << " LAA: Runtime checks are more than threshold !!\n"); + return false; + } + // Loop should have at least one invariant load or store instruction. + if (!InvariantCounter) { + DEBUG(dbgs() << " Invariant not found !!\n"); + return false; + } + // Read only loop not allowed. + if (IsReadOnlyLoop) { + DEBUG(dbgs() << " Found a read-only loop!\n"); + return false; + } + // Profitablity check: + // Check invariant threshold, should be in limit. + if (InvariantCounter * 100 < InvariantThreshold * LoadAndStoreCounter) { + DEBUG(dbgs() + << " Invariant load & store are less then defined threshold\n"); + DEBUG(dbgs() << " Invariant loads & stores: " + << ((InvariantCounter * 100) / LoadAndStoreCounter) << "%\n"); + DEBUG(dbgs() << " Invariant loads & store threshold: " + << InvariantThreshold << "%\n"); + return false; + } + return true; +} + +/// \brief It checks loop is already visited or not. +/// check loop meta data, if loop revisited return true +/// else false. +bool LoopVersioningLICM::isLoopAlreadyVisited() { + // Check LoopVersioningLICM metadata into loop + if (findStringMetadataForLoop(CurLoop, LICMVersioningMetaData)) { + return true; + } + return false; +} + +/// \brief Checks legality for LoopVersioningLICM by considering following: +/// a) loop structure legality b) loop instruction legality +/// c) loop memory access legality. +/// Return true if legal else returns false. +bool LoopVersioningLICM::isLegalForVersioning() { + DEBUG(dbgs() << "Loop: " << *CurLoop); + // Make sure not re-visiting same loop again. + if (isLoopAlreadyVisited()) { + DEBUG( + dbgs() << " Revisiting loop in LoopVersioningLICM not allowed.\n\n"); + return false; + } + // Check loop structure leagality. + if (!legalLoopStructure()) { + DEBUG( + dbgs() << " Loop structure not suitable for LoopVersioningLICM\n\n"); + return false; + } + // Check loop instruction leagality. + if (!legalLoopInstructions()) { + DEBUG(dbgs() + << " Loop instructions not suitable for LoopVersioningLICM\n\n"); + return false; + } + // Check loop memory access leagality. + if (!legalLoopMemoryAccesses()) { + DEBUG(dbgs() + << " Loop memory access not suitable for LoopVersioningLICM\n\n"); + return false; + } + // Loop versioning is feasible, return true. + DEBUG(dbgs() << " Loop Versioning found to be beneficial\n\n"); + return true; +} + +/// \brief Update loop with aggressive aliasing assumptions. +/// It marks no-alias to any pairs of memory operations by assuming +/// loop should not have any must-alias memory accesses pairs. +/// During LoopVersioningLICM legality we ignore loops having must +/// aliasing memory accesses. +void LoopVersioningLICM::setNoAliasToLoop(Loop *VerLoop) { + // Get latch terminator instruction. + Instruction *I = VerLoop->getLoopLatch()->getTerminator(); + // Create alias scope domain. + MDBuilder MDB(I->getContext()); + MDNode *NewDomain = MDB.createAnonymousAliasScopeDomain("LVDomain"); + StringRef Name = "LVAliasScope"; + SmallVector<Metadata *, 4> Scopes, NoAliases; + MDNode *NewScope = MDB.createAnonymousAliasScope(NewDomain, Name); + // Iterate over each instruction of loop. + // set no-alias for all load & store instructions. + for (auto *Block : CurLoop->getBlocks()) { + for (auto &Inst : *Block) { + // Only interested in instruction that may modify or read memory. + if (!Inst.mayReadFromMemory() && !Inst.mayWriteToMemory()) + continue; + Scopes.push_back(NewScope); + NoAliases.push_back(NewScope); + // Set no-alias for current instruction. + Inst.setMetadata( + LLVMContext::MD_noalias, + MDNode::concatenate(Inst.getMetadata(LLVMContext::MD_noalias), + MDNode::get(Inst.getContext(), NoAliases))); + // set alias-scope for current instruction. + Inst.setMetadata( + LLVMContext::MD_alias_scope, + MDNode::concatenate(Inst.getMetadata(LLVMContext::MD_alias_scope), + MDNode::get(Inst.getContext(), Scopes))); + } + } +} + +bool LoopVersioningLICM::runOnLoop(Loop *L, LPPassManager &LPM) { + // This will automatically release all resources hold by the current + // LoopVersioningLICM object. + AutoResetter Resetter(*this); + + if (skipLoop(L)) + return false; + // Get Analysis information. + AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LAA = &getAnalysis<LoopAccessLegacyAnalysis>(); + LAI = nullptr; + // Set Current Loop + CurLoop = L; + CurAST.reset(new AliasSetTracker(*AA)); + + // Loop over the body of this loop, construct AST. + LoopInfo *LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + for (auto *Block : L->getBlocks()) { + if (LI->getLoopFor(Block) == L) // Ignore blocks in subloop. + CurAST->add(*Block); // Incorporate the specified basic block + } + + bool Changed = false; + + // Check feasiblity of LoopVersioningLICM. + // If versioning found to be feasible and beneficial then proceed + // else simply return, by cleaning up memory. + if (isLegalForVersioning()) { + // Do loop versioning. + // Create memcheck for memory accessed inside loop. + // Clone original loop, and set blocks properly. + DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LoopVersioning LVer(*LAI, CurLoop, LI, DT, SE, true); + LVer.versionLoop(); + // Set Loop Versioning metaData for original loop. + addStringMetadataToLoop(LVer.getNonVersionedLoop(), LICMVersioningMetaData); + // Set Loop Versioning metaData for version loop. + addStringMetadataToLoop(LVer.getVersionedLoop(), LICMVersioningMetaData); + // Set "llvm.mem.parallel_loop_access" metaData to versioned loop. + addStringMetadataToLoop(LVer.getVersionedLoop(), + "llvm.mem.parallel_loop_access"); + // Update version loop with aggressive aliasing assumption. + setNoAliasToLoop(LVer.getVersionedLoop()); + Changed = true; + } + return Changed; +} + +char LoopVersioningLICM::ID = 0; + +INITIALIZE_PASS_BEGIN(LoopVersioningLICM, "loop-versioning-licm", + "Loop Versioning For LICM", false, false) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopAccessLegacyAnalysis) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopSimplify) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_END(LoopVersioningLICM, "loop-versioning-licm", + "Loop Versioning For LICM", false, false) + +Pass *llvm::createLoopVersioningLICMPass() { return new LoopVersioningLICM(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerAtomic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerAtomic.cpp new file mode 100644 index 000000000000..c165c5ece95c --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LowerAtomic.cpp @@ -0,0 +1,172 @@ +//===- LowerAtomic.cpp - Lower atomic intrinsics --------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass lowers atomic intrinsics to non-atomic form for use in a known +// non-preemptible environment. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LowerAtomic.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +using namespace llvm; + +#define DEBUG_TYPE "loweratomic" + +static bool LowerAtomicCmpXchgInst(AtomicCmpXchgInst *CXI) { + IRBuilder<> Builder(CXI); + Value *Ptr = CXI->getPointerOperand(); + Value *Cmp = CXI->getCompareOperand(); + Value *Val = CXI->getNewValOperand(); + + LoadInst *Orig = Builder.CreateLoad(Ptr); + Value *Equal = Builder.CreateICmpEQ(Orig, Cmp); + Value *Res = Builder.CreateSelect(Equal, Val, Orig); + Builder.CreateStore(Res, Ptr); + + Res = Builder.CreateInsertValue(UndefValue::get(CXI->getType()), Orig, 0); + Res = Builder.CreateInsertValue(Res, Equal, 1); + + CXI->replaceAllUsesWith(Res); + CXI->eraseFromParent(); + return true; +} + +static bool LowerAtomicRMWInst(AtomicRMWInst *RMWI) { + IRBuilder<> Builder(RMWI); + Value *Ptr = RMWI->getPointerOperand(); + Value *Val = RMWI->getValOperand(); + + LoadInst *Orig = Builder.CreateLoad(Ptr); + Value *Res = nullptr; + + switch (RMWI->getOperation()) { + default: llvm_unreachable("Unexpected RMW operation"); + case AtomicRMWInst::Xchg: + Res = Val; + break; + case AtomicRMWInst::Add: + Res = Builder.CreateAdd(Orig, Val); + break; + case AtomicRMWInst::Sub: + Res = Builder.CreateSub(Orig, Val); + break; + case AtomicRMWInst::And: + Res = Builder.CreateAnd(Orig, Val); + break; + case AtomicRMWInst::Nand: + Res = Builder.CreateNot(Builder.CreateAnd(Orig, Val)); + break; + case AtomicRMWInst::Or: + Res = Builder.CreateOr(Orig, Val); + break; + case AtomicRMWInst::Xor: + Res = Builder.CreateXor(Orig, Val); + break; + case AtomicRMWInst::Max: + Res = Builder.CreateSelect(Builder.CreateICmpSLT(Orig, Val), + Val, Orig); + break; + case AtomicRMWInst::Min: + Res = Builder.CreateSelect(Builder.CreateICmpSLT(Orig, Val), + Orig, Val); + break; + case AtomicRMWInst::UMax: + Res = Builder.CreateSelect(Builder.CreateICmpULT(Orig, Val), + Val, Orig); + break; + case AtomicRMWInst::UMin: + Res = Builder.CreateSelect(Builder.CreateICmpULT(Orig, Val), + Orig, Val); + break; + } + Builder.CreateStore(Res, Ptr); + RMWI->replaceAllUsesWith(Orig); + RMWI->eraseFromParent(); + return true; +} + +static bool LowerFenceInst(FenceInst *FI) { + FI->eraseFromParent(); + return true; +} + +static bool LowerLoadInst(LoadInst *LI) { + LI->setAtomic(AtomicOrdering::NotAtomic); + return true; +} + +static bool LowerStoreInst(StoreInst *SI) { + SI->setAtomic(AtomicOrdering::NotAtomic); + return true; +} + +static bool runOnBasicBlock(BasicBlock &BB) { + bool Changed = false; + for (BasicBlock::iterator DI = BB.begin(), DE = BB.end(); DI != DE;) { + Instruction *Inst = &*DI++; + if (FenceInst *FI = dyn_cast<FenceInst>(Inst)) + Changed |= LowerFenceInst(FI); + else if (AtomicCmpXchgInst *CXI = dyn_cast<AtomicCmpXchgInst>(Inst)) + Changed |= LowerAtomicCmpXchgInst(CXI); + else if (AtomicRMWInst *RMWI = dyn_cast<AtomicRMWInst>(Inst)) + Changed |= LowerAtomicRMWInst(RMWI); + else if (LoadInst *LI = dyn_cast<LoadInst>(Inst)) { + if (LI->isAtomic()) + LowerLoadInst(LI); + } else if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) { + if (SI->isAtomic()) + LowerStoreInst(SI); + } + } + return Changed; +} + +static bool lowerAtomics(Function &F) { + bool Changed = false; + for (BasicBlock &BB : F) { + Changed |= runOnBasicBlock(BB); + } + return Changed; +} + +PreservedAnalyses LowerAtomicPass::run(Function &F, FunctionAnalysisManager &) { + if (lowerAtomics(F)) + return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +namespace { +class LowerAtomicLegacyPass : public FunctionPass { +public: + static char ID; + + LowerAtomicLegacyPass() : FunctionPass(ID) { + initializeLowerAtomicLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + // Don't skip optnone functions; atomics still need to be lowered. + FunctionAnalysisManager DummyFAM; + auto PA = Impl.run(F, DummyFAM); + return !PA.areAllPreserved(); + } + +private: + LowerAtomicPass Impl; + }; +} + +char LowerAtomicLegacyPass::ID = 0; +INITIALIZE_PASS(LowerAtomicLegacyPass, "loweratomic", + "Lower atomic intrinsics to non-atomic form", false, false) + +Pass *llvm::createLowerAtomicPass() { return new LowerAtomicLegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp new file mode 100644 index 000000000000..46f8a3564265 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LowerExpectIntrinsic.cpp @@ -0,0 +1,383 @@ +//===- LowerExpectIntrinsic.cpp - Lower expect intrinsic ------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass lowers the 'expect' intrinsic to LLVM metadata. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LowerExpectIntrinsic.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" + +using namespace llvm; + +#define DEBUG_TYPE "lower-expect-intrinsic" + +STATISTIC(ExpectIntrinsicsHandled, + "Number of 'expect' intrinsic instructions handled"); + +// These default values are chosen to represent an extremely skewed outcome for +// a condition, but they leave some room for interpretation by later passes. +// +// If the documentation for __builtin_expect() was made explicit that it should +// only be used in extreme cases, we could make this ratio higher. As it stands, +// programmers may be using __builtin_expect() / llvm.expect to annotate that a +// branch is likely or unlikely to be taken. +// +// There is a known dependency on this ratio in CodeGenPrepare when transforming +// 'select' instructions. It may be worthwhile to hoist these values to some +// shared space, so they can be used directly by other passes. + +static cl::opt<uint32_t> LikelyBranchWeight( + "likely-branch-weight", cl::Hidden, cl::init(2000), + cl::desc("Weight of the branch likely to be taken (default = 2000)")); +static cl::opt<uint32_t> UnlikelyBranchWeight( + "unlikely-branch-weight", cl::Hidden, cl::init(1), + cl::desc("Weight of the branch unlikely to be taken (default = 1)")); + +static bool handleSwitchExpect(SwitchInst &SI) { + CallInst *CI = dyn_cast<CallInst>(SI.getCondition()); + if (!CI) + return false; + + Function *Fn = CI->getCalledFunction(); + if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect) + return false; + + Value *ArgValue = CI->getArgOperand(0); + ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!ExpectedValue) + return false; + + SwitchInst::CaseHandle Case = *SI.findCaseValue(ExpectedValue); + unsigned n = SI.getNumCases(); // +1 for default case. + SmallVector<uint32_t, 16> Weights(n + 1, UnlikelyBranchWeight); + + if (Case == *SI.case_default()) + Weights[0] = LikelyBranchWeight; + else + Weights[Case.getCaseIndex() + 1] = LikelyBranchWeight; + + SI.setMetadata(LLVMContext::MD_prof, + MDBuilder(CI->getContext()).createBranchWeights(Weights)); + + SI.setCondition(ArgValue); + return true; +} + +/// Handler for PHINodes that define the value argument to an +/// @llvm.expect call. +/// +/// If the operand of the phi has a constant value and it 'contradicts' +/// with the expected value of phi def, then the corresponding incoming +/// edge of the phi is unlikely to be taken. Using that information, +/// the branch probability info for the originating branch can be inferred. +static void handlePhiDef(CallInst *Expect) { + Value &Arg = *Expect->getArgOperand(0); + ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(Expect->getArgOperand(1)); + if (!ExpectedValue) + return; + const APInt &ExpectedPhiValue = ExpectedValue->getValue(); + + // Walk up in backward a list of instructions that + // have 'copy' semantics by 'stripping' the copies + // until a PHI node or an instruction of unknown kind + // is reached. Negation via xor is also handled. + // + // C = PHI(...); + // B = C; + // A = B; + // D = __builtin_expect(A, 0); + // + Value *V = &Arg; + SmallVector<Instruction *, 4> Operations; + while (!isa<PHINode>(V)) { + if (ZExtInst *ZExt = dyn_cast<ZExtInst>(V)) { + V = ZExt->getOperand(0); + Operations.push_back(ZExt); + continue; + } + + if (SExtInst *SExt = dyn_cast<SExtInst>(V)) { + V = SExt->getOperand(0); + Operations.push_back(SExt); + continue; + } + + BinaryOperator *BinOp = dyn_cast<BinaryOperator>(V); + if (!BinOp || BinOp->getOpcode() != Instruction::Xor) + return; + + ConstantInt *CInt = dyn_cast<ConstantInt>(BinOp->getOperand(1)); + if (!CInt) + return; + + V = BinOp->getOperand(0); + Operations.push_back(BinOp); + } + + // Executes the recorded operations on input 'Value'. + auto ApplyOperations = [&](const APInt &Value) { + APInt Result = Value; + for (auto Op : llvm::reverse(Operations)) { + switch (Op->getOpcode()) { + case Instruction::Xor: + Result ^= cast<ConstantInt>(Op->getOperand(1))->getValue(); + break; + case Instruction::ZExt: + Result = Result.zext(Op->getType()->getIntegerBitWidth()); + break; + case Instruction::SExt: + Result = Result.sext(Op->getType()->getIntegerBitWidth()); + break; + default: + llvm_unreachable("Unexpected operation"); + } + } + return Result; + }; + + auto *PhiDef = dyn_cast<PHINode>(V); + + // Get the first dominating conditional branch of the operand + // i's incoming block. + auto GetDomConditional = [&](unsigned i) -> BranchInst * { + BasicBlock *BB = PhiDef->getIncomingBlock(i); + BranchInst *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (BI && BI->isConditional()) + return BI; + BB = BB->getSinglePredecessor(); + if (!BB) + return nullptr; + BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || BI->isUnconditional()) + return nullptr; + return BI; + }; + + // Now walk through all Phi operands to find phi oprerands with values + // conflicting with the expected phi output value. Any such operand + // indicates the incoming edge to that operand is unlikely. + for (unsigned i = 0, e = PhiDef->getNumIncomingValues(); i != e; ++i) { + + Value *PhiOpnd = PhiDef->getIncomingValue(i); + ConstantInt *CI = dyn_cast<ConstantInt>(PhiOpnd); + if (!CI) + continue; + + // Not an interesting case when IsUnlikely is false -- we can not infer + // anything useful when the operand value matches the expected phi + // output. + if (ExpectedPhiValue == ApplyOperations(CI->getValue())) + continue; + + BranchInst *BI = GetDomConditional(i); + if (!BI) + continue; + + MDBuilder MDB(PhiDef->getContext()); + + // There are two situations in which an operand of the PhiDef comes + // from a given successor of a branch instruction BI. + // 1) When the incoming block of the operand is the successor block; + // 2) When the incoming block is BI's enclosing block and the + // successor is the PhiDef's enclosing block. + // + // Returns true if the operand which comes from OpndIncomingBB + // comes from outgoing edge of BI that leads to Succ block. + auto *OpndIncomingBB = PhiDef->getIncomingBlock(i); + auto IsOpndComingFromSuccessor = [&](BasicBlock *Succ) { + if (OpndIncomingBB == Succ) + // If this successor is the incoming block for this + // Phi operand, then this successor does lead to the Phi. + return true; + if (OpndIncomingBB == BI->getParent() && Succ == PhiDef->getParent()) + // Otherwise, if the edge is directly from the branch + // to the Phi, this successor is the one feeding this + // Phi operand. + return true; + return false; + }; + + if (IsOpndComingFromSuccessor(BI->getSuccessor(1))) + BI->setMetadata( + LLVMContext::MD_prof, + MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight)); + else if (IsOpndComingFromSuccessor(BI->getSuccessor(0))) + BI->setMetadata( + LLVMContext::MD_prof, + MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight)); + } +} + +// Handle both BranchInst and SelectInst. +template <class BrSelInst> static bool handleBrSelExpect(BrSelInst &BSI) { + + // Handle non-optimized IR code like: + // %expval = call i64 @llvm.expect.i64(i64 %conv1, i64 1) + // %tobool = icmp ne i64 %expval, 0 + // br i1 %tobool, label %if.then, label %if.end + // + // Or the following simpler case: + // %expval = call i1 @llvm.expect.i1(i1 %cmp, i1 1) + // br i1 %expval, label %if.then, label %if.end + + CallInst *CI; + + ICmpInst *CmpI = dyn_cast<ICmpInst>(BSI.getCondition()); + CmpInst::Predicate Predicate; + ConstantInt *CmpConstOperand = nullptr; + if (!CmpI) { + CI = dyn_cast<CallInst>(BSI.getCondition()); + Predicate = CmpInst::ICMP_NE; + } else { + Predicate = CmpI->getPredicate(); + if (Predicate != CmpInst::ICMP_NE && Predicate != CmpInst::ICMP_EQ) + return false; + + CmpConstOperand = dyn_cast<ConstantInt>(CmpI->getOperand(1)); + if (!CmpConstOperand) + return false; + CI = dyn_cast<CallInst>(CmpI->getOperand(0)); + } + + if (!CI) + return false; + + uint64_t ValueComparedTo = 0; + if (CmpConstOperand) { + if (CmpConstOperand->getBitWidth() > 64) + return false; + ValueComparedTo = CmpConstOperand->getZExtValue(); + } + + Function *Fn = CI->getCalledFunction(); + if (!Fn || Fn->getIntrinsicID() != Intrinsic::expect) + return false; + + Value *ArgValue = CI->getArgOperand(0); + ConstantInt *ExpectedValue = dyn_cast<ConstantInt>(CI->getArgOperand(1)); + if (!ExpectedValue) + return false; + + MDBuilder MDB(CI->getContext()); + MDNode *Node; + + if ((ExpectedValue->getZExtValue() == ValueComparedTo) == + (Predicate == CmpInst::ICMP_EQ)) + Node = MDB.createBranchWeights(LikelyBranchWeight, UnlikelyBranchWeight); + else + Node = MDB.createBranchWeights(UnlikelyBranchWeight, LikelyBranchWeight); + + BSI.setMetadata(LLVMContext::MD_prof, Node); + + if (CmpI) + CmpI->setOperand(0, ArgValue); + else + BSI.setCondition(ArgValue); + return true; +} + +static bool handleBranchExpect(BranchInst &BI) { + if (BI.isUnconditional()) + return false; + + return handleBrSelExpect<BranchInst>(BI); +} + +static bool lowerExpectIntrinsic(Function &F) { + bool Changed = false; + + for (BasicBlock &BB : F) { + // Create "block_weights" metadata. + if (BranchInst *BI = dyn_cast<BranchInst>(BB.getTerminator())) { + if (handleBranchExpect(*BI)) + ExpectIntrinsicsHandled++; + } else if (SwitchInst *SI = dyn_cast<SwitchInst>(BB.getTerminator())) { + if (handleSwitchExpect(*SI)) + ExpectIntrinsicsHandled++; + } + + // Remove llvm.expect intrinsics. Iterate backwards in order + // to process select instructions before the intrinsic gets + // removed. + for (auto BI = BB.rbegin(), BE = BB.rend(); BI != BE;) { + Instruction *Inst = &*BI++; + CallInst *CI = dyn_cast<CallInst>(Inst); + if (!CI) { + if (SelectInst *SI = dyn_cast<SelectInst>(Inst)) { + if (handleBrSelExpect(*SI)) + ExpectIntrinsicsHandled++; + } + continue; + } + + Function *Fn = CI->getCalledFunction(); + if (Fn && Fn->getIntrinsicID() == Intrinsic::expect) { + // Before erasing the llvm.expect, walk backward to find + // phi that define llvm.expect's first arg, and + // infer branch probability: + handlePhiDef(CI); + Value *Exp = CI->getArgOperand(0); + CI->replaceAllUsesWith(Exp); + CI->eraseFromParent(); + Changed = true; + } + } + } + + return Changed; +} + +PreservedAnalyses LowerExpectIntrinsicPass::run(Function &F, + FunctionAnalysisManager &) { + if (lowerExpectIntrinsic(F)) + return PreservedAnalyses::none(); + + return PreservedAnalyses::all(); +} + +namespace { +/// \brief Legacy pass for lowering expect intrinsics out of the IR. +/// +/// When this pass is run over a function it uses expect intrinsics which feed +/// branches and switches to provide branch weight metadata for those +/// terminators. It then removes the expect intrinsics from the IR so the rest +/// of the optimizer can ignore them. +class LowerExpectIntrinsic : public FunctionPass { +public: + static char ID; + LowerExpectIntrinsic() : FunctionPass(ID) { + initializeLowerExpectIntrinsicPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { return lowerExpectIntrinsic(F); } +}; +} + +char LowerExpectIntrinsic::ID = 0; +INITIALIZE_PASS(LowerExpectIntrinsic, "lower-expect", + "Lower 'expect' Intrinsics", false, false) + +FunctionPass *llvm::createLowerExpectIntrinsicPass() { + return new LowerExpectIntrinsic(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp b/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp new file mode 100644 index 000000000000..070114a84cc5 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/LowerGuardIntrinsic.cpp @@ -0,0 +1,137 @@ +//===- LowerGuardIntrinsic.cpp - Lower the guard intrinsic ---------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass lowers the llvm.experimental.guard intrinsic to a conditional call +// to @llvm.experimental.deoptimize. Once this happens, the guard can no longer +// be widened. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/LowerGuardIntrinsic.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +static cl::opt<uint32_t> PredicatePassBranchWeight( + "guards-predicate-pass-branch-weight", cl::Hidden, cl::init(1 << 20), + cl::desc("The probability of a guard failing is assumed to be the " + "reciprocal of this value (default = 1 << 20)")); + +namespace { +struct LowerGuardIntrinsicLegacyPass : public FunctionPass { + static char ID; + LowerGuardIntrinsicLegacyPass() : FunctionPass(ID) { + initializeLowerGuardIntrinsicLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; +}; +} + +static void MakeGuardControlFlowExplicit(Function *DeoptIntrinsic, + CallInst *CI) { + OperandBundleDef DeoptOB(*CI->getOperandBundle(LLVMContext::OB_deopt)); + SmallVector<Value *, 4> Args(std::next(CI->arg_begin()), CI->arg_end()); + + auto *CheckBB = CI->getParent(); + auto *DeoptBlockTerm = + SplitBlockAndInsertIfThen(CI->getArgOperand(0), CI, true); + + auto *CheckBI = cast<BranchInst>(CheckBB->getTerminator()); + + // SplitBlockAndInsertIfThen inserts control flow that branches to + // DeoptBlockTerm if the condition is true. We want the opposite. + CheckBI->swapSuccessors(); + + CheckBI->getSuccessor(0)->setName("guarded"); + CheckBI->getSuccessor(1)->setName("deopt"); + + if (auto *MD = CI->getMetadata(LLVMContext::MD_make_implicit)) + CheckBI->setMetadata(LLVMContext::MD_make_implicit, MD); + + MDBuilder MDB(CI->getContext()); + CheckBI->setMetadata(LLVMContext::MD_prof, + MDB.createBranchWeights(PredicatePassBranchWeight, 1)); + + IRBuilder<> B(DeoptBlockTerm); + auto *DeoptCall = B.CreateCall(DeoptIntrinsic, Args, {DeoptOB}, ""); + + if (DeoptIntrinsic->getReturnType()->isVoidTy()) { + B.CreateRetVoid(); + } else { + DeoptCall->setName("deoptcall"); + B.CreateRet(DeoptCall); + } + + DeoptCall->setCallingConv(CI->getCallingConv()); + DeoptBlockTerm->eraseFromParent(); +} + +static bool lowerGuardIntrinsic(Function &F) { + // Check if we can cheaply rule out the possibility of not having any work to + // do. + auto *GuardDecl = F.getParent()->getFunction( + Intrinsic::getName(Intrinsic::experimental_guard)); + if (!GuardDecl || GuardDecl->use_empty()) + return false; + + SmallVector<CallInst *, 8> ToLower; + for (auto &I : instructions(F)) + if (auto *CI = dyn_cast<CallInst>(&I)) + if (auto *F = CI->getCalledFunction()) + if (F->getIntrinsicID() == Intrinsic::experimental_guard) + ToLower.push_back(CI); + + if (ToLower.empty()) + return false; + + auto *DeoptIntrinsic = Intrinsic::getDeclaration( + F.getParent(), Intrinsic::experimental_deoptimize, {F.getReturnType()}); + DeoptIntrinsic->setCallingConv(GuardDecl->getCallingConv()); + + for (auto *CI : ToLower) { + MakeGuardControlFlowExplicit(DeoptIntrinsic, CI); + CI->eraseFromParent(); + } + + return true; +} + +bool LowerGuardIntrinsicLegacyPass::runOnFunction(Function &F) { + return lowerGuardIntrinsic(F); +} + +char LowerGuardIntrinsicLegacyPass::ID = 0; +INITIALIZE_PASS(LowerGuardIntrinsicLegacyPass, "lower-guard-intrinsic", + "Lower the guard intrinsic to normal control flow", false, + false) + +Pass *llvm::createLowerGuardIntrinsicPass() { + return new LowerGuardIntrinsicLegacyPass(); +} + +PreservedAnalyses LowerGuardIntrinsicPass::run(Function &F, + FunctionAnalysisManager &AM) { + if (lowerGuardIntrinsic(F)) + return PreservedAnalyses::none(); + + return PreservedAnalyses::all(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp b/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp new file mode 100644 index 000000000000..9c870b42a747 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/MemCpyOptimizer.cpp @@ -0,0 +1,1492 @@ +//===- MemCpyOptimizer.cpp - Optimize use of memcpy and friends -----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass performs various transformations related to eliminating memcpy +// calls, or transforming sets of stores into memset's. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/MemCpyOptimizer.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/MemoryLocation.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "memcpyopt" + +STATISTIC(NumMemCpyInstr, "Number of memcpy instructions deleted"); +STATISTIC(NumMemSetInfer, "Number of memsets inferred"); +STATISTIC(NumMoveToCpy, "Number of memmoves converted to memcpy"); +STATISTIC(NumCpyToSet, "Number of memcpys converted to memset"); + +static int64_t GetOffsetFromIndex(const GEPOperator *GEP, unsigned Idx, + bool &VariableIdxFound, + const DataLayout &DL) { + // Skip over the first indices. + gep_type_iterator GTI = gep_type_begin(GEP); + for (unsigned i = 1; i != Idx; ++i, ++GTI) + /*skip along*/; + + // Compute the offset implied by the rest of the indices. + int64_t Offset = 0; + for (unsigned i = Idx, e = GEP->getNumOperands(); i != e; ++i, ++GTI) { + ConstantInt *OpC = dyn_cast<ConstantInt>(GEP->getOperand(i)); + if (!OpC) + return VariableIdxFound = true; + if (OpC->isZero()) continue; // No offset. + + // Handle struct indices, which add their field offset to the pointer. + if (StructType *STy = GTI.getStructTypeOrNull()) { + Offset += DL.getStructLayout(STy)->getElementOffset(OpC->getZExtValue()); + continue; + } + + // Otherwise, we have a sequential type like an array or vector. Multiply + // the index by the ElementSize. + uint64_t Size = DL.getTypeAllocSize(GTI.getIndexedType()); + Offset += Size*OpC->getSExtValue(); + } + + return Offset; +} + +/// Return true if Ptr1 is provably equal to Ptr2 plus a constant offset, and +/// return that constant offset. For example, Ptr1 might be &A[42], and Ptr2 +/// might be &A[40]. In this case offset would be -8. +static bool IsPointerOffset(Value *Ptr1, Value *Ptr2, int64_t &Offset, + const DataLayout &DL) { + Ptr1 = Ptr1->stripPointerCasts(); + Ptr2 = Ptr2->stripPointerCasts(); + + // Handle the trivial case first. + if (Ptr1 == Ptr2) { + Offset = 0; + return true; + } + + GEPOperator *GEP1 = dyn_cast<GEPOperator>(Ptr1); + GEPOperator *GEP2 = dyn_cast<GEPOperator>(Ptr2); + + bool VariableIdxFound = false; + + // If one pointer is a GEP and the other isn't, then see if the GEP is a + // constant offset from the base, as in "P" and "gep P, 1". + if (GEP1 && !GEP2 && GEP1->getOperand(0)->stripPointerCasts() == Ptr2) { + Offset = -GetOffsetFromIndex(GEP1, 1, VariableIdxFound, DL); + return !VariableIdxFound; + } + + if (GEP2 && !GEP1 && GEP2->getOperand(0)->stripPointerCasts() == Ptr1) { + Offset = GetOffsetFromIndex(GEP2, 1, VariableIdxFound, DL); + return !VariableIdxFound; + } + + // Right now we handle the case when Ptr1/Ptr2 are both GEPs with an identical + // base. After that base, they may have some number of common (and + // potentially variable) indices. After that they handle some constant + // offset, which determines their offset from each other. At this point, we + // handle no other case. + if (!GEP1 || !GEP2 || GEP1->getOperand(0) != GEP2->getOperand(0)) + return false; + + // Skip any common indices and track the GEP types. + unsigned Idx = 1; + for (; Idx != GEP1->getNumOperands() && Idx != GEP2->getNumOperands(); ++Idx) + if (GEP1->getOperand(Idx) != GEP2->getOperand(Idx)) + break; + + int64_t Offset1 = GetOffsetFromIndex(GEP1, Idx, VariableIdxFound, DL); + int64_t Offset2 = GetOffsetFromIndex(GEP2, Idx, VariableIdxFound, DL); + if (VariableIdxFound) return false; + + Offset = Offset2-Offset1; + return true; +} + +namespace { + +/// Represents a range of memset'd bytes with the ByteVal value. +/// This allows us to analyze stores like: +/// store 0 -> P+1 +/// store 0 -> P+0 +/// store 0 -> P+3 +/// store 0 -> P+2 +/// which sometimes happens with stores to arrays of structs etc. When we see +/// the first store, we make a range [1, 2). The second store extends the range +/// to [0, 2). The third makes a new range [2, 3). The fourth store joins the +/// two ranges into [0, 3) which is memset'able. +struct MemsetRange { + // Start/End - A semi range that describes the span that this range covers. + // The range is closed at the start and open at the end: [Start, End). + int64_t Start, End; + + /// StartPtr - The getelementptr instruction that points to the start of the + /// range. + Value *StartPtr; + + /// Alignment - The known alignment of the first store. + unsigned Alignment; + + /// TheStores - The actual stores that make up this range. + SmallVector<Instruction*, 16> TheStores; + + bool isProfitableToUseMemset(const DataLayout &DL) const; +}; + +} // end anonymous namespace + +bool MemsetRange::isProfitableToUseMemset(const DataLayout &DL) const { + // If we found more than 4 stores to merge or 16 bytes, use memset. + if (TheStores.size() >= 4 || End-Start >= 16) return true; + + // If there is nothing to merge, don't do anything. + if (TheStores.size() < 2) return false; + + // If any of the stores are a memset, then it is always good to extend the + // memset. + for (Instruction *SI : TheStores) + if (!isa<StoreInst>(SI)) + return true; + + // Assume that the code generator is capable of merging pairs of stores + // together if it wants to. + if (TheStores.size() == 2) return false; + + // If we have fewer than 8 stores, it can still be worthwhile to do this. + // For example, merging 4 i8 stores into an i32 store is useful almost always. + // However, merging 2 32-bit stores isn't useful on a 32-bit architecture (the + // memset will be split into 2 32-bit stores anyway) and doing so can + // pessimize the llvm optimizer. + // + // Since we don't have perfect knowledge here, make some assumptions: assume + // the maximum GPR width is the same size as the largest legal integer + // size. If so, check to see whether we will end up actually reducing the + // number of stores used. + unsigned Bytes = unsigned(End-Start); + unsigned MaxIntSize = DL.getLargestLegalIntTypeSizeInBits() / 8; + if (MaxIntSize == 0) + MaxIntSize = 1; + unsigned NumPointerStores = Bytes / MaxIntSize; + + // Assume the remaining bytes if any are done a byte at a time. + unsigned NumByteStores = Bytes % MaxIntSize; + + // If we will reduce the # stores (according to this heuristic), do the + // transformation. This encourages merging 4 x i8 -> i32 and 2 x i16 -> i32 + // etc. + return TheStores.size() > NumPointerStores+NumByteStores; +} + +namespace { + +class MemsetRanges { + using range_iterator = SmallVectorImpl<MemsetRange>::iterator; + + /// A sorted list of the memset ranges. + SmallVector<MemsetRange, 8> Ranges; + + const DataLayout &DL; + +public: + MemsetRanges(const DataLayout &DL) : DL(DL) {} + + using const_iterator = SmallVectorImpl<MemsetRange>::const_iterator; + + const_iterator begin() const { return Ranges.begin(); } + const_iterator end() const { return Ranges.end(); } + bool empty() const { return Ranges.empty(); } + + void addInst(int64_t OffsetFromFirst, Instruction *Inst) { + if (StoreInst *SI = dyn_cast<StoreInst>(Inst)) + addStore(OffsetFromFirst, SI); + else + addMemSet(OffsetFromFirst, cast<MemSetInst>(Inst)); + } + + void addStore(int64_t OffsetFromFirst, StoreInst *SI) { + int64_t StoreSize = DL.getTypeStoreSize(SI->getOperand(0)->getType()); + + addRange(OffsetFromFirst, StoreSize, + SI->getPointerOperand(), SI->getAlignment(), SI); + } + + void addMemSet(int64_t OffsetFromFirst, MemSetInst *MSI) { + int64_t Size = cast<ConstantInt>(MSI->getLength())->getZExtValue(); + addRange(OffsetFromFirst, Size, MSI->getDest(), MSI->getAlignment(), MSI); + } + + void addRange(int64_t Start, int64_t Size, Value *Ptr, + unsigned Alignment, Instruction *Inst); +}; + +} // end anonymous namespace + +/// Add a new store to the MemsetRanges data structure. This adds a +/// new range for the specified store at the specified offset, merging into +/// existing ranges as appropriate. +void MemsetRanges::addRange(int64_t Start, int64_t Size, Value *Ptr, + unsigned Alignment, Instruction *Inst) { + int64_t End = Start+Size; + + range_iterator I = std::lower_bound(Ranges.begin(), Ranges.end(), Start, + [](const MemsetRange &LHS, int64_t RHS) { return LHS.End < RHS; }); + + // We now know that I == E, in which case we didn't find anything to merge + // with, or that Start <= I->End. If End < I->Start or I == E, then we need + // to insert a new range. Handle this now. + if (I == Ranges.end() || End < I->Start) { + MemsetRange &R = *Ranges.insert(I, MemsetRange()); + R.Start = Start; + R.End = End; + R.StartPtr = Ptr; + R.Alignment = Alignment; + R.TheStores.push_back(Inst); + return; + } + + // This store overlaps with I, add it. + I->TheStores.push_back(Inst); + + // At this point, we may have an interval that completely contains our store. + // If so, just add it to the interval and return. + if (I->Start <= Start && I->End >= End) + return; + + // Now we know that Start <= I->End and End >= I->Start so the range overlaps + // but is not entirely contained within the range. + + // See if the range extends the start of the range. In this case, it couldn't + // possibly cause it to join the prior range, because otherwise we would have + // stopped on *it*. + if (Start < I->Start) { + I->Start = Start; + I->StartPtr = Ptr; + I->Alignment = Alignment; + } + + // Now we know that Start <= I->End and Start >= I->Start (so the startpoint + // is in or right at the end of I), and that End >= I->Start. Extend I out to + // End. + if (End > I->End) { + I->End = End; + range_iterator NextI = I; + while (++NextI != Ranges.end() && End >= NextI->Start) { + // Merge the range in. + I->TheStores.append(NextI->TheStores.begin(), NextI->TheStores.end()); + if (NextI->End > I->End) + I->End = NextI->End; + Ranges.erase(NextI); + NextI = I; + } + } +} + +//===----------------------------------------------------------------------===// +// MemCpyOptLegacyPass Pass +//===----------------------------------------------------------------------===// + +namespace { + +class MemCpyOptLegacyPass : public FunctionPass { + MemCpyOptPass Impl; + +public: + static char ID; // Pass identification, replacement for typeid + + MemCpyOptLegacyPass() : FunctionPass(ID) { + initializeMemCpyOptLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + +private: + // This transformation requires dominator postdominator info + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<MemoryDependenceWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); + } +}; + +} // end anonymous namespace + +char MemCpyOptLegacyPass::ID = 0; + +/// The public interface to this file... +FunctionPass *llvm::createMemCpyOptPass() { return new MemCpyOptLegacyPass(); } + +INITIALIZE_PASS_BEGIN(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization", + false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_END(MemCpyOptLegacyPass, "memcpyopt", "MemCpy Optimization", + false, false) + +/// When scanning forward over instructions, we look for some other patterns to +/// fold away. In particular, this looks for stores to neighboring locations of +/// memory. If it sees enough consecutive ones, it attempts to merge them +/// together into a memcpy/memset. +Instruction *MemCpyOptPass::tryMergingIntoMemset(Instruction *StartInst, + Value *StartPtr, + Value *ByteVal) { + const DataLayout &DL = StartInst->getModule()->getDataLayout(); + + // Okay, so we now have a single store that can be splatable. Scan to find + // all subsequent stores of the same value to offset from the same pointer. + // Join these together into ranges, so we can decide whether contiguous blocks + // are stored. + MemsetRanges Ranges(DL); + + BasicBlock::iterator BI(StartInst); + for (++BI; !isa<TerminatorInst>(BI); ++BI) { + if (!isa<StoreInst>(BI) && !isa<MemSetInst>(BI)) { + // If the instruction is readnone, ignore it, otherwise bail out. We + // don't even allow readonly here because we don't want something like: + // A[1] = 2; strlen(A); A[2] = 2; -> memcpy(A, ...); strlen(A). + if (BI->mayWriteToMemory() || BI->mayReadFromMemory()) + break; + continue; + } + + if (StoreInst *NextStore = dyn_cast<StoreInst>(BI)) { + // If this is a store, see if we can merge it in. + if (!NextStore->isSimple()) break; + + // Check to see if this stored value is of the same byte-splattable value. + if (ByteVal != isBytewiseValue(NextStore->getOperand(0))) + break; + + // Check to see if this store is to a constant offset from the start ptr. + int64_t Offset; + if (!IsPointerOffset(StartPtr, NextStore->getPointerOperand(), Offset, + DL)) + break; + + Ranges.addStore(Offset, NextStore); + } else { + MemSetInst *MSI = cast<MemSetInst>(BI); + + if (MSI->isVolatile() || ByteVal != MSI->getValue() || + !isa<ConstantInt>(MSI->getLength())) + break; + + // Check to see if this store is to a constant offset from the start ptr. + int64_t Offset; + if (!IsPointerOffset(StartPtr, MSI->getDest(), Offset, DL)) + break; + + Ranges.addMemSet(Offset, MSI); + } + } + + // If we have no ranges, then we just had a single store with nothing that + // could be merged in. This is a very common case of course. + if (Ranges.empty()) + return nullptr; + + // If we had at least one store that could be merged in, add the starting + // store as well. We try to avoid this unless there is at least something + // interesting as a small compile-time optimization. + Ranges.addInst(0, StartInst); + + // If we create any memsets, we put it right before the first instruction that + // isn't part of the memset block. This ensure that the memset is dominated + // by any addressing instruction needed by the start of the block. + IRBuilder<> Builder(&*BI); + + // Now that we have full information about ranges, loop over the ranges and + // emit memset's for anything big enough to be worthwhile. + Instruction *AMemSet = nullptr; + for (const MemsetRange &Range : Ranges) { + if (Range.TheStores.size() == 1) continue; + + // If it is profitable to lower this range to memset, do so now. + if (!Range.isProfitableToUseMemset(DL)) + continue; + + // Otherwise, we do want to transform this! Create a new memset. + // Get the starting pointer of the block. + StartPtr = Range.StartPtr; + + // Determine alignment + unsigned Alignment = Range.Alignment; + if (Alignment == 0) { + Type *EltType = + cast<PointerType>(StartPtr->getType())->getElementType(); + Alignment = DL.getABITypeAlignment(EltType); + } + + AMemSet = + Builder.CreateMemSet(StartPtr, ByteVal, Range.End-Range.Start, Alignment); + + DEBUG(dbgs() << "Replace stores:\n"; + for (Instruction *SI : Range.TheStores) + dbgs() << *SI << '\n'; + dbgs() << "With: " << *AMemSet << '\n'); + + if (!Range.TheStores.empty()) + AMemSet->setDebugLoc(Range.TheStores[0]->getDebugLoc()); + + // Zap all the stores. + for (Instruction *SI : Range.TheStores) { + MD->removeInstruction(SI); + SI->eraseFromParent(); + } + ++NumMemSetInfer; + } + + return AMemSet; +} + +static unsigned findCommonAlignment(const DataLayout &DL, const StoreInst *SI, + const LoadInst *LI) { + unsigned StoreAlign = SI->getAlignment(); + if (!StoreAlign) + StoreAlign = DL.getABITypeAlignment(SI->getOperand(0)->getType()); + unsigned LoadAlign = LI->getAlignment(); + if (!LoadAlign) + LoadAlign = DL.getABITypeAlignment(LI->getType()); + + return std::min(StoreAlign, LoadAlign); +} + +// This method try to lift a store instruction before position P. +// It will lift the store and its argument + that anything that +// may alias with these. +// The method returns true if it was successful. +static bool moveUp(AliasAnalysis &AA, StoreInst *SI, Instruction *P, + const LoadInst *LI) { + // If the store alias this position, early bail out. + MemoryLocation StoreLoc = MemoryLocation::get(SI); + if (isModOrRefSet(AA.getModRefInfo(P, StoreLoc))) + return false; + + // Keep track of the arguments of all instruction we plan to lift + // so we can make sure to lift them as well if apropriate. + DenseSet<Instruction*> Args; + if (auto *Ptr = dyn_cast<Instruction>(SI->getPointerOperand())) + if (Ptr->getParent() == SI->getParent()) + Args.insert(Ptr); + + // Instruction to lift before P. + SmallVector<Instruction*, 8> ToLift; + + // Memory locations of lifted instructions. + SmallVector<MemoryLocation, 8> MemLocs{StoreLoc}; + + // Lifted callsites. + SmallVector<ImmutableCallSite, 8> CallSites; + + const MemoryLocation LoadLoc = MemoryLocation::get(LI); + + for (auto I = --SI->getIterator(), E = P->getIterator(); I != E; --I) { + auto *C = &*I; + + bool MayAlias = isModOrRefSet(AA.getModRefInfo(C, None)); + + bool NeedLift = false; + if (Args.erase(C)) + NeedLift = true; + else if (MayAlias) { + NeedLift = llvm::any_of(MemLocs, [C, &AA](const MemoryLocation &ML) { + return isModOrRefSet(AA.getModRefInfo(C, ML)); + }); + + if (!NeedLift) + NeedLift = + llvm::any_of(CallSites, [C, &AA](const ImmutableCallSite &CS) { + return isModOrRefSet(AA.getModRefInfo(C, CS)); + }); + } + + if (!NeedLift) + continue; + + if (MayAlias) { + // Since LI is implicitly moved downwards past the lifted instructions, + // none of them may modify its source. + if (isModSet(AA.getModRefInfo(C, LoadLoc))) + return false; + else if (auto CS = ImmutableCallSite(C)) { + // If we can't lift this before P, it's game over. + if (isModOrRefSet(AA.getModRefInfo(P, CS))) + return false; + + CallSites.push_back(CS); + } else if (isa<LoadInst>(C) || isa<StoreInst>(C) || isa<VAArgInst>(C)) { + // If we can't lift this before P, it's game over. + auto ML = MemoryLocation::get(C); + if (isModOrRefSet(AA.getModRefInfo(P, ML))) + return false; + + MemLocs.push_back(ML); + } else + // We don't know how to lift this instruction. + return false; + } + + ToLift.push_back(C); + for (unsigned k = 0, e = C->getNumOperands(); k != e; ++k) + if (auto *A = dyn_cast<Instruction>(C->getOperand(k))) + if (A->getParent() == SI->getParent()) + Args.insert(A); + } + + // We made it, we need to lift + for (auto *I : llvm::reverse(ToLift)) { + DEBUG(dbgs() << "Lifting " << *I << " before " << *P << "\n"); + I->moveBefore(P); + } + + return true; +} + +bool MemCpyOptPass::processStore(StoreInst *SI, BasicBlock::iterator &BBI) { + if (!SI->isSimple()) return false; + + // Avoid merging nontemporal stores since the resulting + // memcpy/memset would not be able to preserve the nontemporal hint. + // In theory we could teach how to propagate the !nontemporal metadata to + // memset calls. However, that change would force the backend to + // conservatively expand !nontemporal memset calls back to sequences of + // store instructions (effectively undoing the merging). + if (SI->getMetadata(LLVMContext::MD_nontemporal)) + return false; + + const DataLayout &DL = SI->getModule()->getDataLayout(); + + // Load to store forwarding can be interpreted as memcpy. + if (LoadInst *LI = dyn_cast<LoadInst>(SI->getOperand(0))) { + if (LI->isSimple() && LI->hasOneUse() && + LI->getParent() == SI->getParent()) { + + auto *T = LI->getType(); + if (T->isAggregateType()) { + AliasAnalysis &AA = LookupAliasAnalysis(); + MemoryLocation LoadLoc = MemoryLocation::get(LI); + + // We use alias analysis to check if an instruction may store to + // the memory we load from in between the load and the store. If + // such an instruction is found, we try to promote there instead + // of at the store position. + Instruction *P = SI; + for (auto &I : make_range(++LI->getIterator(), SI->getIterator())) { + if (isModSet(AA.getModRefInfo(&I, LoadLoc))) { + P = &I; + break; + } + } + + // We found an instruction that may write to the loaded memory. + // We can try to promote at this position instead of the store + // position if nothing alias the store memory after this and the store + // destination is not in the range. + if (P && P != SI) { + if (!moveUp(AA, SI, P, LI)) + P = nullptr; + } + + // If a valid insertion position is found, then we can promote + // the load/store pair to a memcpy. + if (P) { + // If we load from memory that may alias the memory we store to, + // memmove must be used to preserve semantic. If not, memcpy can + // be used. + bool UseMemMove = false; + if (!AA.isNoAlias(MemoryLocation::get(SI), LoadLoc)) + UseMemMove = true; + + unsigned Align = findCommonAlignment(DL, SI, LI); + uint64_t Size = DL.getTypeStoreSize(T); + + IRBuilder<> Builder(P); + Instruction *M; + if (UseMemMove) + M = Builder.CreateMemMove(SI->getPointerOperand(), + LI->getPointerOperand(), Size, + Align, SI->isVolatile()); + else + M = Builder.CreateMemCpy(SI->getPointerOperand(), + LI->getPointerOperand(), Size, + Align, SI->isVolatile()); + + DEBUG(dbgs() << "Promoting " << *LI << " to " << *SI + << " => " << *M << "\n"); + + MD->removeInstruction(SI); + SI->eraseFromParent(); + MD->removeInstruction(LI); + LI->eraseFromParent(); + ++NumMemCpyInstr; + + // Make sure we do not invalidate the iterator. + BBI = M->getIterator(); + return true; + } + } + + // Detect cases where we're performing call slot forwarding, but + // happen to be using a load-store pair to implement it, rather than + // a memcpy. + MemDepResult ldep = MD->getDependency(LI); + CallInst *C = nullptr; + if (ldep.isClobber() && !isa<MemCpyInst>(ldep.getInst())) + C = dyn_cast<CallInst>(ldep.getInst()); + + if (C) { + // Check that nothing touches the dest of the "copy" between + // the call and the store. + Value *CpyDest = SI->getPointerOperand()->stripPointerCasts(); + bool CpyDestIsLocal = isa<AllocaInst>(CpyDest); + AliasAnalysis &AA = LookupAliasAnalysis(); + MemoryLocation StoreLoc = MemoryLocation::get(SI); + for (BasicBlock::iterator I = --SI->getIterator(), E = C->getIterator(); + I != E; --I) { + if (isModOrRefSet(AA.getModRefInfo(&*I, StoreLoc))) { + C = nullptr; + break; + } + // The store to dest may never happen if an exception can be thrown + // between the load and the store. + if (I->mayThrow() && !CpyDestIsLocal) { + C = nullptr; + break; + } + } + } + + if (C) { + bool changed = performCallSlotOptzn( + LI, SI->getPointerOperand()->stripPointerCasts(), + LI->getPointerOperand()->stripPointerCasts(), + DL.getTypeStoreSize(SI->getOperand(0)->getType()), + findCommonAlignment(DL, SI, LI), C); + if (changed) { + MD->removeInstruction(SI); + SI->eraseFromParent(); + MD->removeInstruction(LI); + LI->eraseFromParent(); + ++NumMemCpyInstr; + return true; + } + } + } + } + + // There are two cases that are interesting for this code to handle: memcpy + // and memset. Right now we only handle memset. + + // Ensure that the value being stored is something that can be memset'able a + // byte at a time like "0" or "-1" or any width, as well as things like + // 0xA0A0A0A0 and 0.0. + auto *V = SI->getOperand(0); + if (Value *ByteVal = isBytewiseValue(V)) { + if (Instruction *I = tryMergingIntoMemset(SI, SI->getPointerOperand(), + ByteVal)) { + BBI = I->getIterator(); // Don't invalidate iterator. + return true; + } + + // If we have an aggregate, we try to promote it to memset regardless + // of opportunity for merging as it can expose optimization opportunities + // in subsequent passes. + auto *T = V->getType(); + if (T->isAggregateType()) { + uint64_t Size = DL.getTypeStoreSize(T); + unsigned Align = SI->getAlignment(); + if (!Align) + Align = DL.getABITypeAlignment(T); + IRBuilder<> Builder(SI); + auto *M = Builder.CreateMemSet(SI->getPointerOperand(), ByteVal, + Size, Align, SI->isVolatile()); + + DEBUG(dbgs() << "Promoting " << *SI << " to " << *M << "\n"); + + MD->removeInstruction(SI); + SI->eraseFromParent(); + NumMemSetInfer++; + + // Make sure we do not invalidate the iterator. + BBI = M->getIterator(); + return true; + } + } + + return false; +} + +bool MemCpyOptPass::processMemSet(MemSetInst *MSI, BasicBlock::iterator &BBI) { + // See if there is another memset or store neighboring this memset which + // allows us to widen out the memset to do a single larger store. + if (isa<ConstantInt>(MSI->getLength()) && !MSI->isVolatile()) + if (Instruction *I = tryMergingIntoMemset(MSI, MSI->getDest(), + MSI->getValue())) { + BBI = I->getIterator(); // Don't invalidate iterator. + return true; + } + return false; +} + +/// Takes a memcpy and a call that it depends on, +/// and checks for the possibility of a call slot optimization by having +/// the call write its result directly into the destination of the memcpy. +bool MemCpyOptPass::performCallSlotOptzn(Instruction *cpy, Value *cpyDest, + Value *cpySrc, uint64_t cpyLen, + unsigned cpyAlign, CallInst *C) { + // The general transformation to keep in mind is + // + // call @func(..., src, ...) + // memcpy(dest, src, ...) + // + // -> + // + // memcpy(dest, src, ...) + // call @func(..., dest, ...) + // + // Since moving the memcpy is technically awkward, we additionally check that + // src only holds uninitialized values at the moment of the call, meaning that + // the memcpy can be discarded rather than moved. + + // Lifetime marks shouldn't be operated on. + if (Function *F = C->getCalledFunction()) + if (F->isIntrinsic() && F->getIntrinsicID() == Intrinsic::lifetime_start) + return false; + + // Deliberately get the source and destination with bitcasts stripped away, + // because we'll need to do type comparisons based on the underlying type. + CallSite CS(C); + + // Require that src be an alloca. This simplifies the reasoning considerably. + AllocaInst *srcAlloca = dyn_cast<AllocaInst>(cpySrc); + if (!srcAlloca) + return false; + + ConstantInt *srcArraySize = dyn_cast<ConstantInt>(srcAlloca->getArraySize()); + if (!srcArraySize) + return false; + + const DataLayout &DL = cpy->getModule()->getDataLayout(); + uint64_t srcSize = DL.getTypeAllocSize(srcAlloca->getAllocatedType()) * + srcArraySize->getZExtValue(); + + if (cpyLen < srcSize) + return false; + + // Check that accessing the first srcSize bytes of dest will not cause a + // trap. Otherwise the transform is invalid since it might cause a trap + // to occur earlier than it otherwise would. + if (AllocaInst *A = dyn_cast<AllocaInst>(cpyDest)) { + // The destination is an alloca. Check it is larger than srcSize. + ConstantInt *destArraySize = dyn_cast<ConstantInt>(A->getArraySize()); + if (!destArraySize) + return false; + + uint64_t destSize = DL.getTypeAllocSize(A->getAllocatedType()) * + destArraySize->getZExtValue(); + + if (destSize < srcSize) + return false; + } else if (Argument *A = dyn_cast<Argument>(cpyDest)) { + // The store to dest may never happen if the call can throw. + if (C->mayThrow()) + return false; + + if (A->getDereferenceableBytes() < srcSize) { + // If the destination is an sret parameter then only accesses that are + // outside of the returned struct type can trap. + if (!A->hasStructRetAttr()) + return false; + + Type *StructTy = cast<PointerType>(A->getType())->getElementType(); + if (!StructTy->isSized()) { + // The call may never return and hence the copy-instruction may never + // be executed, and therefore it's not safe to say "the destination + // has at least <cpyLen> bytes, as implied by the copy-instruction", + return false; + } + + uint64_t destSize = DL.getTypeAllocSize(StructTy); + if (destSize < srcSize) + return false; + } + } else { + return false; + } + + // Check that dest points to memory that is at least as aligned as src. + unsigned srcAlign = srcAlloca->getAlignment(); + if (!srcAlign) + srcAlign = DL.getABITypeAlignment(srcAlloca->getAllocatedType()); + bool isDestSufficientlyAligned = srcAlign <= cpyAlign; + // If dest is not aligned enough and we can't increase its alignment then + // bail out. + if (!isDestSufficientlyAligned && !isa<AllocaInst>(cpyDest)) + return false; + + // Check that src is not accessed except via the call and the memcpy. This + // guarantees that it holds only undefined values when passed in (so the final + // memcpy can be dropped), that it is not read or written between the call and + // the memcpy, and that writing beyond the end of it is undefined. + SmallVector<User*, 8> srcUseList(srcAlloca->user_begin(), + srcAlloca->user_end()); + while (!srcUseList.empty()) { + User *U = srcUseList.pop_back_val(); + + if (isa<BitCastInst>(U) || isa<AddrSpaceCastInst>(U)) { + for (User *UU : U->users()) + srcUseList.push_back(UU); + continue; + } + if (GetElementPtrInst *G = dyn_cast<GetElementPtrInst>(U)) { + if (!G->hasAllZeroIndices()) + return false; + + for (User *UU : U->users()) + srcUseList.push_back(UU); + continue; + } + if (const IntrinsicInst *IT = dyn_cast<IntrinsicInst>(U)) + if (IT->getIntrinsicID() == Intrinsic::lifetime_start || + IT->getIntrinsicID() == Intrinsic::lifetime_end) + continue; + + if (U != C && U != cpy) + return false; + } + + // Check that src isn't captured by the called function since the + // transformation can cause aliasing issues in that case. + for (unsigned i = 0, e = CS.arg_size(); i != e; ++i) + if (CS.getArgument(i) == cpySrc && !CS.doesNotCapture(i)) + return false; + + // Since we're changing the parameter to the callsite, we need to make sure + // that what would be the new parameter dominates the callsite. + DominatorTree &DT = LookupDomTree(); + if (Instruction *cpyDestInst = dyn_cast<Instruction>(cpyDest)) + if (!DT.dominates(cpyDestInst, C)) + return false; + + // In addition to knowing that the call does not access src in some + // unexpected manner, for example via a global, which we deduce from + // the use analysis, we also need to know that it does not sneakily + // access dest. We rely on AA to figure this out for us. + AliasAnalysis &AA = LookupAliasAnalysis(); + ModRefInfo MR = AA.getModRefInfo(C, cpyDest, srcSize); + // If necessary, perform additional analysis. + if (isModOrRefSet(MR)) + MR = AA.callCapturesBefore(C, cpyDest, srcSize, &DT); + if (isModOrRefSet(MR)) + return false; + + // We can't create address space casts here because we don't know if they're + // safe for the target. + if (cpySrc->getType()->getPointerAddressSpace() != + cpyDest->getType()->getPointerAddressSpace()) + return false; + for (unsigned i = 0; i < CS.arg_size(); ++i) + if (CS.getArgument(i)->stripPointerCasts() == cpySrc && + cpySrc->getType()->getPointerAddressSpace() != + CS.getArgument(i)->getType()->getPointerAddressSpace()) + return false; + + // All the checks have passed, so do the transformation. + bool changedArgument = false; + for (unsigned i = 0; i < CS.arg_size(); ++i) + if (CS.getArgument(i)->stripPointerCasts() == cpySrc) { + Value *Dest = cpySrc->getType() == cpyDest->getType() ? cpyDest + : CastInst::CreatePointerCast(cpyDest, cpySrc->getType(), + cpyDest->getName(), C); + changedArgument = true; + if (CS.getArgument(i)->getType() == Dest->getType()) + CS.setArgument(i, Dest); + else + CS.setArgument(i, CastInst::CreatePointerCast(Dest, + CS.getArgument(i)->getType(), Dest->getName(), C)); + } + + if (!changedArgument) + return false; + + // If the destination wasn't sufficiently aligned then increase its alignment. + if (!isDestSufficientlyAligned) { + assert(isa<AllocaInst>(cpyDest) && "Can only increase alloca alignment!"); + cast<AllocaInst>(cpyDest)->setAlignment(srcAlign); + } + + // Drop any cached information about the call, because we may have changed + // its dependence information by changing its parameter. + MD->removeInstruction(C); + + // Update AA metadata + // FIXME: MD_tbaa_struct and MD_mem_parallel_loop_access should also be + // handled here, but combineMetadata doesn't support them yet + unsigned KnownIDs[] = {LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, + LLVMContext::MD_invariant_group}; + combineMetadata(C, cpy, KnownIDs); + + // Remove the memcpy. + MD->removeInstruction(cpy); + ++NumMemCpyInstr; + + return true; +} + +/// We've found that the (upward scanning) memory dependence of memcpy 'M' is +/// the memcpy 'MDep'. Try to simplify M to copy from MDep's input if we can. +bool MemCpyOptPass::processMemCpyMemCpyDependence(MemCpyInst *M, + MemCpyInst *MDep) { + // We can only transforms memcpy's where the dest of one is the source of the + // other. + if (M->getSource() != MDep->getDest() || MDep->isVolatile()) + return false; + + // If dep instruction is reading from our current input, then it is a noop + // transfer and substituting the input won't change this instruction. Just + // ignore the input and let someone else zap MDep. This handles cases like: + // memcpy(a <- a) + // memcpy(b <- a) + if (M->getSource() == MDep->getSource()) + return false; + + // Second, the length of the memcpy's must be the same, or the preceding one + // must be larger than the following one. + ConstantInt *MDepLen = dyn_cast<ConstantInt>(MDep->getLength()); + ConstantInt *MLen = dyn_cast<ConstantInt>(M->getLength()); + if (!MDepLen || !MLen || MDepLen->getZExtValue() < MLen->getZExtValue()) + return false; + + AliasAnalysis &AA = LookupAliasAnalysis(); + + // Verify that the copied-from memory doesn't change in between the two + // transfers. For example, in: + // memcpy(a <- b) + // *b = 42; + // memcpy(c <- a) + // It would be invalid to transform the second memcpy into memcpy(c <- b). + // + // TODO: If the code between M and MDep is transparent to the destination "c", + // then we could still perform the xform by moving M up to the first memcpy. + // + // NOTE: This is conservative, it will stop on any read from the source loc, + // not just the defining memcpy. + MemDepResult SourceDep = + MD->getPointerDependencyFrom(MemoryLocation::getForSource(MDep), false, + M->getIterator(), M->getParent()); + if (!SourceDep.isClobber() || SourceDep.getInst() != MDep) + return false; + + // If the dest of the second might alias the source of the first, then the + // source and dest might overlap. We still want to eliminate the intermediate + // value, but we have to generate a memmove instead of memcpy. + bool UseMemMove = false; + if (!AA.isNoAlias(MemoryLocation::getForDest(M), + MemoryLocation::getForSource(MDep))) + UseMemMove = true; + + // If all checks passed, then we can transform M. + + // Make sure to use the lesser of the alignment of the source and the dest + // since we're changing where we're reading from, but don't want to increase + // the alignment past what can be read from or written to. + // TODO: Is this worth it if we're creating a less aligned memcpy? For + // example we could be moving from movaps -> movq on x86. + unsigned Align = std::min(MDep->getAlignment(), M->getAlignment()); + + IRBuilder<> Builder(M); + if (UseMemMove) + Builder.CreateMemMove(M->getRawDest(), MDep->getRawSource(), M->getLength(), + Align, M->isVolatile()); + else + Builder.CreateMemCpy(M->getRawDest(), MDep->getRawSource(), M->getLength(), + Align, M->isVolatile()); + + // Remove the instruction we're replacing. + MD->removeInstruction(M); + M->eraseFromParent(); + ++NumMemCpyInstr; + return true; +} + +/// We've found that the (upward scanning) memory dependence of \p MemCpy is +/// \p MemSet. Try to simplify \p MemSet to only set the trailing bytes that +/// weren't copied over by \p MemCpy. +/// +/// In other words, transform: +/// \code +/// memset(dst, c, dst_size); +/// memcpy(dst, src, src_size); +/// \endcode +/// into: +/// \code +/// memcpy(dst, src, src_size); +/// memset(dst + src_size, c, dst_size <= src_size ? 0 : dst_size - src_size); +/// \endcode +bool MemCpyOptPass::processMemSetMemCpyDependence(MemCpyInst *MemCpy, + MemSetInst *MemSet) { + // We can only transform memset/memcpy with the same destination. + if (MemSet->getDest() != MemCpy->getDest()) + return false; + + // Check that there are no other dependencies on the memset destination. + MemDepResult DstDepInfo = + MD->getPointerDependencyFrom(MemoryLocation::getForDest(MemSet), false, + MemCpy->getIterator(), MemCpy->getParent()); + if (DstDepInfo.getInst() != MemSet) + return false; + + // Use the same i8* dest as the memcpy, killing the memset dest if different. + Value *Dest = MemCpy->getRawDest(); + Value *DestSize = MemSet->getLength(); + Value *SrcSize = MemCpy->getLength(); + + // By default, create an unaligned memset. + unsigned Align = 1; + // If Dest is aligned, and SrcSize is constant, use the minimum alignment + // of the sum. + const unsigned DestAlign = + std::max(MemSet->getAlignment(), MemCpy->getAlignment()); + if (DestAlign > 1) + if (ConstantInt *SrcSizeC = dyn_cast<ConstantInt>(SrcSize)) + Align = MinAlign(SrcSizeC->getZExtValue(), DestAlign); + + IRBuilder<> Builder(MemCpy); + + // If the sizes have different types, zext the smaller one. + if (DestSize->getType() != SrcSize->getType()) { + if (DestSize->getType()->getIntegerBitWidth() > + SrcSize->getType()->getIntegerBitWidth()) + SrcSize = Builder.CreateZExt(SrcSize, DestSize->getType()); + else + DestSize = Builder.CreateZExt(DestSize, SrcSize->getType()); + } + + Value *Ule = Builder.CreateICmpULE(DestSize, SrcSize); + Value *SizeDiff = Builder.CreateSub(DestSize, SrcSize); + Value *MemsetLen = Builder.CreateSelect( + Ule, ConstantInt::getNullValue(DestSize->getType()), SizeDiff); + Builder.CreateMemSet(Builder.CreateGEP(Dest, SrcSize), MemSet->getOperand(1), + MemsetLen, Align); + + MD->removeInstruction(MemSet); + MemSet->eraseFromParent(); + return true; +} + +/// Transform memcpy to memset when its source was just memset. +/// In other words, turn: +/// \code +/// memset(dst1, c, dst1_size); +/// memcpy(dst2, dst1, dst2_size); +/// \endcode +/// into: +/// \code +/// memset(dst1, c, dst1_size); +/// memset(dst2, c, dst2_size); +/// \endcode +/// When dst2_size <= dst1_size. +/// +/// The \p MemCpy must have a Constant length. +bool MemCpyOptPass::performMemCpyToMemSetOptzn(MemCpyInst *MemCpy, + MemSetInst *MemSet) { + AliasAnalysis &AA = LookupAliasAnalysis(); + + // Make sure that memcpy(..., memset(...), ...), that is we are memsetting and + // memcpying from the same address. Otherwise it is hard to reason about. + if (!AA.isMustAlias(MemSet->getRawDest(), MemCpy->getRawSource())) + return false; + + ConstantInt *CopySize = cast<ConstantInt>(MemCpy->getLength()); + ConstantInt *MemSetSize = dyn_cast<ConstantInt>(MemSet->getLength()); + // Make sure the memcpy doesn't read any more than what the memset wrote. + // Don't worry about sizes larger than i64. + if (!MemSetSize || CopySize->getZExtValue() > MemSetSize->getZExtValue()) + return false; + + IRBuilder<> Builder(MemCpy); + Builder.CreateMemSet(MemCpy->getRawDest(), MemSet->getOperand(1), + CopySize, MemCpy->getAlignment()); + return true; +} + +/// Perform simplification of memcpy's. If we have memcpy A +/// which copies X to Y, and memcpy B which copies Y to Z, then we can rewrite +/// B to be a memcpy from X to Z (or potentially a memmove, depending on +/// circumstances). This allows later passes to remove the first memcpy +/// altogether. +bool MemCpyOptPass::processMemCpy(MemCpyInst *M) { + // We can only optimize non-volatile memcpy's. + if (M->isVolatile()) return false; + + // If the source and destination of the memcpy are the same, then zap it. + if (M->getSource() == M->getDest()) { + MD->removeInstruction(M); + M->eraseFromParent(); + return false; + } + + // If copying from a constant, try to turn the memcpy into a memset. + if (GlobalVariable *GV = dyn_cast<GlobalVariable>(M->getSource())) + if (GV->isConstant() && GV->hasDefinitiveInitializer()) + if (Value *ByteVal = isBytewiseValue(GV->getInitializer())) { + IRBuilder<> Builder(M); + Builder.CreateMemSet(M->getRawDest(), ByteVal, M->getLength(), + M->getAlignment(), false); + MD->removeInstruction(M); + M->eraseFromParent(); + ++NumCpyToSet; + return true; + } + + MemDepResult DepInfo = MD->getDependency(M); + + // Try to turn a partially redundant memset + memcpy into + // memcpy + smaller memset. We don't need the memcpy size for this. + if (DepInfo.isClobber()) + if (MemSetInst *MDep = dyn_cast<MemSetInst>(DepInfo.getInst())) + if (processMemSetMemCpyDependence(M, MDep)) + return true; + + // The optimizations after this point require the memcpy size. + ConstantInt *CopySize = dyn_cast<ConstantInt>(M->getLength()); + if (!CopySize) return false; + + // There are four possible optimizations we can do for memcpy: + // a) memcpy-memcpy xform which exposes redundance for DSE. + // b) call-memcpy xform for return slot optimization. + // c) memcpy from freshly alloca'd space or space that has just started its + // lifetime copies undefined data, and we can therefore eliminate the + // memcpy in favor of the data that was already at the destination. + // d) memcpy from a just-memset'd source can be turned into memset. + if (DepInfo.isClobber()) { + if (CallInst *C = dyn_cast<CallInst>(DepInfo.getInst())) { + if (performCallSlotOptzn(M, M->getDest(), M->getSource(), + CopySize->getZExtValue(), M->getAlignment(), + C)) { + MD->removeInstruction(M); + M->eraseFromParent(); + return true; + } + } + } + + MemoryLocation SrcLoc = MemoryLocation::getForSource(M); + MemDepResult SrcDepInfo = MD->getPointerDependencyFrom( + SrcLoc, true, M->getIterator(), M->getParent()); + + if (SrcDepInfo.isClobber()) { + if (MemCpyInst *MDep = dyn_cast<MemCpyInst>(SrcDepInfo.getInst())) + return processMemCpyMemCpyDependence(M, MDep); + } else if (SrcDepInfo.isDef()) { + Instruction *I = SrcDepInfo.getInst(); + bool hasUndefContents = false; + + if (isa<AllocaInst>(I)) { + hasUndefContents = true; + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + if (ConstantInt *LTSize = dyn_cast<ConstantInt>(II->getArgOperand(0))) + if (LTSize->getZExtValue() >= CopySize->getZExtValue()) + hasUndefContents = true; + } + + if (hasUndefContents) { + MD->removeInstruction(M); + M->eraseFromParent(); + ++NumMemCpyInstr; + return true; + } + } + + if (SrcDepInfo.isClobber()) + if (MemSetInst *MDep = dyn_cast<MemSetInst>(SrcDepInfo.getInst())) + if (performMemCpyToMemSetOptzn(M, MDep)) { + MD->removeInstruction(M); + M->eraseFromParent(); + ++NumCpyToSet; + return true; + } + + return false; +} + +/// Transforms memmove calls to memcpy calls when the src/dst are guaranteed +/// not to alias. +bool MemCpyOptPass::processMemMove(MemMoveInst *M) { + AliasAnalysis &AA = LookupAliasAnalysis(); + + if (!TLI->has(LibFunc_memmove)) + return false; + + // See if the pointers alias. + if (!AA.isNoAlias(MemoryLocation::getForDest(M), + MemoryLocation::getForSource(M))) + return false; + + DEBUG(dbgs() << "MemCpyOptPass: Optimizing memmove -> memcpy: " << *M + << "\n"); + + // If not, then we know we can transform this. + Type *ArgTys[3] = { M->getRawDest()->getType(), + M->getRawSource()->getType(), + M->getLength()->getType() }; + M->setCalledFunction(Intrinsic::getDeclaration(M->getModule(), + Intrinsic::memcpy, ArgTys)); + + // MemDep may have over conservative information about this instruction, just + // conservatively flush it from the cache. + MD->removeInstruction(M); + + ++NumMoveToCpy; + return true; +} + +/// This is called on every byval argument in call sites. +bool MemCpyOptPass::processByValArgument(CallSite CS, unsigned ArgNo) { + const DataLayout &DL = CS.getCaller()->getParent()->getDataLayout(); + // Find out what feeds this byval argument. + Value *ByValArg = CS.getArgument(ArgNo); + Type *ByValTy = cast<PointerType>(ByValArg->getType())->getElementType(); + uint64_t ByValSize = DL.getTypeAllocSize(ByValTy); + MemDepResult DepInfo = MD->getPointerDependencyFrom( + MemoryLocation(ByValArg, ByValSize), true, + CS.getInstruction()->getIterator(), CS.getInstruction()->getParent()); + if (!DepInfo.isClobber()) + return false; + + // If the byval argument isn't fed by a memcpy, ignore it. If it is fed by + // a memcpy, see if we can byval from the source of the memcpy instead of the + // result. + MemCpyInst *MDep = dyn_cast<MemCpyInst>(DepInfo.getInst()); + if (!MDep || MDep->isVolatile() || + ByValArg->stripPointerCasts() != MDep->getDest()) + return false; + + // The length of the memcpy must be larger or equal to the size of the byval. + ConstantInt *C1 = dyn_cast<ConstantInt>(MDep->getLength()); + if (!C1 || C1->getValue().getZExtValue() < ByValSize) + return false; + + // Get the alignment of the byval. If the call doesn't specify the alignment, + // then it is some target specific value that we can't know. + unsigned ByValAlign = CS.getParamAlignment(ArgNo); + if (ByValAlign == 0) return false; + + // If it is greater than the memcpy, then we check to see if we can force the + // source of the memcpy to the alignment we need. If we fail, we bail out. + AssumptionCache &AC = LookupAssumptionCache(); + DominatorTree &DT = LookupDomTree(); + if (MDep->getAlignment() < ByValAlign && + getOrEnforceKnownAlignment(MDep->getSource(), ByValAlign, DL, + CS.getInstruction(), &AC, &DT) < ByValAlign) + return false; + + // The address space of the memcpy source must match the byval argument + if (MDep->getSource()->getType()->getPointerAddressSpace() != + ByValArg->getType()->getPointerAddressSpace()) + return false; + + // Verify that the copied-from memory doesn't change in between the memcpy and + // the byval call. + // memcpy(a <- b) + // *b = 42; + // foo(*a) + // It would be invalid to transform the second memcpy into foo(*b). + // + // NOTE: This is conservative, it will stop on any read from the source loc, + // not just the defining memcpy. + MemDepResult SourceDep = MD->getPointerDependencyFrom( + MemoryLocation::getForSource(MDep), false, + CS.getInstruction()->getIterator(), MDep->getParent()); + if (!SourceDep.isClobber() || SourceDep.getInst() != MDep) + return false; + + Value *TmpCast = MDep->getSource(); + if (MDep->getSource()->getType() != ByValArg->getType()) + TmpCast = new BitCastInst(MDep->getSource(), ByValArg->getType(), + "tmpcast", CS.getInstruction()); + + DEBUG(dbgs() << "MemCpyOptPass: Forwarding memcpy to byval:\n" + << " " << *MDep << "\n" + << " " << *CS.getInstruction() << "\n"); + + // Otherwise we're good! Update the byval argument. + CS.setArgument(ArgNo, TmpCast); + ++NumMemCpyInstr; + return true; +} + +/// Executes one iteration of MemCpyOptPass. +bool MemCpyOptPass::iterateOnFunction(Function &F) { + bool MadeChange = false; + + // Walk all instruction in the function. + for (BasicBlock &BB : F) { + for (BasicBlock::iterator BI = BB.begin(), BE = BB.end(); BI != BE;) { + // Avoid invalidating the iterator. + Instruction *I = &*BI++; + + bool RepeatInstruction = false; + + if (StoreInst *SI = dyn_cast<StoreInst>(I)) + MadeChange |= processStore(SI, BI); + else if (MemSetInst *M = dyn_cast<MemSetInst>(I)) + RepeatInstruction = processMemSet(M, BI); + else if (MemCpyInst *M = dyn_cast<MemCpyInst>(I)) + RepeatInstruction = processMemCpy(M); + else if (MemMoveInst *M = dyn_cast<MemMoveInst>(I)) + RepeatInstruction = processMemMove(M); + else if (auto CS = CallSite(I)) { + for (unsigned i = 0, e = CS.arg_size(); i != e; ++i) + if (CS.isByValArgument(i)) + MadeChange |= processByValArgument(CS, i); + } + + // Reprocess the instruction if desired. + if (RepeatInstruction) { + if (BI != BB.begin()) + --BI; + MadeChange = true; + } + } + } + + return MadeChange; +} + +PreservedAnalyses MemCpyOptPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &MD = AM.getResult<MemoryDependenceAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + + auto LookupAliasAnalysis = [&]() -> AliasAnalysis & { + return AM.getResult<AAManager>(F); + }; + auto LookupAssumptionCache = [&]() -> AssumptionCache & { + return AM.getResult<AssumptionAnalysis>(F); + }; + auto LookupDomTree = [&]() -> DominatorTree & { + return AM.getResult<DominatorTreeAnalysis>(F); + }; + + bool MadeChange = runImpl(F, &MD, &TLI, LookupAliasAnalysis, + LookupAssumptionCache, LookupDomTree); + if (!MadeChange) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + PA.preserve<MemoryDependenceAnalysis>(); + return PA; +} + +bool MemCpyOptPass::runImpl( + Function &F, MemoryDependenceResults *MD_, TargetLibraryInfo *TLI_, + std::function<AliasAnalysis &()> LookupAliasAnalysis_, + std::function<AssumptionCache &()> LookupAssumptionCache_, + std::function<DominatorTree &()> LookupDomTree_) { + bool MadeChange = false; + MD = MD_; + TLI = TLI_; + LookupAliasAnalysis = std::move(LookupAliasAnalysis_); + LookupAssumptionCache = std::move(LookupAssumptionCache_); + LookupDomTree = std::move(LookupDomTree_); + + // If we don't have at least memset and memcpy, there is little point of doing + // anything here. These are required by a freestanding implementation, so if + // even they are disabled, there is no point in trying hard. + if (!TLI->has(LibFunc_memset) || !TLI->has(LibFunc_memcpy)) + return false; + + while (true) { + if (!iterateOnFunction(F)) + break; + MadeChange = true; + } + + MD = nullptr; + return MadeChange; +} + +/// This is the main transformation entry point for a function. +bool MemCpyOptLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + auto *MD = &getAnalysis<MemoryDependenceWrapperPass>().getMemDep(); + auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + auto LookupAliasAnalysis = [this]() -> AliasAnalysis & { + return getAnalysis<AAResultsWrapperPass>().getAAResults(); + }; + auto LookupAssumptionCache = [this, &F]() -> AssumptionCache & { + return getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + }; + auto LookupDomTree = [this]() -> DominatorTree & { + return getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + }; + + return Impl.runImpl(F, MD, TLI, LookupAliasAnalysis, LookupAssumptionCache, + LookupDomTree); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/MergeICmps.cpp b/contrib/llvm/lib/Transforms/Scalar/MergeICmps.cpp new file mode 100644 index 000000000000..9869a3fb96fa --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/MergeICmps.cpp @@ -0,0 +1,650 @@ +//===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass turns chains of integer comparisons into memcmp (the memcmp is +// later typically inlined as a chain of efficient hardware comparisons). This +// typically benefits c++ member or nonmember operator==(). +// +// The basic idea is to replace a larger chain of integer comparisons loaded +// from contiguous memory locations into a smaller chain of such integer +// comparisons. Benefits are double: +// - There are less jumps, and therefore less opportunities for mispredictions +// and I-cache misses. +// - Code size is smaller, both because jumps are removed and because the +// encoding of a 2*n byte compare is smaller than that of two n-byte +// compares. + +//===----------------------------------------------------------------------===// + +#include <algorithm> +#include <numeric> +#include <utility> +#include <vector> +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" + +using namespace llvm; + +namespace { + +#define DEBUG_TYPE "mergeicmps" + +// A BCE atom. +struct BCEAtom { + BCEAtom() : GEP(nullptr), LoadI(nullptr), Offset() {} + + const Value *Base() const { return GEP ? GEP->getPointerOperand() : nullptr; } + + bool operator<(const BCEAtom &O) const { + assert(Base() && "invalid atom"); + assert(O.Base() && "invalid atom"); + // Just ordering by (Base(), Offset) is sufficient. However because this + // means that the ordering will depend on the addresses of the base + // values, which are not reproducible from run to run. To guarantee + // stability, we use the names of the values if they exist; we sort by: + // (Base.getName(), Base(), Offset). + const int NameCmp = Base()->getName().compare(O.Base()->getName()); + if (NameCmp == 0) { + if (Base() == O.Base()) { + return Offset.slt(O.Offset); + } + return Base() < O.Base(); + } + return NameCmp < 0; + } + + GetElementPtrInst *GEP; + LoadInst *LoadI; + APInt Offset; +}; + +// If this value is a load from a constant offset w.r.t. a base address, and +// there are no othe rusers of the load or address, returns the base address and +// the offset. +BCEAtom visitICmpLoadOperand(Value *const Val) { + BCEAtom Result; + if (auto *const LoadI = dyn_cast<LoadInst>(Val)) { + DEBUG(dbgs() << "load\n"); + if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + if (LoadI->isVolatile()) { + DEBUG(dbgs() << "volatile\n"); + return {}; + } + Value *const Addr = LoadI->getOperand(0); + if (auto *const GEP = dyn_cast<GetElementPtrInst>(Addr)) { + DEBUG(dbgs() << "GEP\n"); + if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + const auto &DL = GEP->getModule()->getDataLayout(); + if (!isDereferenceablePointer(GEP, DL)) { + DEBUG(dbgs() << "not dereferenceable\n"); + // We need to make sure that we can do comparison in any order, so we + // require memory to be unconditionnally dereferencable. + return {}; + } + Result.Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); + if (GEP->accumulateConstantOffset(DL, Result.Offset)) { + Result.GEP = GEP; + Result.LoadI = LoadI; + } + } + } + return Result; +} + +// A basic block with a comparison between two BCE atoms. +// Note: the terminology is misleading: the comparison is symmetric, so there +// is no real {l/r}hs. What we want though is to have the same base on the +// left (resp. right), so that we can detect consecutive loads. To ensure this +// we put the smallest atom on the left. +class BCECmpBlock { + public: + BCECmpBlock() {} + + BCECmpBlock(BCEAtom L, BCEAtom R, int SizeBits) + : Lhs_(L), Rhs_(R), SizeBits_(SizeBits) { + if (Rhs_ < Lhs_) std::swap(Rhs_, Lhs_); + } + + bool IsValid() const { + return Lhs_.Base() != nullptr && Rhs_.Base() != nullptr; + } + + // Assert the the block is consistent: If valid, it should also have + // non-null members besides Lhs_ and Rhs_. + void AssertConsistent() const { + if (IsValid()) { + assert(BB); + assert(CmpI); + assert(BranchI); + } + } + + const BCEAtom &Lhs() const { return Lhs_; } + const BCEAtom &Rhs() const { return Rhs_; } + int SizeBits() const { return SizeBits_; } + + // Returns true if the block does other works besides comparison. + bool doesOtherWork() const; + + // The basic block where this comparison happens. + BasicBlock *BB = nullptr; + // The ICMP for this comparison. + ICmpInst *CmpI = nullptr; + // The terminating branch. + BranchInst *BranchI = nullptr; + + private: + BCEAtom Lhs_; + BCEAtom Rhs_; + int SizeBits_ = 0; +}; + +bool BCECmpBlock::doesOtherWork() const { + AssertConsistent(); + // TODO(courbet): Can we allow some other things ? This is very conservative. + // We might be able to get away with anything does does not have any side + // effects outside of the basic block. + // Note: The GEPs and/or loads are not necessarily in the same block. + for (const Instruction &Inst : *BB) { + if (const auto *const GEP = dyn_cast<GetElementPtrInst>(&Inst)) { + if (!(Lhs_.GEP == GEP || Rhs_.GEP == GEP)) return true; + } else if (const auto *const L = dyn_cast<LoadInst>(&Inst)) { + if (!(Lhs_.LoadI == L || Rhs_.LoadI == L)) return true; + } else if (const auto *const C = dyn_cast<ICmpInst>(&Inst)) { + if (C != CmpI) return true; + } else if (const auto *const Br = dyn_cast<BranchInst>(&Inst)) { + if (Br != BranchI) return true; + } else { + return true; + } + } + return false; +} + +// Visit the given comparison. If this is a comparison between two valid +// BCE atoms, returns the comparison. +BCECmpBlock visitICmp(const ICmpInst *const CmpI, + const ICmpInst::Predicate ExpectedPredicate) { + if (CmpI->getPredicate() == ExpectedPredicate) { + DEBUG(dbgs() << "cmp " + << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") + << "\n"); + auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0)); + if (!Lhs.Base()) return {}; + auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1)); + if (!Rhs.Base()) return {}; + return BCECmpBlock(std::move(Lhs), std::move(Rhs), + CmpI->getOperand(0)->getType()->getScalarSizeInBits()); + } + return {}; +} + +// Visit the given comparison block. If this is a comparison between two valid +// BCE atoms, returns the comparison. +BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, + const BasicBlock *const PhiBlock) { + if (Block->empty()) return {}; + auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); + if (!BranchI) return {}; + DEBUG(dbgs() << "branch\n"); + if (BranchI->isUnconditional()) { + // In this case, we expect an incoming value which is the result of the + // comparison. This is the last link in the chain of comparisons (note + // that this does not mean that this is the last incoming value, blocks + // can be reordered). + auto *const CmpI = dyn_cast<ICmpInst>(Val); + if (!CmpI) return {}; + DEBUG(dbgs() << "icmp\n"); + auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ); + Result.CmpI = CmpI; + Result.BranchI = BranchI; + return Result; + } else { + // In this case, we expect a constant incoming value (the comparison is + // chained). + const auto *const Const = dyn_cast<ConstantInt>(Val); + DEBUG(dbgs() << "const\n"); + if (!Const->isZero()) return {}; + DEBUG(dbgs() << "false\n"); + auto *const CmpI = dyn_cast<ICmpInst>(BranchI->getCondition()); + if (!CmpI) return {}; + DEBUG(dbgs() << "icmp\n"); + assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); + BasicBlock *const FalseBlock = BranchI->getSuccessor(1); + auto Result = visitICmp( + CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE); + Result.CmpI = CmpI; + Result.BranchI = BranchI; + return Result; + } + return {}; +} + +// A chain of comparisons. +class BCECmpChain { + public: + BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi); + + int size() const { return Comparisons_.size(); } + +#ifdef MERGEICMPS_DOT_ON + void dump() const; +#endif // MERGEICMPS_DOT_ON + + bool simplify(const TargetLibraryInfo *const TLI); + + private: + static bool IsContiguous(const BCECmpBlock &First, + const BCECmpBlock &Second) { + return First.Lhs().Base() == Second.Lhs().Base() && + First.Rhs().Base() == Second.Rhs().Base() && + First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && + First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; + } + + // Merges the given comparison blocks into one memcmp block and update + // branches. Comparisons are assumed to be continguous. If NextBBInChain is + // null, the merged block will link to the phi block. + static void mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, + BasicBlock *const NextBBInChain, PHINode &Phi, + const TargetLibraryInfo *const TLI); + + PHINode &Phi_; + std::vector<BCECmpBlock> Comparisons_; + // The original entry block (before sorting); + BasicBlock *EntryBlock_; +}; + +BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi) + : Phi_(Phi) { + // Now look inside blocks to check for BCE comparisons. + std::vector<BCECmpBlock> Comparisons; + for (BasicBlock *Block : Blocks) { + BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block), + Block, Phi.getParent()); + Comparison.BB = Block; + if (!Comparison.IsValid()) { + DEBUG(dbgs() << "skip: not a valid BCECmpBlock\n"); + return; + } + if (Comparison.doesOtherWork()) { + DEBUG(dbgs() << "block does extra work besides compare\n"); + if (Comparisons.empty()) { // First block. + // TODO(courbet): The first block can do other things, and we should + // split them apart in a separate block before the comparison chain. + // Right now we just discard it and make the chain shorter. + DEBUG(dbgs() + << "ignoring first block that does extra work besides compare\n"); + continue; + } + // TODO(courbet): Right now we abort the whole chain. We could be + // merging only the blocks that don't do other work and resume the + // chain from there. For example: + // if (a[0] == b[0]) { // bb1 + // if (a[1] == b[1]) { // bb2 + // some_value = 3; //bb3 + // if (a[2] == b[2]) { //bb3 + // do a ton of stuff //bb4 + // } + // } + // } + // + // This is: + // + // bb1 --eq--> bb2 --eq--> bb3* -eq--> bb4 --+ + // \ \ \ \ + // ne ne ne \ + // \ \ \ v + // +------------+-----------+----------> bb_phi + // + // We can only merge the first two comparisons, because bb3* does + // "other work" (setting some_value to 3). + // We could still merge bb1 and bb2 though. + return; + } + DEBUG(dbgs() << "*Found cmp of " << Comparison.SizeBits() + << " bits between " << Comparison.Lhs().Base() << " + " + << Comparison.Lhs().Offset << " and " + << Comparison.Rhs().Base() << " + " << Comparison.Rhs().Offset + << "\n"); + DEBUG(dbgs() << "\n"); + Comparisons.push_back(Comparison); + } + EntryBlock_ = Comparisons[0].BB; + Comparisons_ = std::move(Comparisons); +#ifdef MERGEICMPS_DOT_ON + errs() << "BEFORE REORDERING:\n\n"; + dump(); +#endif // MERGEICMPS_DOT_ON + // Reorder blocks by LHS. We can do that without changing the + // semantics because we are only accessing dereferencable memory. + std::sort(Comparisons_.begin(), Comparisons_.end(), + [](const BCECmpBlock &a, const BCECmpBlock &b) { + return a.Lhs() < b.Lhs(); + }); +#ifdef MERGEICMPS_DOT_ON + errs() << "AFTER REORDERING:\n\n"; + dump(); +#endif // MERGEICMPS_DOT_ON +} + +#ifdef MERGEICMPS_DOT_ON +void BCECmpChain::dump() const { + errs() << "digraph dag {\n"; + errs() << " graph [bgcolor=transparent];\n"; + errs() << " node [color=black,style=filled,fillcolor=lightyellow];\n"; + errs() << " edge [color=black];\n"; + for (size_t I = 0; I < Comparisons_.size(); ++I) { + const auto &Comparison = Comparisons_[I]; + errs() << " \"" << I << "\" [label=\"%" + << Comparison.Lhs().Base()->getName() << " + " + << Comparison.Lhs().Offset << " == %" + << Comparison.Rhs().Base()->getName() << " + " + << Comparison.Rhs().Offset << " (" << (Comparison.SizeBits() / 8) + << " bytes)\"];\n"; + const Value *const Val = Phi_.getIncomingValueForBlock(Comparison.BB); + if (I > 0) errs() << " \"" << (I - 1) << "\" -> \"" << I << "\";\n"; + errs() << " \"" << I << "\" -> \"Phi\" [label=\"" << *Val << "\"];\n"; + } + errs() << " \"Phi\" [label=\"Phi\"];\n"; + errs() << "}\n\n"; +} +#endif // MERGEICMPS_DOT_ON + +bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI) { + // First pass to check if there is at least one merge. If not, we don't do + // anything and we keep analysis passes intact. + { + bool AtLeastOneMerged = false; + for (size_t I = 1; I < Comparisons_.size(); ++I) { + if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { + AtLeastOneMerged = true; + break; + } + } + if (!AtLeastOneMerged) return false; + } + + // Remove phi references to comparison blocks, they will be rebuilt as we + // merge the blocks. + for (const auto &Comparison : Comparisons_) { + Phi_.removeIncomingValue(Comparison.BB, false); + } + + // Point the predecessors of the chain to the first comparison block (which is + // the new entry point). + if (EntryBlock_ != Comparisons_[0].BB) + EntryBlock_->replaceAllUsesWith(Comparisons_[0].BB); + + // Effectively merge blocks. + int NumMerged = 1; + for (size_t I = 1; I < Comparisons_.size(); ++I) { + if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { + ++NumMerged; + } else { + // Merge all previous comparisons and start a new merge block. + mergeComparisons( + makeArrayRef(Comparisons_).slice(I - NumMerged, NumMerged), + Comparisons_[I].BB, Phi_, TLI); + NumMerged = 1; + } + } + mergeComparisons(makeArrayRef(Comparisons_) + .slice(Comparisons_.size() - NumMerged, NumMerged), + nullptr, Phi_, TLI); + + return true; +} + +void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, + BasicBlock *const NextBBInChain, + PHINode &Phi, + const TargetLibraryInfo *const TLI) { + assert(!Comparisons.empty()); + const auto &FirstComparison = *Comparisons.begin(); + BasicBlock *const BB = FirstComparison.BB; + LLVMContext &Context = BB->getContext(); + + if (Comparisons.size() >= 2) { + DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); + const auto TotalSize = + std::accumulate(Comparisons.begin(), Comparisons.end(), 0, + [](int Size, const BCECmpBlock &C) { + return Size + C.SizeBits(); + }) / + 8; + + // Incoming edges do not need to be updated, and both GEPs are already + // computing the right address, we just need to: + // - replace the two loads and the icmp with the memcmp + // - update the branch + // - update the incoming values in the phi. + FirstComparison.BranchI->eraseFromParent(); + FirstComparison.CmpI->eraseFromParent(); + FirstComparison.Lhs().LoadI->eraseFromParent(); + FirstComparison.Rhs().LoadI->eraseFromParent(); + + IRBuilder<> Builder(BB); + const auto &DL = Phi.getModule()->getDataLayout(); + Value *const MemCmpCall = emitMemCmp( + FirstComparison.Lhs().GEP, FirstComparison.Rhs().GEP, ConstantInt::get(DL.getIntPtrType(Context), TotalSize), + Builder, DL, TLI); + Value *const MemCmpIsZero = Builder.CreateICmpEQ( + MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); + + // Add a branch to the next basic block in the chain. + if (NextBBInChain) { + Builder.CreateCondBr(MemCmpIsZero, NextBBInChain, Phi.getParent()); + Phi.addIncoming(ConstantInt::getFalse(Context), BB); + } else { + Builder.CreateBr(Phi.getParent()); + Phi.addIncoming(MemCmpIsZero, BB); + } + + // Delete merged blocks. + for (size_t I = 1; I < Comparisons.size(); ++I) { + BasicBlock *CBB = Comparisons[I].BB; + CBB->replaceAllUsesWith(BB); + CBB->eraseFromParent(); + } + } else { + assert(Comparisons.size() == 1); + // There are no blocks to merge, but we still need to update the branches. + DEBUG(dbgs() << "Only one comparison, updating branches\n"); + if (NextBBInChain) { + if (FirstComparison.BranchI->isConditional()) { + DEBUG(dbgs() << "conditional -> conditional\n"); + // Just update the "true" target, the "false" target should already be + // the phi block. + assert(FirstComparison.BranchI->getSuccessor(1) == Phi.getParent()); + FirstComparison.BranchI->setSuccessor(0, NextBBInChain); + Phi.addIncoming(ConstantInt::getFalse(Context), BB); + } else { + DEBUG(dbgs() << "unconditional -> conditional\n"); + // Replace the unconditional branch by a conditional one. + FirstComparison.BranchI->eraseFromParent(); + IRBuilder<> Builder(BB); + Builder.CreateCondBr(FirstComparison.CmpI, NextBBInChain, + Phi.getParent()); + Phi.addIncoming(FirstComparison.CmpI, BB); + } + } else { + if (FirstComparison.BranchI->isConditional()) { + DEBUG(dbgs() << "conditional -> unconditional\n"); + // Replace the conditional branch by an unconditional one. + FirstComparison.BranchI->eraseFromParent(); + IRBuilder<> Builder(BB); + Builder.CreateBr(Phi.getParent()); + Phi.addIncoming(FirstComparison.CmpI, BB); + } else { + DEBUG(dbgs() << "unconditional -> unconditional\n"); + Phi.addIncoming(FirstComparison.CmpI, BB); + } + } + } +} + +std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, + BasicBlock *const LastBlock, + int NumBlocks) { + // Walk up from the last block to find other blocks. + std::vector<BasicBlock *> Blocks(NumBlocks); + BasicBlock *CurBlock = LastBlock; + for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) { + if (CurBlock->hasAddressTaken()) { + // Somebody is jumping to the block through an address, all bets are + // off. + DEBUG(dbgs() << "skip: block " << BlockIndex + << " has its address taken\n"); + return {}; + } + Blocks[BlockIndex] = CurBlock; + auto *SinglePredecessor = CurBlock->getSinglePredecessor(); + if (!SinglePredecessor) { + // The block has two or more predecessors. + DEBUG(dbgs() << "skip: block " << BlockIndex + << " has two or more predecessors\n"); + return {}; + } + if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) { + // The block does not link back to the phi. + DEBUG(dbgs() << "skip: block " << BlockIndex + << " does not link back to the phi\n"); + return {}; + } + CurBlock = SinglePredecessor; + } + Blocks[0] = CurBlock; + return Blocks; +} + +bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI) { + DEBUG(dbgs() << "processPhi()\n"); + if (Phi.getNumIncomingValues() <= 1) { + DEBUG(dbgs() << "skip: only one incoming value in phi\n"); + return false; + } + // We are looking for something that has the following structure: + // bb1 --eq--> bb2 --eq--> bb3 --eq--> bb4 --+ + // \ \ \ \ + // ne ne ne \ + // \ \ \ v + // +------------+-----------+----------> bb_phi + // + // - The last basic block (bb4 here) must branch unconditionally to bb_phi. + // It's the only block that contributes a non-constant value to the Phi. + // - All other blocks (b1, b2, b3) must have exactly two successors, one of + // them being the the phi block. + // - All intermediate blocks (bb2, bb3) must have only one predecessor. + // - Blocks cannot do other work besides the comparison, see doesOtherWork() + + // The blocks are not necessarily ordered in the phi, so we start from the + // last block and reconstruct the order. + BasicBlock *LastBlock = nullptr; + for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) { + if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue; + if (LastBlock) { + // There are several non-constant values. + DEBUG(dbgs() << "skip: several non-constant values\n"); + return false; + } + LastBlock = Phi.getIncomingBlock(I); + } + if (!LastBlock) { + // There is no non-constant block. + DEBUG(dbgs() << "skip: no non-constant block\n"); + return false; + } + if (LastBlock->getSingleSuccessor() != Phi.getParent()) { + DEBUG(dbgs() << "skip: last block non-phi successor\n"); + return false; + } + + const auto Blocks = + getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues()); + if (Blocks.empty()) return false; + BCECmpChain CmpChain(Blocks, Phi); + + if (CmpChain.size() < 2) { + DEBUG(dbgs() << "skip: only one compare block\n"); + return false; + } + + return CmpChain.simplify(TLI); +} + +class MergeICmps : public FunctionPass { + public: + static char ID; + + MergeICmps() : FunctionPass(ID) { + initializeMergeICmpsPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) return false; + const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto PA = runImpl(F, &TLI, &TTI); + return !PA.areAllPreserved(); + } + + private: + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + } + + PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI); +}; + +PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI) { + DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); + + // We only try merging comparisons if the target wants to expand memcmp later. + // The rationale is to avoid turning small chains into memcmp calls. + if (!TTI->enableMemCmpExpansion(true)) return PreservedAnalyses::all(); + + bool MadeChange = false; + + for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { + // A Phi operation is always first in a basic block. + if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) + MadeChange |= processPhi(*Phi, TLI); + } + + if (MadeChange) return PreservedAnalyses::none(); + return PreservedAnalyses::all(); +} + +} // namespace + +char MergeICmps::ID = 0; +INITIALIZE_PASS_BEGIN(MergeICmps, "mergeicmps", + "Merge contiguous icmps into a memcmp", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(MergeICmps, "mergeicmps", + "Merge contiguous icmps into a memcmp", false, false) + +Pass *llvm::createMergeICmpsPass() { return new MergeICmps(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp b/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp new file mode 100644 index 000000000000..f2f615cb9b0f --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/MergedLoadStoreMotion.cpp @@ -0,0 +1,429 @@ +//===- MergedLoadStoreMotion.cpp - merge and hoist/sink load/stores -------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +//! \file +//! \brief This pass performs merges of loads and stores on both sides of a +// diamond (hammock). It hoists the loads and sinks the stores. +// +// The algorithm iteratively hoists two loads to the same address out of a +// diamond (hammock) and merges them into a single load in the header. Similar +// it sinks and merges two stores to the tail block (footer). The algorithm +// iterates over the instructions of one side of the diamond and attempts to +// find a matching load/store on the other side. It hoists / sinks when it +// thinks it safe to do so. This optimization helps with eg. hiding load +// latencies, triggering if-conversion, and reducing static code size. +// +// NOTE: This code no longer performs load hoisting, it is subsumed by GVNHoist. +// +//===----------------------------------------------------------------------===// +// +// +// Example: +// Diamond shaped code before merge: +// +// header: +// br %cond, label %if.then, label %if.else +// + + +// + + +// + + +// if.then: if.else: +// %lt = load %addr_l %le = load %addr_l +// <use %lt> <use %le> +// <...> <...> +// store %st, %addr_s store %se, %addr_s +// br label %if.end br label %if.end +// + + +// + + +// + + +// if.end ("footer"): +// <...> +// +// Diamond shaped code after merge: +// +// header: +// %l = load %addr_l +// br %cond, label %if.then, label %if.else +// + + +// + + +// + + +// if.then: if.else: +// <use %l> <use %l> +// <...> <...> +// br label %if.end br label %if.end +// + + +// + + +// + + +// if.end ("footer"): +// %s.sink = phi [%st, if.then], [%se, if.else] +// <...> +// store %s.sink, %addr_s +// <...> +// +// +//===----------------------- TODO -----------------------------------------===// +// +// 1) Generalize to regions other than diamonds +// 2) Be more aggressive merging memory operations +// Note that both changes require register pressure control +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/MemoryDependenceAnalysis.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Metadata.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "mldst-motion" + +namespace { +//===----------------------------------------------------------------------===// +// MergedLoadStoreMotion Pass +//===----------------------------------------------------------------------===// +class MergedLoadStoreMotion { + MemoryDependenceResults *MD = nullptr; + AliasAnalysis *AA = nullptr; + + // The mergeLoad/Store algorithms could have Size0 * Size1 complexity, + // where Size0 and Size1 are the #instructions on the two sides of + // the diamond. The constant chosen here is arbitrary. Compiler Time + // Control is enforced by the check Size0 * Size1 < MagicCompileTimeControl. + const int MagicCompileTimeControl = 250; + +public: + bool run(Function &F, MemoryDependenceResults *MD, AliasAnalysis &AA); + +private: + /// + /// \brief Remove instruction from parent and update memory dependence + /// analysis. + /// + void removeInstruction(Instruction *Inst); + BasicBlock *getDiamondTail(BasicBlock *BB); + bool isDiamondHead(BasicBlock *BB); + // Routines for sinking stores + StoreInst *canSinkFromBlock(BasicBlock *BB, StoreInst *SI); + PHINode *getPHIOperand(BasicBlock *BB, StoreInst *S0, StoreInst *S1); + bool isStoreSinkBarrierInRange(const Instruction &Start, + const Instruction &End, MemoryLocation Loc); + bool sinkStore(BasicBlock *BB, StoreInst *SinkCand, StoreInst *ElseInst); + bool mergeStores(BasicBlock *BB); +}; +} // end anonymous namespace + +/// +/// \brief Remove instruction from parent and update memory dependence analysis. +/// +void MergedLoadStoreMotion::removeInstruction(Instruction *Inst) { + // Notify the memory dependence analysis. + if (MD) { + MD->removeInstruction(Inst); + if (auto *LI = dyn_cast<LoadInst>(Inst)) + MD->invalidateCachedPointerInfo(LI->getPointerOperand()); + if (Inst->getType()->isPtrOrPtrVectorTy()) { + MD->invalidateCachedPointerInfo(Inst); + } + } + Inst->eraseFromParent(); +} + +/// +/// \brief Return tail block of a diamond. +/// +BasicBlock *MergedLoadStoreMotion::getDiamondTail(BasicBlock *BB) { + assert(isDiamondHead(BB) && "Basic block is not head of a diamond"); + return BB->getTerminator()->getSuccessor(0)->getSingleSuccessor(); +} + +/// +/// \brief True when BB is the head of a diamond (hammock) +/// +bool MergedLoadStoreMotion::isDiamondHead(BasicBlock *BB) { + if (!BB) + return false; + auto *BI = dyn_cast<BranchInst>(BB->getTerminator()); + if (!BI || !BI->isConditional()) + return false; + + BasicBlock *Succ0 = BI->getSuccessor(0); + BasicBlock *Succ1 = BI->getSuccessor(1); + + if (!Succ0->getSinglePredecessor()) + return false; + if (!Succ1->getSinglePredecessor()) + return false; + + BasicBlock *Succ0Succ = Succ0->getSingleSuccessor(); + BasicBlock *Succ1Succ = Succ1->getSingleSuccessor(); + // Ignore triangles. + if (!Succ0Succ || !Succ1Succ || Succ0Succ != Succ1Succ) + return false; + return true; +} + + +/// +/// \brief True when instruction is a sink barrier for a store +/// located in Loc +/// +/// Whenever an instruction could possibly read or modify the +/// value being stored or protect against the store from +/// happening it is considered a sink barrier. +/// +bool MergedLoadStoreMotion::isStoreSinkBarrierInRange(const Instruction &Start, + const Instruction &End, + MemoryLocation Loc) { + for (const Instruction &Inst : + make_range(Start.getIterator(), End.getIterator())) + if (Inst.mayThrow()) + return true; + return AA->canInstructionRangeModRef(Start, End, Loc, ModRefInfo::ModRef); +} + +/// +/// \brief Check if \p BB contains a store to the same address as \p SI +/// +/// \return The store in \p when it is safe to sink. Otherwise return Null. +/// +StoreInst *MergedLoadStoreMotion::canSinkFromBlock(BasicBlock *BB1, + StoreInst *Store0) { + DEBUG(dbgs() << "can Sink? : "; Store0->dump(); dbgs() << "\n"); + BasicBlock *BB0 = Store0->getParent(); + for (Instruction &Inst : reverse(*BB1)) { + auto *Store1 = dyn_cast<StoreInst>(&Inst); + if (!Store1) + continue; + + MemoryLocation Loc0 = MemoryLocation::get(Store0); + MemoryLocation Loc1 = MemoryLocation::get(Store1); + if (AA->isMustAlias(Loc0, Loc1) && Store0->isSameOperationAs(Store1) && + !isStoreSinkBarrierInRange(*Store1->getNextNode(), BB1->back(), Loc1) && + !isStoreSinkBarrierInRange(*Store0->getNextNode(), BB0->back(), Loc0)) { + return Store1; + } + } + return nullptr; +} + +/// +/// \brief Create a PHI node in BB for the operands of S0 and S1 +/// +PHINode *MergedLoadStoreMotion::getPHIOperand(BasicBlock *BB, StoreInst *S0, + StoreInst *S1) { + // Create a phi if the values mismatch. + Value *Opd1 = S0->getValueOperand(); + Value *Opd2 = S1->getValueOperand(); + if (Opd1 == Opd2) + return nullptr; + + auto *NewPN = PHINode::Create(Opd1->getType(), 2, Opd2->getName() + ".sink", + &BB->front()); + NewPN->addIncoming(Opd1, S0->getParent()); + NewPN->addIncoming(Opd2, S1->getParent()); + if (MD && NewPN->getType()->isPtrOrPtrVectorTy()) + MD->invalidateCachedPointerInfo(NewPN); + return NewPN; +} + +/// +/// \brief Merge two stores to same address and sink into \p BB +/// +/// Also sinks GEP instruction computing the store address +/// +bool MergedLoadStoreMotion::sinkStore(BasicBlock *BB, StoreInst *S0, + StoreInst *S1) { + // Only one definition? + auto *A0 = dyn_cast<Instruction>(S0->getPointerOperand()); + auto *A1 = dyn_cast<Instruction>(S1->getPointerOperand()); + if (A0 && A1 && A0->isIdenticalTo(A1) && A0->hasOneUse() && + (A0->getParent() == S0->getParent()) && A1->hasOneUse() && + (A1->getParent() == S1->getParent()) && isa<GetElementPtrInst>(A0)) { + DEBUG(dbgs() << "Sink Instruction into BB \n"; BB->dump(); + dbgs() << "Instruction Left\n"; S0->dump(); dbgs() << "\n"; + dbgs() << "Instruction Right\n"; S1->dump(); dbgs() << "\n"); + // Hoist the instruction. + BasicBlock::iterator InsertPt = BB->getFirstInsertionPt(); + // Intersect optional metadata. + S0->andIRFlags(S1); + S0->dropUnknownNonDebugMetadata(); + + // Create the new store to be inserted at the join point. + StoreInst *SNew = cast<StoreInst>(S0->clone()); + Instruction *ANew = A0->clone(); + SNew->insertBefore(&*InsertPt); + ANew->insertBefore(SNew); + + assert(S0->getParent() == A0->getParent()); + assert(S1->getParent() == A1->getParent()); + + // New PHI operand? Use it. + if (PHINode *NewPN = getPHIOperand(BB, S0, S1)) + SNew->setOperand(0, NewPN); + removeInstruction(S0); + removeInstruction(S1); + A0->replaceAllUsesWith(ANew); + removeInstruction(A0); + A1->replaceAllUsesWith(ANew); + removeInstruction(A1); + return true; + } + return false; +} + +/// +/// \brief True when two stores are equivalent and can sink into the footer +/// +/// Starting from a diamond tail block, iterate over the instructions in one +/// predecessor block and try to match a store in the second predecessor. +/// +bool MergedLoadStoreMotion::mergeStores(BasicBlock *T) { + + bool MergedStores = false; + assert(T && "Footer of a diamond cannot be empty"); + + pred_iterator PI = pred_begin(T), E = pred_end(T); + assert(PI != E); + BasicBlock *Pred0 = *PI; + ++PI; + BasicBlock *Pred1 = *PI; + ++PI; + // tail block of a diamond/hammock? + if (Pred0 == Pred1) + return false; // No. + if (PI != E) + return false; // No. More than 2 predecessors. + + // #Instructions in Succ1 for Compile Time Control + int Size1 = Pred1->size(); + int NStores = 0; + + for (BasicBlock::reverse_iterator RBI = Pred0->rbegin(), RBE = Pred0->rend(); + RBI != RBE;) { + + Instruction *I = &*RBI; + ++RBI; + + // Don't sink non-simple (atomic, volatile) stores. + auto *S0 = dyn_cast<StoreInst>(I); + if (!S0 || !S0->isSimple()) + continue; + + ++NStores; + if (NStores * Size1 >= MagicCompileTimeControl) + break; + if (StoreInst *S1 = canSinkFromBlock(Pred1, S0)) { + bool Res = sinkStore(T, S0, S1); + MergedStores |= Res; + // Don't attempt to sink below stores that had to stick around + // But after removal of a store and some of its feeding + // instruction search again from the beginning since the iterator + // is likely stale at this point. + if (!Res) + break; + RBI = Pred0->rbegin(); + RBE = Pred0->rend(); + DEBUG(dbgs() << "Search again\n"; Instruction *I = &*RBI; I->dump()); + } + } + return MergedStores; +} + +bool MergedLoadStoreMotion::run(Function &F, MemoryDependenceResults *MD, + AliasAnalysis &AA) { + this->MD = MD; + this->AA = &AA; + + bool Changed = false; + DEBUG(dbgs() << "Instruction Merger\n"); + + // Merge unconditional branches, allowing PRE to catch more + // optimization opportunities. + for (Function::iterator FI = F.begin(), FE = F.end(); FI != FE;) { + BasicBlock *BB = &*FI++; + + // Hoist equivalent loads and sink stores + // outside diamonds when possible + if (isDiamondHead(BB)) { + Changed |= mergeStores(getDiamondTail(BB)); + } + } + return Changed; +} + +namespace { +class MergedLoadStoreMotionLegacyPass : public FunctionPass { +public: + static char ID; // Pass identification, replacement for typeid + MergedLoadStoreMotionLegacyPass() : FunctionPass(ID) { + initializeMergedLoadStoreMotionLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + /// + /// \brief Run the transformation for each function + /// + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + MergedLoadStoreMotion Impl; + auto *MDWP = getAnalysisIfAvailable<MemoryDependenceWrapperPass>(); + return Impl.run(F, MDWP ? &MDWP->getMemDep() : nullptr, + getAnalysis<AAResultsWrapperPass>().getAAResults()); + } + +private: + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<MemoryDependenceWrapperPass>(); + } +}; + +char MergedLoadStoreMotionLegacyPass::ID = 0; +} // anonymous namespace + +/// +/// \brief createMergedLoadStoreMotionPass - The public interface to this file. +/// +FunctionPass *llvm::createMergedLoadStoreMotionPass() { + return new MergedLoadStoreMotionLegacyPass(); +} + +INITIALIZE_PASS_BEGIN(MergedLoadStoreMotionLegacyPass, "mldst-motion", + "MergedLoadStoreMotion", false, false) +INITIALIZE_PASS_DEPENDENCY(MemoryDependenceWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(MergedLoadStoreMotionLegacyPass, "mldst-motion", + "MergedLoadStoreMotion", false, false) + +PreservedAnalyses +MergedLoadStoreMotionPass::run(Function &F, FunctionAnalysisManager &AM) { + MergedLoadStoreMotion Impl; + auto *MD = AM.getCachedResult<MemoryDependenceAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + if (!Impl.run(F, MD, AA)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + PA.preserve<MemoryDependenceAnalysis>(); + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp b/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp new file mode 100644 index 000000000000..b026c8d692c3 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/NaryReassociate.cpp @@ -0,0 +1,538 @@ +//===- NaryReassociate.cpp - Reassociate n-ary expressions ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass reassociates n-ary add expressions and eliminates the redundancy +// exposed by the reassociation. +// +// A motivating example: +// +// void foo(int a, int b) { +// bar(a + b); +// bar((a + 2) + b); +// } +// +// An ideal compiler should reassociate (a + 2) + b to (a + b) + 2 and simplify +// the above code to +// +// int t = a + b; +// bar(t); +// bar(t + 2); +// +// However, the Reassociate pass is unable to do that because it processes each +// instruction individually and believes (a + 2) + b is the best form according +// to its rank system. +// +// To address this limitation, NaryReassociate reassociates an expression in a +// form that reuses existing instructions. As a result, NaryReassociate can +// reassociate (a + 2) + b in the example to (a + b) + 2 because it detects that +// (a + b) is computed before. +// +// NaryReassociate works as follows. For every instruction in the form of (a + +// b) + c, it checks whether a + c or b + c is already computed by a dominating +// instruction. If so, it then reassociates (a + b) + c into (a + c) + b or (b + +// c) + a and removes the redundancy accordingly. To efficiently look up whether +// an expression is computed before, we store each instruction seen and its SCEV +// into an SCEV-to-instruction map. +// +// Although the algorithm pattern-matches only ternary additions, it +// automatically handles many >3-ary expressions by walking through the function +// in the depth-first order. For example, given +// +// (a + c) + d +// ((a + b) + c) + d +// +// NaryReassociate first rewrites (a + b) + c to (a + c) + b, and then rewrites +// ((a + c) + b) + d into ((a + c) + d) + b. +// +// Finally, the above dominator-based algorithm may need to be run multiple +// iterations before emitting optimal code. One source of this need is that we +// only split an operand when it is used only once. The above algorithm can +// eliminate an instruction and decrease the usage count of its operands. As a +// result, an instruction that previously had multiple uses may become a +// single-use instruction and thus eligible for split consideration. For +// example, +// +// ac = a + c +// ab = a + b +// abc = ab + c +// ab2 = ab + b +// ab2c = ab2 + c +// +// In the first iteration, we cannot reassociate abc to ac+b because ab is used +// twice. However, we can reassociate ab2c to abc+b in the first iteration. As a +// result, ab2 becomes dead and ab will be used only once in the second +// iteration. +// +// Limitations and TODO items: +// +// 1) We only considers n-ary adds and muls for now. This should be extended +// and generalized. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/NaryReassociate.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> + +using namespace llvm; +using namespace PatternMatch; + +#define DEBUG_TYPE "nary-reassociate" + +namespace { + +class NaryReassociateLegacyPass : public FunctionPass { +public: + static char ID; + + NaryReassociateLegacyPass() : FunctionPass(ID) { + initializeNaryReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool doInitialization(Module &M) override { + return false; + } + + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<ScalarEvolutionWrapperPass>(); + AU.addPreserved<TargetLibraryInfoWrapperPass>(); + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.setPreservesCFG(); + } + +private: + NaryReassociatePass Impl; +}; + +} // end anonymous namespace + +char NaryReassociateLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(NaryReassociateLegacyPass, "nary-reassociate", + "Nary reassociation", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(NaryReassociateLegacyPass, "nary-reassociate", + "Nary reassociation", false, false) + +FunctionPass *llvm::createNaryReassociatePass() { + return new NaryReassociateLegacyPass(); +} + +bool NaryReassociateLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + auto *TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + return Impl.runImpl(F, AC, DT, SE, TLI, TTI); +} + +PreservedAnalyses NaryReassociatePass::run(Function &F, + FunctionAnalysisManager &AM) { + auto *AC = &AM.getResult<AssumptionAnalysis>(F); + auto *DT = &AM.getResult<DominatorTreeAnalysis>(F); + auto *SE = &AM.getResult<ScalarEvolutionAnalysis>(F); + auto *TLI = &AM.getResult<TargetLibraryAnalysis>(F); + auto *TTI = &AM.getResult<TargetIRAnalysis>(F); + + if (!runImpl(F, AC, DT, SE, TLI, TTI)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<ScalarEvolutionAnalysis>(); + return PA; +} + +bool NaryReassociatePass::runImpl(Function &F, AssumptionCache *AC_, + DominatorTree *DT_, ScalarEvolution *SE_, + TargetLibraryInfo *TLI_, + TargetTransformInfo *TTI_) { + AC = AC_; + DT = DT_; + SE = SE_; + TLI = TLI_; + TTI = TTI_; + DL = &F.getParent()->getDataLayout(); + + bool Changed = false, ChangedInThisIteration; + do { + ChangedInThisIteration = doOneIteration(F); + Changed |= ChangedInThisIteration; + } while (ChangedInThisIteration); + return Changed; +} + +// Whitelist the instruction types NaryReassociate handles for now. +static bool isPotentiallyNaryReassociable(Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::GetElementPtr: + case Instruction::Mul: + return true; + default: + return false; + } +} + +bool NaryReassociatePass::doOneIteration(Function &F) { + bool Changed = false; + SeenExprs.clear(); + // Process the basic blocks in a depth first traversal of the dominator + // tree. This order ensures that all bases of a candidate are in Candidates + // when we process it. + for (const auto Node : depth_first(DT)) { + BasicBlock *BB = Node->getBlock(); + for (auto I = BB->begin(); I != BB->end(); ++I) { + if (SE->isSCEVable(I->getType()) && isPotentiallyNaryReassociable(&*I)) { + const SCEV *OldSCEV = SE->getSCEV(&*I); + if (Instruction *NewI = tryReassociate(&*I)) { + Changed = true; + SE->forgetValue(&*I); + I->replaceAllUsesWith(NewI); + // If SeenExprs constains I's WeakTrackingVH, that entry will be + // replaced with + // nullptr. + RecursivelyDeleteTriviallyDeadInstructions(&*I, TLI); + I = NewI->getIterator(); + } + // Add the rewritten instruction to SeenExprs; the original instruction + // is deleted. + const SCEV *NewSCEV = SE->getSCEV(&*I); + SeenExprs[NewSCEV].push_back(WeakTrackingVH(&*I)); + // Ideally, NewSCEV should equal OldSCEV because tryReassociate(I) + // is equivalent to I. However, ScalarEvolution::getSCEV may + // weaken nsw causing NewSCEV not to equal OldSCEV. For example, suppose + // we reassociate + // I = &a[sext(i +nsw j)] // assuming sizeof(a[0]) = 4 + // to + // NewI = &a[sext(i)] + sext(j). + // + // ScalarEvolution computes + // getSCEV(I) = a + 4 * sext(i + j) + // getSCEV(newI) = a + 4 * sext(i) + 4 * sext(j) + // which are different SCEVs. + // + // To alleviate this issue of ScalarEvolution not always capturing + // equivalence, we add I to SeenExprs[OldSCEV] as well so that we can + // map both SCEV before and after tryReassociate(I) to I. + // + // This improvement is exercised in @reassociate_gep_nsw in nary-gep.ll. + if (NewSCEV != OldSCEV) + SeenExprs[OldSCEV].push_back(WeakTrackingVH(&*I)); + } + } + } + return Changed; +} + +Instruction *NaryReassociatePass::tryReassociate(Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Add: + case Instruction::Mul: + return tryReassociateBinaryOp(cast<BinaryOperator>(I)); + case Instruction::GetElementPtr: + return tryReassociateGEP(cast<GetElementPtrInst>(I)); + default: + llvm_unreachable("should be filtered out by isPotentiallyNaryReassociable"); + } +} + +static bool isGEPFoldable(GetElementPtrInst *GEP, + const TargetTransformInfo *TTI) { + SmallVector<const Value*, 4> Indices; + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) + Indices.push_back(*I); + return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(), + Indices) == TargetTransformInfo::TCC_Free; +} + +Instruction *NaryReassociatePass::tryReassociateGEP(GetElementPtrInst *GEP) { + // Not worth reassociating GEP if it is foldable. + if (isGEPFoldable(GEP, TTI)) + return nullptr; + + gep_type_iterator GTI = gep_type_begin(*GEP); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isSequential()) { + if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I - 1, + GTI.getIndexedType())) { + return NewGEP; + } + } + } + return nullptr; +} + +bool NaryReassociatePass::requiresSignExtension(Value *Index, + GetElementPtrInst *GEP) { + unsigned PointerSizeInBits = + DL->getPointerSizeInBits(GEP->getType()->getPointerAddressSpace()); + return cast<IntegerType>(Index->getType())->getBitWidth() < PointerSizeInBits; +} + +GetElementPtrInst * +NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, + unsigned I, Type *IndexedType) { + Value *IndexToSplit = GEP->getOperand(I + 1); + if (SExtInst *SExt = dyn_cast<SExtInst>(IndexToSplit)) { + IndexToSplit = SExt->getOperand(0); + } else if (ZExtInst *ZExt = dyn_cast<ZExtInst>(IndexToSplit)) { + // zext can be treated as sext if the source is non-negative. + if (isKnownNonNegative(ZExt->getOperand(0), *DL, 0, AC, GEP, DT)) + IndexToSplit = ZExt->getOperand(0); + } + + if (AddOperator *AO = dyn_cast<AddOperator>(IndexToSplit)) { + // If the I-th index needs sext and the underlying add is not equipped with + // nsw, we cannot split the add because + // sext(LHS + RHS) != sext(LHS) + sext(RHS). + if (requiresSignExtension(IndexToSplit, GEP) && + computeOverflowForSignedAdd(AO, *DL, AC, GEP, DT) != + OverflowResult::NeverOverflows) + return nullptr; + + Value *LHS = AO->getOperand(0), *RHS = AO->getOperand(1); + // IndexToSplit = LHS + RHS. + if (auto *NewGEP = tryReassociateGEPAtIndex(GEP, I, LHS, RHS, IndexedType)) + return NewGEP; + // Symmetrically, try IndexToSplit = RHS + LHS. + if (LHS != RHS) { + if (auto *NewGEP = + tryReassociateGEPAtIndex(GEP, I, RHS, LHS, IndexedType)) + return NewGEP; + } + } + return nullptr; +} + +GetElementPtrInst * +NaryReassociatePass::tryReassociateGEPAtIndex(GetElementPtrInst *GEP, + unsigned I, Value *LHS, + Value *RHS, Type *IndexedType) { + // Look for GEP's closest dominator that has the same SCEV as GEP except that + // the I-th index is replaced with LHS. + SmallVector<const SCEV *, 4> IndexExprs; + for (auto Index = GEP->idx_begin(); Index != GEP->idx_end(); ++Index) + IndexExprs.push_back(SE->getSCEV(*Index)); + // Replace the I-th index with LHS. + IndexExprs[I] = SE->getSCEV(LHS); + if (isKnownNonNegative(LHS, *DL, 0, AC, GEP, DT) && + DL->getTypeSizeInBits(LHS->getType()) < + DL->getTypeSizeInBits(GEP->getOperand(I)->getType())) { + // Zero-extend LHS if it is non-negative. InstCombine canonicalizes sext to + // zext if the source operand is proved non-negative. We should do that + // consistently so that CandidateExpr more likely appears before. See + // @reassociate_gep_assume for an example of this canonicalization. + IndexExprs[I] = + SE->getZeroExtendExpr(IndexExprs[I], GEP->getOperand(I)->getType()); + } + const SCEV *CandidateExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), + IndexExprs); + + Value *Candidate = findClosestMatchingDominator(CandidateExpr, GEP); + if (Candidate == nullptr) + return nullptr; + + IRBuilder<> Builder(GEP); + // Candidate does not necessarily have the same pointer type as GEP. Use + // bitcast or pointer cast to make sure they have the same type, so that the + // later RAUW doesn't complain. + Candidate = Builder.CreateBitOrPointerCast(Candidate, GEP->getType()); + assert(Candidate->getType() == GEP->getType()); + + // NewGEP = (char *)Candidate + RHS * sizeof(IndexedType) + uint64_t IndexedSize = DL->getTypeAllocSize(IndexedType); + Type *ElementType = GEP->getResultElementType(); + uint64_t ElementSize = DL->getTypeAllocSize(ElementType); + // Another less rare case: because I is not necessarily the last index of the + // GEP, the size of the type at the I-th index (IndexedSize) is not + // necessarily divisible by ElementSize. For example, + // + // #pragma pack(1) + // struct S { + // int a[3]; + // int64 b[8]; + // }; + // #pragma pack() + // + // sizeof(S) = 100 is indivisible by sizeof(int64) = 8. + // + // TODO: bail out on this case for now. We could emit uglygep. + if (IndexedSize % ElementSize != 0) + return nullptr; + + // NewGEP = &Candidate[RHS * (sizeof(IndexedType) / sizeof(Candidate[0]))); + Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + if (RHS->getType() != IntPtrTy) + RHS = Builder.CreateSExtOrTrunc(RHS, IntPtrTy); + if (IndexedSize != ElementSize) { + RHS = Builder.CreateMul( + RHS, ConstantInt::get(IntPtrTy, IndexedSize / ElementSize)); + } + GetElementPtrInst *NewGEP = + cast<GetElementPtrInst>(Builder.CreateGEP(Candidate, RHS)); + NewGEP->setIsInBounds(GEP->isInBounds()); + NewGEP->takeName(GEP); + return NewGEP; +} + +Instruction *NaryReassociatePass::tryReassociateBinaryOp(BinaryOperator *I) { + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + if (auto *NewI = tryReassociateBinaryOp(LHS, RHS, I)) + return NewI; + if (auto *NewI = tryReassociateBinaryOp(RHS, LHS, I)) + return NewI; + return nullptr; +} + +Instruction *NaryReassociatePass::tryReassociateBinaryOp(Value *LHS, Value *RHS, + BinaryOperator *I) { + Value *A = nullptr, *B = nullptr; + // To be conservative, we reassociate I only when it is the only user of (A op + // B). + if (LHS->hasOneUse() && matchTernaryOp(I, LHS, A, B)) { + // I = (A op B) op RHS + // = (A op RHS) op B or (B op RHS) op A + const SCEV *AExpr = SE->getSCEV(A), *BExpr = SE->getSCEV(B); + const SCEV *RHSExpr = SE->getSCEV(RHS); + if (BExpr != RHSExpr) { + if (auto *NewI = + tryReassociatedBinaryOp(getBinarySCEV(I, AExpr, RHSExpr), B, I)) + return NewI; + } + if (AExpr != RHSExpr) { + if (auto *NewI = + tryReassociatedBinaryOp(getBinarySCEV(I, BExpr, RHSExpr), A, I)) + return NewI; + } + } + return nullptr; +} + +Instruction *NaryReassociatePass::tryReassociatedBinaryOp(const SCEV *LHSExpr, + Value *RHS, + BinaryOperator *I) { + // Look for the closest dominator LHS of I that computes LHSExpr, and replace + // I with LHS op RHS. + auto *LHS = findClosestMatchingDominator(LHSExpr, I); + if (LHS == nullptr) + return nullptr; + + Instruction *NewI = nullptr; + switch (I->getOpcode()) { + case Instruction::Add: + NewI = BinaryOperator::CreateAdd(LHS, RHS, "", I); + break; + case Instruction::Mul: + NewI = BinaryOperator::CreateMul(LHS, RHS, "", I); + break; + default: + llvm_unreachable("Unexpected instruction."); + } + NewI->takeName(I); + return NewI; +} + +bool NaryReassociatePass::matchTernaryOp(BinaryOperator *I, Value *V, + Value *&Op1, Value *&Op2) { + switch (I->getOpcode()) { + case Instruction::Add: + return match(V, m_Add(m_Value(Op1), m_Value(Op2))); + case Instruction::Mul: + return match(V, m_Mul(m_Value(Op1), m_Value(Op2))); + default: + llvm_unreachable("Unexpected instruction."); + } + return false; +} + +const SCEV *NaryReassociatePass::getBinarySCEV(BinaryOperator *I, + const SCEV *LHS, + const SCEV *RHS) { + switch (I->getOpcode()) { + case Instruction::Add: + return SE->getAddExpr(LHS, RHS); + case Instruction::Mul: + return SE->getMulExpr(LHS, RHS); + default: + llvm_unreachable("Unexpected instruction."); + } + return nullptr; +} + +Instruction * +NaryReassociatePass::findClosestMatchingDominator(const SCEV *CandidateExpr, + Instruction *Dominatee) { + auto Pos = SeenExprs.find(CandidateExpr); + if (Pos == SeenExprs.end()) + return nullptr; + + auto &Candidates = Pos->second; + // Because we process the basic blocks in pre-order of the dominator tree, a + // candidate that doesn't dominate the current instruction won't dominate any + // future instruction either. Therefore, we pop it out of the stack. This + // optimization makes the algorithm O(n). + while (!Candidates.empty()) { + // Candidates stores WeakTrackingVHs, so a candidate can be nullptr if it's + // removed + // during rewriting. + if (Value *Candidate = Candidates.back()) { + Instruction *CandidateInstruction = cast<Instruction>(Candidate); + if (DT->dominates(CandidateInstruction, Dominatee)) + return CandidateInstruction; + } + Candidates.pop_back(); + } + return nullptr; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp b/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp new file mode 100644 index 000000000000..9ebf2d769356 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/NewGVN.cpp @@ -0,0 +1,4252 @@ +//===- NewGVN.cpp - Global Value Numbering Pass ---------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This file implements the new LLVM's Global Value Numbering pass. +/// GVN partitions values computed by a function into congruence classes. +/// Values ending up in the same congruence class are guaranteed to be the same +/// for every execution of the program. In that respect, congruency is a +/// compile-time approximation of equivalence of values at runtime. +/// The algorithm implemented here uses a sparse formulation and it's based +/// on the ideas described in the paper: +/// "A Sparse Algorithm for Predicated Global Value Numbering" from +/// Karthik Gargi. +/// +/// A brief overview of the algorithm: The algorithm is essentially the same as +/// the standard RPO value numbering algorithm (a good reference is the paper +/// "SCC based value numbering" by L. Taylor Simpson) with one major difference: +/// The RPO algorithm proceeds, on every iteration, to process every reachable +/// block and every instruction in that block. This is because the standard RPO +/// algorithm does not track what things have the same value number, it only +/// tracks what the value number of a given operation is (the mapping is +/// operation -> value number). Thus, when a value number of an operation +/// changes, it must reprocess everything to ensure all uses of a value number +/// get updated properly. In constrast, the sparse algorithm we use *also* +/// tracks what operations have a given value number (IE it also tracks the +/// reverse mapping from value number -> operations with that value number), so +/// that it only needs to reprocess the instructions that are affected when +/// something's value number changes. The vast majority of complexity and code +/// in this file is devoted to tracking what value numbers could change for what +/// instructions when various things happen. The rest of the algorithm is +/// devoted to performing symbolic evaluation, forward propagation, and +/// simplification of operations based on the value numbers deduced so far +/// +/// In order to make the GVN mostly-complete, we use a technique derived from +/// "Detection of Redundant Expressions: A Complete and Polynomial-time +/// Algorithm in SSA" by R.R. Pai. The source of incompleteness in most SSA +/// based GVN algorithms is related to their inability to detect equivalence +/// between phi of ops (IE phi(a+b, c+d)) and op of phis (phi(a,c) + phi(b, d)). +/// We resolve this issue by generating the equivalent "phi of ops" form for +/// each op of phis we see, in a way that only takes polynomial time to resolve. +/// +/// We also do not perform elimination by using any published algorithm. All +/// published algorithms are O(Instructions). Instead, we use a technique that +/// is O(number of operations with the same value number), enabling us to skip +/// trying to eliminate things that have unique value numbers. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/NewGVN.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/BitVector.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SparseBitVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFGPrinter.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/MemorySSA.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Allocator.h" +#include "llvm/Support/ArrayRecycler.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/DebugCounter.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/PointerLikeTypeTraits.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/GVNExpression.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/PredicateInfo.h" +#include "llvm/Transforms/Utils/VNCoercion.h" +#include <algorithm> +#include <cassert> +#include <cstdint> +#include <iterator> +#include <map> +#include <memory> +#include <set> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +using namespace llvm; +using namespace llvm::GVNExpression; +using namespace llvm::VNCoercion; + +#define DEBUG_TYPE "newgvn" + +STATISTIC(NumGVNInstrDeleted, "Number of instructions deleted"); +STATISTIC(NumGVNBlocksDeleted, "Number of blocks deleted"); +STATISTIC(NumGVNOpsSimplified, "Number of Expressions simplified"); +STATISTIC(NumGVNPhisAllSame, "Number of PHIs whos arguments are all the same"); +STATISTIC(NumGVNMaxIterations, + "Maximum Number of iterations it took to converge GVN"); +STATISTIC(NumGVNLeaderChanges, "Number of leader changes"); +STATISTIC(NumGVNSortedLeaderChanges, "Number of sorted leader changes"); +STATISTIC(NumGVNAvoidedSortedLeaderChanges, + "Number of avoided sorted leader changes"); +STATISTIC(NumGVNDeadStores, "Number of redundant/dead stores eliminated"); +STATISTIC(NumGVNPHIOfOpsCreated, "Number of PHI of ops created"); +STATISTIC(NumGVNPHIOfOpsEliminations, + "Number of things eliminated using PHI of ops"); +DEBUG_COUNTER(VNCounter, "newgvn-vn", + "Controls which instructions are value numbered"); +DEBUG_COUNTER(PHIOfOpsCounter, "newgvn-phi", + "Controls which instructions we create phi of ops for"); +// Currently store defining access refinement is too slow due to basicaa being +// egregiously slow. This flag lets us keep it working while we work on this +// issue. +static cl::opt<bool> EnableStoreRefinement("enable-store-refinement", + cl::init(false), cl::Hidden); + +/// Currently, the generation "phi of ops" can result in correctness issues. +static cl::opt<bool> EnablePhiOfOps("enable-phi-of-ops", cl::init(true), + cl::Hidden); + +//===----------------------------------------------------------------------===// +// GVN Pass +//===----------------------------------------------------------------------===// + +// Anchor methods. +namespace llvm { +namespace GVNExpression { + +Expression::~Expression() = default; +BasicExpression::~BasicExpression() = default; +CallExpression::~CallExpression() = default; +LoadExpression::~LoadExpression() = default; +StoreExpression::~StoreExpression() = default; +AggregateValueExpression::~AggregateValueExpression() = default; +PHIExpression::~PHIExpression() = default; + +} // end namespace GVNExpression +} // end namespace llvm + +namespace { + +// Tarjan's SCC finding algorithm with Nuutila's improvements +// SCCIterator is actually fairly complex for the simple thing we want. +// It also wants to hand us SCC's that are unrelated to the phi node we ask +// about, and have us process them there or risk redoing work. +// Graph traits over a filter iterator also doesn't work that well here. +// This SCC finder is specialized to walk use-def chains, and only follows +// instructions, +// not generic values (arguments, etc). +struct TarjanSCC { + TarjanSCC() : Components(1) {} + + void Start(const Instruction *Start) { + if (Root.lookup(Start) == 0) + FindSCC(Start); + } + + const SmallPtrSetImpl<const Value *> &getComponentFor(const Value *V) const { + unsigned ComponentID = ValueToComponent.lookup(V); + + assert(ComponentID > 0 && + "Asking for a component for a value we never processed"); + return Components[ComponentID]; + } + +private: + void FindSCC(const Instruction *I) { + Root[I] = ++DFSNum; + // Store the DFS Number we had before it possibly gets incremented. + unsigned int OurDFS = DFSNum; + for (auto &Op : I->operands()) { + if (auto *InstOp = dyn_cast<Instruction>(Op)) { + if (Root.lookup(Op) == 0) + FindSCC(InstOp); + if (!InComponent.count(Op)) + Root[I] = std::min(Root.lookup(I), Root.lookup(Op)); + } + } + // See if we really were the root of a component, by seeing if we still have + // our DFSNumber. If we do, we are the root of the component, and we have + // completed a component. If we do not, we are not the root of a component, + // and belong on the component stack. + if (Root.lookup(I) == OurDFS) { + unsigned ComponentID = Components.size(); + Components.resize(Components.size() + 1); + auto &Component = Components.back(); + Component.insert(I); + DEBUG(dbgs() << "Component root is " << *I << "\n"); + InComponent.insert(I); + ValueToComponent[I] = ComponentID; + // Pop a component off the stack and label it. + while (!Stack.empty() && Root.lookup(Stack.back()) >= OurDFS) { + auto *Member = Stack.back(); + DEBUG(dbgs() << "Component member is " << *Member << "\n"); + Component.insert(Member); + InComponent.insert(Member); + ValueToComponent[Member] = ComponentID; + Stack.pop_back(); + } + } else { + // Part of a component, push to stack + Stack.push_back(I); + } + } + + unsigned int DFSNum = 1; + SmallPtrSet<const Value *, 8> InComponent; + DenseMap<const Value *, unsigned int> Root; + SmallVector<const Value *, 8> Stack; + + // Store the components as vector of ptr sets, because we need the topo order + // of SCC's, but not individual member order + SmallVector<SmallPtrSet<const Value *, 8>, 8> Components; + + DenseMap<const Value *, unsigned> ValueToComponent; +}; + +// Congruence classes represent the set of expressions/instructions +// that are all the same *during some scope in the function*. +// That is, because of the way we perform equality propagation, and +// because of memory value numbering, it is not correct to assume +// you can willy-nilly replace any member with any other at any +// point in the function. +// +// For any Value in the Member set, it is valid to replace any dominated member +// with that Value. +// +// Every congruence class has a leader, and the leader is used to symbolize +// instructions in a canonical way (IE every operand of an instruction that is a +// member of the same congruence class will always be replaced with leader +// during symbolization). To simplify symbolization, we keep the leader as a +// constant if class can be proved to be a constant value. Otherwise, the +// leader is the member of the value set with the smallest DFS number. Each +// congruence class also has a defining expression, though the expression may be +// null. If it exists, it can be used for forward propagation and reassociation +// of values. + +// For memory, we also track a representative MemoryAccess, and a set of memory +// members for MemoryPhis (which have no real instructions). Note that for +// memory, it seems tempting to try to split the memory members into a +// MemoryCongruenceClass or something. Unfortunately, this does not work +// easily. The value numbering of a given memory expression depends on the +// leader of the memory congruence class, and the leader of memory congruence +// class depends on the value numbering of a given memory expression. This +// leads to wasted propagation, and in some cases, missed optimization. For +// example: If we had value numbered two stores together before, but now do not, +// we move them to a new value congruence class. This in turn will move at one +// of the memorydefs to a new memory congruence class. Which in turn, affects +// the value numbering of the stores we just value numbered (because the memory +// congruence class is part of the value number). So while theoretically +// possible to split them up, it turns out to be *incredibly* complicated to get +// it to work right, because of the interdependency. While structurally +// slightly messier, it is algorithmically much simpler and faster to do what we +// do here, and track them both at once in the same class. +// Note: The default iterators for this class iterate over values +class CongruenceClass { +public: + using MemberType = Value; + using MemberSet = SmallPtrSet<MemberType *, 4>; + using MemoryMemberType = MemoryPhi; + using MemoryMemberSet = SmallPtrSet<const MemoryMemberType *, 2>; + + explicit CongruenceClass(unsigned ID) : ID(ID) {} + CongruenceClass(unsigned ID, Value *Leader, const Expression *E) + : ID(ID), RepLeader(Leader), DefiningExpr(E) {} + + unsigned getID() const { return ID; } + + // True if this class has no members left. This is mainly used for assertion + // purposes, and for skipping empty classes. + bool isDead() const { + // If it's both dead from a value perspective, and dead from a memory + // perspective, it's really dead. + return empty() && memory_empty(); + } + + // Leader functions + Value *getLeader() const { return RepLeader; } + void setLeader(Value *Leader) { RepLeader = Leader; } + const std::pair<Value *, unsigned int> &getNextLeader() const { + return NextLeader; + } + void resetNextLeader() { NextLeader = {nullptr, ~0}; } + void addPossibleNextLeader(std::pair<Value *, unsigned int> LeaderPair) { + if (LeaderPair.second < NextLeader.second) + NextLeader = LeaderPair; + } + + Value *getStoredValue() const { return RepStoredValue; } + void setStoredValue(Value *Leader) { RepStoredValue = Leader; } + const MemoryAccess *getMemoryLeader() const { return RepMemoryAccess; } + void setMemoryLeader(const MemoryAccess *Leader) { RepMemoryAccess = Leader; } + + // Forward propagation info + const Expression *getDefiningExpr() const { return DefiningExpr; } + + // Value member set + bool empty() const { return Members.empty(); } + unsigned size() const { return Members.size(); } + MemberSet::const_iterator begin() const { return Members.begin(); } + MemberSet::const_iterator end() const { return Members.end(); } + void insert(MemberType *M) { Members.insert(M); } + void erase(MemberType *M) { Members.erase(M); } + void swap(MemberSet &Other) { Members.swap(Other); } + + // Memory member set + bool memory_empty() const { return MemoryMembers.empty(); } + unsigned memory_size() const { return MemoryMembers.size(); } + MemoryMemberSet::const_iterator memory_begin() const { + return MemoryMembers.begin(); + } + MemoryMemberSet::const_iterator memory_end() const { + return MemoryMembers.end(); + } + iterator_range<MemoryMemberSet::const_iterator> memory() const { + return make_range(memory_begin(), memory_end()); + } + + void memory_insert(const MemoryMemberType *M) { MemoryMembers.insert(M); } + void memory_erase(const MemoryMemberType *M) { MemoryMembers.erase(M); } + + // Store count + unsigned getStoreCount() const { return StoreCount; } + void incStoreCount() { ++StoreCount; } + void decStoreCount() { + assert(StoreCount != 0 && "Store count went negative"); + --StoreCount; + } + + // True if this class has no memory members. + bool definesNoMemory() const { return StoreCount == 0 && memory_empty(); } + + // Return true if two congruence classes are equivalent to each other. This + // means + // that every field but the ID number and the dead field are equivalent. + bool isEquivalentTo(const CongruenceClass *Other) const { + if (!Other) + return false; + if (this == Other) + return true; + + if (std::tie(StoreCount, RepLeader, RepStoredValue, RepMemoryAccess) != + std::tie(Other->StoreCount, Other->RepLeader, Other->RepStoredValue, + Other->RepMemoryAccess)) + return false; + if (DefiningExpr != Other->DefiningExpr) + if (!DefiningExpr || !Other->DefiningExpr || + *DefiningExpr != *Other->DefiningExpr) + return false; + // We need some ordered set + std::set<Value *> AMembers(Members.begin(), Members.end()); + std::set<Value *> BMembers(Members.begin(), Members.end()); + return AMembers == BMembers; + } + +private: + unsigned ID; + + // Representative leader. + Value *RepLeader = nullptr; + + // The most dominating leader after our current leader, because the member set + // is not sorted and is expensive to keep sorted all the time. + std::pair<Value *, unsigned int> NextLeader = {nullptr, ~0U}; + + // If this is represented by a store, the value of the store. + Value *RepStoredValue = nullptr; + + // If this class contains MemoryDefs or MemoryPhis, this is the leading memory + // access. + const MemoryAccess *RepMemoryAccess = nullptr; + + // Defining Expression. + const Expression *DefiningExpr = nullptr; + + // Actual members of this class. + MemberSet Members; + + // This is the set of MemoryPhis that exist in the class. MemoryDefs and + // MemoryUses have real instructions representing them, so we only need to + // track MemoryPhis here. + MemoryMemberSet MemoryMembers; + + // Number of stores in this congruence class. + // This is used so we can detect store equivalence changes properly. + int StoreCount = 0; +}; + +} // end anonymous namespace + +namespace llvm { + +struct ExactEqualsExpression { + const Expression &E; + + explicit ExactEqualsExpression(const Expression &E) : E(E) {} + + hash_code getComputedHash() const { return E.getComputedHash(); } + + bool operator==(const Expression &Other) const { + return E.exactlyEquals(Other); + } +}; + +template <> struct DenseMapInfo<const Expression *> { + static const Expression *getEmptyKey() { + auto Val = static_cast<uintptr_t>(-1); + Val <<= PointerLikeTypeTraits<const Expression *>::NumLowBitsAvailable; + return reinterpret_cast<const Expression *>(Val); + } + + static const Expression *getTombstoneKey() { + auto Val = static_cast<uintptr_t>(~1U); + Val <<= PointerLikeTypeTraits<const Expression *>::NumLowBitsAvailable; + return reinterpret_cast<const Expression *>(Val); + } + + static unsigned getHashValue(const Expression *E) { + return E->getComputedHash(); + } + + static unsigned getHashValue(const ExactEqualsExpression &E) { + return E.getComputedHash(); + } + + static bool isEqual(const ExactEqualsExpression &LHS, const Expression *RHS) { + if (RHS == getTombstoneKey() || RHS == getEmptyKey()) + return false; + return LHS == *RHS; + } + + static bool isEqual(const Expression *LHS, const Expression *RHS) { + if (LHS == RHS) + return true; + if (LHS == getTombstoneKey() || RHS == getTombstoneKey() || + LHS == getEmptyKey() || RHS == getEmptyKey()) + return false; + // Compare hashes before equality. This is *not* what the hashtable does, + // since it is computing it modulo the number of buckets, whereas we are + // using the full hash keyspace. Since the hashes are precomputed, this + // check is *much* faster than equality. + if (LHS->getComputedHash() != RHS->getComputedHash()) + return false; + return *LHS == *RHS; + } +}; + +} // end namespace llvm + +namespace { + +class NewGVN { + Function &F; + DominatorTree *DT; + const TargetLibraryInfo *TLI; + AliasAnalysis *AA; + MemorySSA *MSSA; + MemorySSAWalker *MSSAWalker; + const DataLayout &DL; + std::unique_ptr<PredicateInfo> PredInfo; + + // These are the only two things the create* functions should have + // side-effects on due to allocating memory. + mutable BumpPtrAllocator ExpressionAllocator; + mutable ArrayRecycler<Value *> ArgRecycler; + mutable TarjanSCC SCCFinder; + const SimplifyQuery SQ; + + // Number of function arguments, used by ranking + unsigned int NumFuncArgs; + + // RPOOrdering of basic blocks + DenseMap<const DomTreeNode *, unsigned> RPOOrdering; + + // Congruence class info. + + // This class is called INITIAL in the paper. It is the class everything + // startsout in, and represents any value. Being an optimistic analysis, + // anything in the TOP class has the value TOP, which is indeterminate and + // equivalent to everything. + CongruenceClass *TOPClass; + std::vector<CongruenceClass *> CongruenceClasses; + unsigned NextCongruenceNum; + + // Value Mappings. + DenseMap<Value *, CongruenceClass *> ValueToClass; + DenseMap<Value *, const Expression *> ValueToExpression; + + // Value PHI handling, used to make equivalence between phi(op, op) and + // op(phi, phi). + // These mappings just store various data that would normally be part of the + // IR. + SmallPtrSet<const Instruction *, 8> PHINodeUses; + + DenseMap<const Value *, bool> OpSafeForPHIOfOps; + + // Map a temporary instruction we created to a parent block. + DenseMap<const Value *, BasicBlock *> TempToBlock; + + // Map between the already in-program instructions and the temporary phis we + // created that they are known equivalent to. + DenseMap<const Value *, PHINode *> RealToTemp; + + // In order to know when we should re-process instructions that have + // phi-of-ops, we track the set of expressions that they needed as + // leaders. When we discover new leaders for those expressions, we process the + // associated phi-of-op instructions again in case they have changed. The + // other way they may change is if they had leaders, and those leaders + // disappear. However, at the point they have leaders, there are uses of the + // relevant operands in the created phi node, and so they will get reprocessed + // through the normal user marking we perform. + mutable DenseMap<const Value *, SmallPtrSet<Value *, 2>> AdditionalUsers; + DenseMap<const Expression *, SmallPtrSet<Instruction *, 2>> + ExpressionToPhiOfOps; + + // Map from temporary operation to MemoryAccess. + DenseMap<const Instruction *, MemoryUseOrDef *> TempToMemory; + + // Set of all temporary instructions we created. + // Note: This will include instructions that were just created during value + // numbering. The way to test if something is using them is to check + // RealToTemp. + DenseSet<Instruction *> AllTempInstructions; + + // This is the set of instructions to revisit on a reachability change. At + // the end of the main iteration loop it will contain at least all the phi of + // ops instructions that will be changed to phis, as well as regular phis. + // During the iteration loop, it may contain other things, such as phi of ops + // instructions that used edge reachability to reach a result, and so need to + // be revisited when the edge changes, independent of whether the phi they + // depended on changes. + DenseMap<BasicBlock *, SparseBitVector<>> RevisitOnReachabilityChange; + + // Mapping from predicate info we used to the instructions we used it with. + // In order to correctly ensure propagation, we must keep track of what + // comparisons we used, so that when the values of the comparisons change, we + // propagate the information to the places we used the comparison. + mutable DenseMap<const Value *, SmallPtrSet<Instruction *, 2>> + PredicateToUsers; + + // the same reasoning as PredicateToUsers. When we skip MemoryAccesses for + // stores, we no longer can rely solely on the def-use chains of MemorySSA. + mutable DenseMap<const MemoryAccess *, SmallPtrSet<MemoryAccess *, 2>> + MemoryToUsers; + + // A table storing which memorydefs/phis represent a memory state provably + // equivalent to another memory state. + // We could use the congruence class machinery, but the MemoryAccess's are + // abstract memory states, so they can only ever be equivalent to each other, + // and not to constants, etc. + DenseMap<const MemoryAccess *, CongruenceClass *> MemoryAccessToClass; + + // We could, if we wanted, build MemoryPhiExpressions and + // MemoryVariableExpressions, etc, and value number them the same way we value + // number phi expressions. For the moment, this seems like overkill. They + // can only exist in one of three states: they can be TOP (equal to + // everything), Equivalent to something else, or unique. Because we do not + // create expressions for them, we need to simulate leader change not just + // when they change class, but when they change state. Note: We can do the + // same thing for phis, and avoid having phi expressions if we wanted, We + // should eventually unify in one direction or the other, so this is a little + // bit of an experiment in which turns out easier to maintain. + enum MemoryPhiState { MPS_Invalid, MPS_TOP, MPS_Equivalent, MPS_Unique }; + DenseMap<const MemoryPhi *, MemoryPhiState> MemoryPhiState; + + enum InstCycleState { ICS_Unknown, ICS_CycleFree, ICS_Cycle }; + mutable DenseMap<const Instruction *, InstCycleState> InstCycleState; + + // Expression to class mapping. + using ExpressionClassMap = DenseMap<const Expression *, CongruenceClass *>; + ExpressionClassMap ExpressionToClass; + + // We have a single expression that represents currently DeadExpressions. + // For dead expressions we can prove will stay dead, we mark them with + // DFS number zero. However, it's possible in the case of phi nodes + // for us to assume/prove all arguments are dead during fixpointing. + // We use DeadExpression for that case. + DeadExpression *SingletonDeadExpression = nullptr; + + // Which values have changed as a result of leader changes. + SmallPtrSet<Value *, 8> LeaderChanges; + + // Reachability info. + using BlockEdge = BasicBlockEdge; + DenseSet<BlockEdge> ReachableEdges; + SmallPtrSet<const BasicBlock *, 8> ReachableBlocks; + + // This is a bitvector because, on larger functions, we may have + // thousands of touched instructions at once (entire blocks, + // instructions with hundreds of uses, etc). Even with optimization + // for when we mark whole blocks as touched, when this was a + // SmallPtrSet or DenseSet, for some functions, we spent >20% of all + // the time in GVN just managing this list. The bitvector, on the + // other hand, efficiently supports test/set/clear of both + // individual and ranges, as well as "find next element" This + // enables us to use it as a worklist with essentially 0 cost. + BitVector TouchedInstructions; + + DenseMap<const BasicBlock *, std::pair<unsigned, unsigned>> BlockInstRange; + +#ifndef NDEBUG + // Debugging for how many times each block and instruction got processed. + DenseMap<const Value *, unsigned> ProcessedCount; +#endif + + // DFS info. + // This contains a mapping from Instructions to DFS numbers. + // The numbering starts at 1. An instruction with DFS number zero + // means that the instruction is dead. + DenseMap<const Value *, unsigned> InstrDFS; + + // This contains the mapping DFS numbers to instructions. + SmallVector<Value *, 32> DFSToInstr; + + // Deletion info. + SmallPtrSet<Instruction *, 8> InstructionsToErase; + +public: + NewGVN(Function &F, DominatorTree *DT, AssumptionCache *AC, + TargetLibraryInfo *TLI, AliasAnalysis *AA, MemorySSA *MSSA, + const DataLayout &DL) + : F(F), DT(DT), TLI(TLI), AA(AA), MSSA(MSSA), DL(DL), + PredInfo(make_unique<PredicateInfo>(F, *DT, *AC)), SQ(DL, TLI, DT, AC) { + } + + bool runGVN(); + +private: + // Expression handling. + const Expression *createExpression(Instruction *) const; + const Expression *createBinaryExpression(unsigned, Type *, Value *, Value *, + Instruction *) const; + + // Our canonical form for phi arguments is a pair of incoming value, incoming + // basic block. + using ValPair = std::pair<Value *, BasicBlock *>; + + PHIExpression *createPHIExpression(ArrayRef<ValPair>, const Instruction *, + BasicBlock *, bool &HasBackEdge, + bool &OriginalOpsConstant) const; + const DeadExpression *createDeadExpression() const; + const VariableExpression *createVariableExpression(Value *) const; + const ConstantExpression *createConstantExpression(Constant *) const; + const Expression *createVariableOrConstant(Value *V) const; + const UnknownExpression *createUnknownExpression(Instruction *) const; + const StoreExpression *createStoreExpression(StoreInst *, + const MemoryAccess *) const; + LoadExpression *createLoadExpression(Type *, Value *, LoadInst *, + const MemoryAccess *) const; + const CallExpression *createCallExpression(CallInst *, + const MemoryAccess *) const; + const AggregateValueExpression * + createAggregateValueExpression(Instruction *) const; + bool setBasicExpressionInfo(Instruction *, BasicExpression *) const; + + // Congruence class handling. + CongruenceClass *createCongruenceClass(Value *Leader, const Expression *E) { + auto *result = new CongruenceClass(NextCongruenceNum++, Leader, E); + CongruenceClasses.emplace_back(result); + return result; + } + + CongruenceClass *createMemoryClass(MemoryAccess *MA) { + auto *CC = createCongruenceClass(nullptr, nullptr); + CC->setMemoryLeader(MA); + return CC; + } + + CongruenceClass *ensureLeaderOfMemoryClass(MemoryAccess *MA) { + auto *CC = getMemoryClass(MA); + if (CC->getMemoryLeader() != MA) + CC = createMemoryClass(MA); + return CC; + } + + CongruenceClass *createSingletonCongruenceClass(Value *Member) { + CongruenceClass *CClass = createCongruenceClass(Member, nullptr); + CClass->insert(Member); + ValueToClass[Member] = CClass; + return CClass; + } + + void initializeCongruenceClasses(Function &F); + const Expression *makePossiblePHIOfOps(Instruction *, + SmallPtrSetImpl<Value *> &); + Value *findLeaderForInst(Instruction *ValueOp, + SmallPtrSetImpl<Value *> &Visited, + MemoryAccess *MemAccess, Instruction *OrigInst, + BasicBlock *PredBB); + bool OpIsSafeForPHIOfOpsHelper(Value *V, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &Visited, + SmallVectorImpl<Instruction *> &Worklist); + bool OpIsSafeForPHIOfOps(Value *Op, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &); + void addPhiOfOps(PHINode *Op, BasicBlock *BB, Instruction *ExistingValue); + void removePhiOfOps(Instruction *I, PHINode *PHITemp); + + // Value number an Instruction or MemoryPhi. + void valueNumberMemoryPhi(MemoryPhi *); + void valueNumberInstruction(Instruction *); + + // Symbolic evaluation. + const Expression *checkSimplificationResults(Expression *, Instruction *, + Value *) const; + const Expression *performSymbolicEvaluation(Value *, + SmallPtrSetImpl<Value *> &) const; + const Expression *performSymbolicLoadCoercion(Type *, Value *, LoadInst *, + Instruction *, + MemoryAccess *) const; + const Expression *performSymbolicLoadEvaluation(Instruction *) const; + const Expression *performSymbolicStoreEvaluation(Instruction *) const; + const Expression *performSymbolicCallEvaluation(Instruction *) const; + void sortPHIOps(MutableArrayRef<ValPair> Ops) const; + const Expression *performSymbolicPHIEvaluation(ArrayRef<ValPair>, + Instruction *I, + BasicBlock *PHIBlock) const; + const Expression *performSymbolicAggrValueEvaluation(Instruction *) const; + const Expression *performSymbolicCmpEvaluation(Instruction *) const; + const Expression *performSymbolicPredicateInfoEvaluation(Instruction *) const; + + // Congruence finding. + bool someEquivalentDominates(const Instruction *, const Instruction *) const; + Value *lookupOperandLeader(Value *) const; + CongruenceClass *getClassForExpression(const Expression *E) const; + void performCongruenceFinding(Instruction *, const Expression *); + void moveValueToNewCongruenceClass(Instruction *, const Expression *, + CongruenceClass *, CongruenceClass *); + void moveMemoryToNewCongruenceClass(Instruction *, MemoryAccess *, + CongruenceClass *, CongruenceClass *); + Value *getNextValueLeader(CongruenceClass *) const; + const MemoryAccess *getNextMemoryLeader(CongruenceClass *) const; + bool setMemoryClass(const MemoryAccess *From, CongruenceClass *To); + CongruenceClass *getMemoryClass(const MemoryAccess *MA) const; + const MemoryAccess *lookupMemoryLeader(const MemoryAccess *) const; + bool isMemoryAccessTOP(const MemoryAccess *) const; + + // Ranking + unsigned int getRank(const Value *) const; + bool shouldSwapOperands(const Value *, const Value *) const; + + // Reachability handling. + void updateReachableEdge(BasicBlock *, BasicBlock *); + void processOutgoingEdges(TerminatorInst *, BasicBlock *); + Value *findConditionEquivalence(Value *) const; + + // Elimination. + struct ValueDFS; + void convertClassToDFSOrdered(const CongruenceClass &, + SmallVectorImpl<ValueDFS> &, + DenseMap<const Value *, unsigned int> &, + SmallPtrSetImpl<Instruction *> &) const; + void convertClassToLoadsAndStores(const CongruenceClass &, + SmallVectorImpl<ValueDFS> &) const; + + bool eliminateInstructions(Function &); + void replaceInstruction(Instruction *, Value *); + void markInstructionForDeletion(Instruction *); + void deleteInstructionsInBlock(BasicBlock *); + Value *findPHIOfOpsLeader(const Expression *, const Instruction *, + const BasicBlock *) const; + + // New instruction creation. + void handleNewInstruction(Instruction *) {} + + // Various instruction touch utilities + template <typename Map, typename KeyType, typename Func> + void for_each_found(Map &, const KeyType &, Func); + template <typename Map, typename KeyType> + void touchAndErase(Map &, const KeyType &); + void markUsersTouched(Value *); + void markMemoryUsersTouched(const MemoryAccess *); + void markMemoryDefTouched(const MemoryAccess *); + void markPredicateUsersTouched(Instruction *); + void markValueLeaderChangeTouched(CongruenceClass *CC); + void markMemoryLeaderChangeTouched(CongruenceClass *CC); + void markPhiOfOpsChanged(const Expression *E); + void addPredicateUsers(const PredicateBase *, Instruction *) const; + void addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const; + void addAdditionalUsers(Value *To, Value *User) const; + + // Main loop of value numbering + void iterateTouchedInstructions(); + + // Utilities. + void cleanupTables(); + std::pair<unsigned, unsigned> assignDFSNumbers(BasicBlock *, unsigned); + void updateProcessedCount(const Value *V); + void verifyMemoryCongruency() const; + void verifyIterationSettled(Function &F); + void verifyStoreExpressions() const; + bool singleReachablePHIPath(SmallPtrSet<const MemoryAccess *, 8> &, + const MemoryAccess *, const MemoryAccess *) const; + BasicBlock *getBlockForValue(Value *V) const; + void deleteExpression(const Expression *E) const; + MemoryUseOrDef *getMemoryAccess(const Instruction *) const; + MemoryAccess *getDefiningAccess(const MemoryAccess *) const; + MemoryPhi *getMemoryAccess(const BasicBlock *) const; + template <class T, class Range> T *getMinDFSOfRange(const Range &) const; + + unsigned InstrToDFSNum(const Value *V) const { + assert(isa<Instruction>(V) && "This should not be used for MemoryAccesses"); + return InstrDFS.lookup(V); + } + + unsigned InstrToDFSNum(const MemoryAccess *MA) const { + return MemoryToDFSNum(MA); + } + + Value *InstrFromDFSNum(unsigned DFSNum) { return DFSToInstr[DFSNum]; } + + // Given a MemoryAccess, return the relevant instruction DFS number. Note: + // This deliberately takes a value so it can be used with Use's, which will + // auto-convert to Value's but not to MemoryAccess's. + unsigned MemoryToDFSNum(const Value *MA) const { + assert(isa<MemoryAccess>(MA) && + "This should not be used with instructions"); + return isa<MemoryUseOrDef>(MA) + ? InstrToDFSNum(cast<MemoryUseOrDef>(MA)->getMemoryInst()) + : InstrDFS.lookup(MA); + } + + bool isCycleFree(const Instruction *) const; + bool isBackedge(BasicBlock *From, BasicBlock *To) const; + + // Debug counter info. When verifying, we have to reset the value numbering + // debug counter to the same state it started in to get the same results. + std::pair<int, int> StartingVNCounter; +}; + +} // end anonymous namespace + +template <typename T> +static bool equalsLoadStoreHelper(const T &LHS, const Expression &RHS) { + if (!isa<LoadExpression>(RHS) && !isa<StoreExpression>(RHS)) + return false; + return LHS.MemoryExpression::equals(RHS); +} + +bool LoadExpression::equals(const Expression &Other) const { + return equalsLoadStoreHelper(*this, Other); +} + +bool StoreExpression::equals(const Expression &Other) const { + if (!equalsLoadStoreHelper(*this, Other)) + return false; + // Make sure that store vs store includes the value operand. + if (const auto *S = dyn_cast<StoreExpression>(&Other)) + if (getStoredValue() != S->getStoredValue()) + return false; + return true; +} + +// Determine if the edge From->To is a backedge +bool NewGVN::isBackedge(BasicBlock *From, BasicBlock *To) const { + return From == To || + RPOOrdering.lookup(DT->getNode(From)) >= + RPOOrdering.lookup(DT->getNode(To)); +} + +#ifndef NDEBUG +static std::string getBlockName(const BasicBlock *B) { + return DOTGraphTraits<const Function *>::getSimpleNodeLabel(B, nullptr); +} +#endif + +// Get a MemoryAccess for an instruction, fake or real. +MemoryUseOrDef *NewGVN::getMemoryAccess(const Instruction *I) const { + auto *Result = MSSA->getMemoryAccess(I); + return Result ? Result : TempToMemory.lookup(I); +} + +// Get a MemoryPhi for a basic block. These are all real. +MemoryPhi *NewGVN::getMemoryAccess(const BasicBlock *BB) const { + return MSSA->getMemoryAccess(BB); +} + +// Get the basic block from an instruction/memory value. +BasicBlock *NewGVN::getBlockForValue(Value *V) const { + if (auto *I = dyn_cast<Instruction>(V)) { + auto *Parent = I->getParent(); + if (Parent) + return Parent; + Parent = TempToBlock.lookup(V); + assert(Parent && "Every fake instruction should have a block"); + return Parent; + } + + auto *MP = dyn_cast<MemoryPhi>(V); + assert(MP && "Should have been an instruction or a MemoryPhi"); + return MP->getBlock(); +} + +// Delete a definitely dead expression, so it can be reused by the expression +// allocator. Some of these are not in creation functions, so we have to accept +// const versions. +void NewGVN::deleteExpression(const Expression *E) const { + assert(isa<BasicExpression>(E)); + auto *BE = cast<BasicExpression>(E); + const_cast<BasicExpression *>(BE)->deallocateOperands(ArgRecycler); + ExpressionAllocator.Deallocate(E); +} + +// If V is a predicateinfo copy, get the thing it is a copy of. +static Value *getCopyOf(const Value *V) { + if (auto *II = dyn_cast<IntrinsicInst>(V)) + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + return II->getOperand(0); + return nullptr; +} + +// Return true if V is really PN, even accounting for predicateinfo copies. +static bool isCopyOfPHI(const Value *V, const PHINode *PN) { + return V == PN || getCopyOf(V) == PN; +} + +static bool isCopyOfAPHI(const Value *V) { + auto *CO = getCopyOf(V); + return CO && isa<PHINode>(CO); +} + +// Sort PHI Operands into a canonical order. What we use here is an RPO +// order. The BlockInstRange numbers are generated in an RPO walk of the basic +// blocks. +void NewGVN::sortPHIOps(MutableArrayRef<ValPair> Ops) const { + std::sort(Ops.begin(), Ops.end(), [&](const ValPair &P1, const ValPair &P2) { + return BlockInstRange.lookup(P1.second).first < + BlockInstRange.lookup(P2.second).first; + }); +} + +// Return true if V is a value that will always be available (IE can +// be placed anywhere) in the function. We don't do globals here +// because they are often worse to put in place. +static bool alwaysAvailable(Value *V) { + return isa<Constant>(V) || isa<Argument>(V); +} + +// Create a PHIExpression from an array of {incoming edge, value} pairs. I is +// the original instruction we are creating a PHIExpression for (but may not be +// a phi node). We require, as an invariant, that all the PHIOperands in the +// same block are sorted the same way. sortPHIOps will sort them into a +// canonical order. +PHIExpression *NewGVN::createPHIExpression(ArrayRef<ValPair> PHIOperands, + const Instruction *I, + BasicBlock *PHIBlock, + bool &HasBackedge, + bool &OriginalOpsConstant) const { + unsigned NumOps = PHIOperands.size(); + auto *E = new (ExpressionAllocator) PHIExpression(NumOps, PHIBlock); + + E->allocateOperands(ArgRecycler, ExpressionAllocator); + E->setType(PHIOperands.begin()->first->getType()); + E->setOpcode(Instruction::PHI); + + // Filter out unreachable phi operands. + auto Filtered = make_filter_range(PHIOperands, [&](const ValPair &P) { + auto *BB = P.second; + if (auto *PHIOp = dyn_cast<PHINode>(I)) + if (isCopyOfPHI(P.first, PHIOp)) + return false; + if (!ReachableEdges.count({BB, PHIBlock})) + return false; + // Things in TOPClass are equivalent to everything. + if (ValueToClass.lookup(P.first) == TOPClass) + return false; + OriginalOpsConstant = OriginalOpsConstant && isa<Constant>(P.first); + HasBackedge = HasBackedge || isBackedge(BB, PHIBlock); + return lookupOperandLeader(P.first) != I; + }); + std::transform(Filtered.begin(), Filtered.end(), op_inserter(E), + [&](const ValPair &P) -> Value * { + return lookupOperandLeader(P.first); + }); + return E; +} + +// Set basic expression info (Arguments, type, opcode) for Expression +// E from Instruction I in block B. +bool NewGVN::setBasicExpressionInfo(Instruction *I, BasicExpression *E) const { + bool AllConstant = true; + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) + E->setType(GEP->getSourceElementType()); + else + E->setType(I->getType()); + E->setOpcode(I->getOpcode()); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + + // Transform the operand array into an operand leader array, and keep track of + // whether all members are constant. + std::transform(I->op_begin(), I->op_end(), op_inserter(E), [&](Value *O) { + auto Operand = lookupOperandLeader(O); + AllConstant = AllConstant && isa<Constant>(Operand); + return Operand; + }); + + return AllConstant; +} + +const Expression *NewGVN::createBinaryExpression(unsigned Opcode, Type *T, + Value *Arg1, Value *Arg2, + Instruction *I) const { + auto *E = new (ExpressionAllocator) BasicExpression(2); + + E->setType(T); + E->setOpcode(Opcode); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + if (Instruction::isCommutative(Opcode)) { + // Ensure that commutative instructions that only differ by a permutation + // of their operands get the same value number by sorting the operand value + // numbers. Since all commutative instructions have two operands it is more + // efficient to sort by hand rather than using, say, std::sort. + if (shouldSwapOperands(Arg1, Arg2)) + std::swap(Arg1, Arg2); + } + E->op_push_back(lookupOperandLeader(Arg1)); + E->op_push_back(lookupOperandLeader(Arg2)); + + Value *V = SimplifyBinOp(Opcode, E->getOperand(0), E->getOperand(1), SQ); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + return E; +} + +// Take a Value returned by simplification of Expression E/Instruction +// I, and see if it resulted in a simpler expression. If so, return +// that expression. +const Expression *NewGVN::checkSimplificationResults(Expression *E, + Instruction *I, + Value *V) const { + if (!V) + return nullptr; + if (auto *C = dyn_cast<Constant>(V)) { + if (I) + DEBUG(dbgs() << "Simplified " << *I << " to " + << " constant " << *C << "\n"); + NumGVNOpsSimplified++; + assert(isa<BasicExpression>(E) && + "We should always have had a basic expression here"); + deleteExpression(E); + return createConstantExpression(C); + } else if (isa<Argument>(V) || isa<GlobalVariable>(V)) { + if (I) + DEBUG(dbgs() << "Simplified " << *I << " to " + << " variable " << *V << "\n"); + deleteExpression(E); + return createVariableExpression(V); + } + + CongruenceClass *CC = ValueToClass.lookup(V); + if (CC) { + if (CC->getLeader() && CC->getLeader() != I) { + // Don't add temporary instructions to the user lists. + if (!AllTempInstructions.count(I)) + addAdditionalUsers(V, I); + return createVariableOrConstant(CC->getLeader()); + } + if (CC->getDefiningExpr()) { + // If we simplified to something else, we need to communicate + // that we're users of the value we simplified to. + if (I != V) { + // Don't add temporary instructions to the user lists. + if (!AllTempInstructions.count(I)) + addAdditionalUsers(V, I); + } + + if (I) + DEBUG(dbgs() << "Simplified " << *I << " to " + << " expression " << *CC->getDefiningExpr() << "\n"); + NumGVNOpsSimplified++; + deleteExpression(E); + return CC->getDefiningExpr(); + } + } + + return nullptr; +} + +// Create a value expression from the instruction I, replacing operands with +// their leaders. + +const Expression *NewGVN::createExpression(Instruction *I) const { + auto *E = new (ExpressionAllocator) BasicExpression(I->getNumOperands()); + + bool AllConstant = setBasicExpressionInfo(I, E); + + if (I->isCommutative()) { + // Ensure that commutative instructions that only differ by a permutation + // of their operands get the same value number by sorting the operand value + // numbers. Since all commutative instructions have two operands it is more + // efficient to sort by hand rather than using, say, std::sort. + assert(I->getNumOperands() == 2 && "Unsupported commutative instruction!"); + if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) + E->swapOperands(0, 1); + } + // Perform simplification. + if (auto *CI = dyn_cast<CmpInst>(I)) { + // Sort the operand value numbers so x<y and y>x get the same value + // number. + CmpInst::Predicate Predicate = CI->getPredicate(); + if (shouldSwapOperands(E->getOperand(0), E->getOperand(1))) { + E->swapOperands(0, 1); + Predicate = CmpInst::getSwappedPredicate(Predicate); + } + E->setOpcode((CI->getOpcode() << 8) | Predicate); + // TODO: 25% of our time is spent in SimplifyCmpInst with pointer operands + assert(I->getOperand(0)->getType() == I->getOperand(1)->getType() && + "Wrong types on cmp instruction"); + assert((E->getOperand(0)->getType() == I->getOperand(0)->getType() && + E->getOperand(1)->getType() == I->getOperand(1)->getType())); + Value *V = + SimplifyCmpInst(Predicate, E->getOperand(0), E->getOperand(1), SQ); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } else if (isa<SelectInst>(I)) { + if (isa<Constant>(E->getOperand(0)) || + E->getOperand(1) == E->getOperand(2)) { + assert(E->getOperand(1)->getType() == I->getOperand(1)->getType() && + E->getOperand(2)->getType() == I->getOperand(2)->getType()); + Value *V = SimplifySelectInst(E->getOperand(0), E->getOperand(1), + E->getOperand(2), SQ); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } + } else if (I->isBinaryOp()) { + Value *V = + SimplifyBinOp(E->getOpcode(), E->getOperand(0), E->getOperand(1), SQ); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } else if (auto *BI = dyn_cast<BitCastInst>(I)) { + Value *V = + SimplifyCastInst(BI->getOpcode(), BI->getOperand(0), BI->getType(), SQ); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } else if (isa<GetElementPtrInst>(I)) { + Value *V = SimplifyGEPInst( + E->getType(), ArrayRef<Value *>(E->op_begin(), E->op_end()), SQ); + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } else if (AllConstant) { + // We don't bother trying to simplify unless all of the operands + // were constant. + // TODO: There are a lot of Simplify*'s we could call here, if we + // wanted to. The original motivating case for this code was a + // zext i1 false to i8, which we don't have an interface to + // simplify (IE there is no SimplifyZExt). + + SmallVector<Constant *, 8> C; + for (Value *Arg : E->operands()) + C.emplace_back(cast<Constant>(Arg)); + + if (Value *V = ConstantFoldInstOperands(I, C, DL, TLI)) + if (const Expression *SimplifiedE = checkSimplificationResults(E, I, V)) + return SimplifiedE; + } + return E; +} + +const AggregateValueExpression * +NewGVN::createAggregateValueExpression(Instruction *I) const { + if (auto *II = dyn_cast<InsertValueInst>(I)) { + auto *E = new (ExpressionAllocator) + AggregateValueExpression(I->getNumOperands(), II->getNumIndices()); + setBasicExpressionInfo(I, E); + E->allocateIntOperands(ExpressionAllocator); + std::copy(II->idx_begin(), II->idx_end(), int_op_inserter(E)); + return E; + } else if (auto *EI = dyn_cast<ExtractValueInst>(I)) { + auto *E = new (ExpressionAllocator) + AggregateValueExpression(I->getNumOperands(), EI->getNumIndices()); + setBasicExpressionInfo(EI, E); + E->allocateIntOperands(ExpressionAllocator); + std::copy(EI->idx_begin(), EI->idx_end(), int_op_inserter(E)); + return E; + } + llvm_unreachable("Unhandled type of aggregate value operation"); +} + +const DeadExpression *NewGVN::createDeadExpression() const { + // DeadExpression has no arguments and all DeadExpression's are the same, + // so we only need one of them. + return SingletonDeadExpression; +} + +const VariableExpression *NewGVN::createVariableExpression(Value *V) const { + auto *E = new (ExpressionAllocator) VariableExpression(V); + E->setOpcode(V->getValueID()); + return E; +} + +const Expression *NewGVN::createVariableOrConstant(Value *V) const { + if (auto *C = dyn_cast<Constant>(V)) + return createConstantExpression(C); + return createVariableExpression(V); +} + +const ConstantExpression *NewGVN::createConstantExpression(Constant *C) const { + auto *E = new (ExpressionAllocator) ConstantExpression(C); + E->setOpcode(C->getValueID()); + return E; +} + +const UnknownExpression *NewGVN::createUnknownExpression(Instruction *I) const { + auto *E = new (ExpressionAllocator) UnknownExpression(I); + E->setOpcode(I->getOpcode()); + return E; +} + +const CallExpression * +NewGVN::createCallExpression(CallInst *CI, const MemoryAccess *MA) const { + // FIXME: Add operand bundles for calls. + auto *E = + new (ExpressionAllocator) CallExpression(CI->getNumOperands(), CI, MA); + setBasicExpressionInfo(CI, E); + return E; +} + +// Return true if some equivalent of instruction Inst dominates instruction U. +bool NewGVN::someEquivalentDominates(const Instruction *Inst, + const Instruction *U) const { + auto *CC = ValueToClass.lookup(Inst); + // This must be an instruction because we are only called from phi nodes + // in the case that the value it needs to check against is an instruction. + + // The most likely candiates for dominance are the leader and the next leader. + // The leader or nextleader will dominate in all cases where there is an + // equivalent that is higher up in the dom tree. + // We can't *only* check them, however, because the + // dominator tree could have an infinite number of non-dominating siblings + // with instructions that are in the right congruence class. + // A + // B C D E F G + // | + // H + // Instruction U could be in H, with equivalents in every other sibling. + // Depending on the rpo order picked, the leader could be the equivalent in + // any of these siblings. + if (!CC) + return false; + if (alwaysAvailable(CC->getLeader())) + return true; + if (DT->dominates(cast<Instruction>(CC->getLeader()), U)) + return true; + if (CC->getNextLeader().first && + DT->dominates(cast<Instruction>(CC->getNextLeader().first), U)) + return true; + return llvm::any_of(*CC, [&](const Value *Member) { + return Member != CC->getLeader() && + DT->dominates(cast<Instruction>(Member), U); + }); +} + +// See if we have a congruence class and leader for this operand, and if so, +// return it. Otherwise, return the operand itself. +Value *NewGVN::lookupOperandLeader(Value *V) const { + CongruenceClass *CC = ValueToClass.lookup(V); + if (CC) { + // Everything in TOP is represented by undef, as it can be any value. + // We do have to make sure we get the type right though, so we can't set the + // RepLeader to undef. + if (CC == TOPClass) + return UndefValue::get(V->getType()); + return CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader(); + } + + return V; +} + +const MemoryAccess *NewGVN::lookupMemoryLeader(const MemoryAccess *MA) const { + auto *CC = getMemoryClass(MA); + assert(CC->getMemoryLeader() && + "Every MemoryAccess should be mapped to a congruence class with a " + "representative memory access"); + return CC->getMemoryLeader(); +} + +// Return true if the MemoryAccess is really equivalent to everything. This is +// equivalent to the lattice value "TOP" in most lattices. This is the initial +// state of all MemoryAccesses. +bool NewGVN::isMemoryAccessTOP(const MemoryAccess *MA) const { + return getMemoryClass(MA) == TOPClass; +} + +LoadExpression *NewGVN::createLoadExpression(Type *LoadType, Value *PointerOp, + LoadInst *LI, + const MemoryAccess *MA) const { + auto *E = + new (ExpressionAllocator) LoadExpression(1, LI, lookupMemoryLeader(MA)); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + E->setType(LoadType); + + // Give store and loads same opcode so they value number together. + E->setOpcode(0); + E->op_push_back(PointerOp); + if (LI) + E->setAlignment(LI->getAlignment()); + + // TODO: Value number heap versions. We may be able to discover + // things alias analysis can't on it's own (IE that a store and a + // load have the same value, and thus, it isn't clobbering the load). + return E; +} + +const StoreExpression * +NewGVN::createStoreExpression(StoreInst *SI, const MemoryAccess *MA) const { + auto *StoredValueLeader = lookupOperandLeader(SI->getValueOperand()); + auto *E = new (ExpressionAllocator) + StoreExpression(SI->getNumOperands(), SI, StoredValueLeader, MA); + E->allocateOperands(ArgRecycler, ExpressionAllocator); + E->setType(SI->getValueOperand()->getType()); + + // Give store and loads same opcode so they value number together. + E->setOpcode(0); + E->op_push_back(lookupOperandLeader(SI->getPointerOperand())); + + // TODO: Value number heap versions. We may be able to discover + // things alias analysis can't on it's own (IE that a store and a + // load have the same value, and thus, it isn't clobbering the load). + return E; +} + +const Expression *NewGVN::performSymbolicStoreEvaluation(Instruction *I) const { + // Unlike loads, we never try to eliminate stores, so we do not check if they + // are simple and avoid value numbering them. + auto *SI = cast<StoreInst>(I); + auto *StoreAccess = getMemoryAccess(SI); + // Get the expression, if any, for the RHS of the MemoryDef. + const MemoryAccess *StoreRHS = StoreAccess->getDefiningAccess(); + if (EnableStoreRefinement) + StoreRHS = MSSAWalker->getClobberingMemoryAccess(StoreAccess); + // If we bypassed the use-def chains, make sure we add a use. + StoreRHS = lookupMemoryLeader(StoreRHS); + if (StoreRHS != StoreAccess->getDefiningAccess()) + addMemoryUsers(StoreRHS, StoreAccess); + // If we are defined by ourselves, use the live on entry def. + if (StoreRHS == StoreAccess) + StoreRHS = MSSA->getLiveOnEntryDef(); + + if (SI->isSimple()) { + // See if we are defined by a previous store expression, it already has a + // value, and it's the same value as our current store. FIXME: Right now, we + // only do this for simple stores, we should expand to cover memcpys, etc. + const auto *LastStore = createStoreExpression(SI, StoreRHS); + const auto *LastCC = ExpressionToClass.lookup(LastStore); + // We really want to check whether the expression we matched was a store. No + // easy way to do that. However, we can check that the class we found has a + // store, which, assuming the value numbering state is not corrupt, is + // sufficient, because we must also be equivalent to that store's expression + // for it to be in the same class as the load. + if (LastCC && LastCC->getStoredValue() == LastStore->getStoredValue()) + return LastStore; + // Also check if our value operand is defined by a load of the same memory + // location, and the memory state is the same as it was then (otherwise, it + // could have been overwritten later. See test32 in + // transforms/DeadStoreElimination/simple.ll). + if (auto *LI = dyn_cast<LoadInst>(LastStore->getStoredValue())) + if ((lookupOperandLeader(LI->getPointerOperand()) == + LastStore->getOperand(0)) && + (lookupMemoryLeader(getMemoryAccess(LI)->getDefiningAccess()) == + StoreRHS)) + return LastStore; + deleteExpression(LastStore); + } + + // If the store is not equivalent to anything, value number it as a store that + // produces a unique memory state (instead of using it's MemoryUse, we use + // it's MemoryDef). + return createStoreExpression(SI, StoreAccess); +} + +// See if we can extract the value of a loaded pointer from a load, a store, or +// a memory instruction. +const Expression * +NewGVN::performSymbolicLoadCoercion(Type *LoadType, Value *LoadPtr, + LoadInst *LI, Instruction *DepInst, + MemoryAccess *DefiningAccess) const { + assert((!LI || LI->isSimple()) && "Not a simple load"); + if (auto *DepSI = dyn_cast<StoreInst>(DepInst)) { + // Can't forward from non-atomic to atomic without violating memory model. + // Also don't need to coerce if they are the same type, we will just + // propagate. + if (LI->isAtomic() > DepSI->isAtomic() || + LoadType == DepSI->getValueOperand()->getType()) + return nullptr; + int Offset = analyzeLoadFromClobberingStore(LoadType, LoadPtr, DepSI, DL); + if (Offset >= 0) { + if (auto *C = dyn_cast<Constant>( + lookupOperandLeader(DepSI->getValueOperand()))) { + DEBUG(dbgs() << "Coercing load from store " << *DepSI << " to constant " + << *C << "\n"); + return createConstantExpression( + getConstantStoreValueForLoad(C, Offset, LoadType, DL)); + } + } + } else if (auto *DepLI = dyn_cast<LoadInst>(DepInst)) { + // Can't forward from non-atomic to atomic without violating memory model. + if (LI->isAtomic() > DepLI->isAtomic()) + return nullptr; + int Offset = analyzeLoadFromClobberingLoad(LoadType, LoadPtr, DepLI, DL); + if (Offset >= 0) { + // We can coerce a constant load into a load. + if (auto *C = dyn_cast<Constant>(lookupOperandLeader(DepLI))) + if (auto *PossibleConstant = + getConstantLoadValueForLoad(C, Offset, LoadType, DL)) { + DEBUG(dbgs() << "Coercing load from load " << *LI << " to constant " + << *PossibleConstant << "\n"); + return createConstantExpression(PossibleConstant); + } + } + } else if (auto *DepMI = dyn_cast<MemIntrinsic>(DepInst)) { + int Offset = analyzeLoadFromClobberingMemInst(LoadType, LoadPtr, DepMI, DL); + if (Offset >= 0) { + if (auto *PossibleConstant = + getConstantMemInstValueForLoad(DepMI, Offset, LoadType, DL)) { + DEBUG(dbgs() << "Coercing load from meminst " << *DepMI + << " to constant " << *PossibleConstant << "\n"); + return createConstantExpression(PossibleConstant); + } + } + } + + // All of the below are only true if the loaded pointer is produced + // by the dependent instruction. + if (LoadPtr != lookupOperandLeader(DepInst) && + !AA->isMustAlias(LoadPtr, DepInst)) + return nullptr; + // If this load really doesn't depend on anything, then we must be loading an + // undef value. This can happen when loading for a fresh allocation with no + // intervening stores, for example. Note that this is only true in the case + // that the result of the allocation is pointer equal to the load ptr. + if (isa<AllocaInst>(DepInst) || isMallocLikeFn(DepInst, TLI)) { + return createConstantExpression(UndefValue::get(LoadType)); + } + // If this load occurs either right after a lifetime begin, + // then the loaded value is undefined. + else if (auto *II = dyn_cast<IntrinsicInst>(DepInst)) { + if (II->getIntrinsicID() == Intrinsic::lifetime_start) + return createConstantExpression(UndefValue::get(LoadType)); + } + // If this load follows a calloc (which zero initializes memory), + // then the loaded value is zero + else if (isCallocLikeFn(DepInst, TLI)) { + return createConstantExpression(Constant::getNullValue(LoadType)); + } + + return nullptr; +} + +const Expression *NewGVN::performSymbolicLoadEvaluation(Instruction *I) const { + auto *LI = cast<LoadInst>(I); + + // We can eliminate in favor of non-simple loads, but we won't be able to + // eliminate the loads themselves. + if (!LI->isSimple()) + return nullptr; + + Value *LoadAddressLeader = lookupOperandLeader(LI->getPointerOperand()); + // Load of undef is undef. + if (isa<UndefValue>(LoadAddressLeader)) + return createConstantExpression(UndefValue::get(LI->getType())); + MemoryAccess *OriginalAccess = getMemoryAccess(I); + MemoryAccess *DefiningAccess = + MSSAWalker->getClobberingMemoryAccess(OriginalAccess); + + if (!MSSA->isLiveOnEntryDef(DefiningAccess)) { + if (auto *MD = dyn_cast<MemoryDef>(DefiningAccess)) { + Instruction *DefiningInst = MD->getMemoryInst(); + // If the defining instruction is not reachable, replace with undef. + if (!ReachableBlocks.count(DefiningInst->getParent())) + return createConstantExpression(UndefValue::get(LI->getType())); + // This will handle stores and memory insts. We only do if it the + // defining access has a different type, or it is a pointer produced by + // certain memory operations that cause the memory to have a fixed value + // (IE things like calloc). + if (const auto *CoercionResult = + performSymbolicLoadCoercion(LI->getType(), LoadAddressLeader, LI, + DefiningInst, DefiningAccess)) + return CoercionResult; + } + } + + const auto *LE = createLoadExpression(LI->getType(), LoadAddressLeader, LI, + DefiningAccess); + // If our MemoryLeader is not our defining access, add a use to the + // MemoryLeader, so that we get reprocessed when it changes. + if (LE->getMemoryLeader() != DefiningAccess) + addMemoryUsers(LE->getMemoryLeader(), OriginalAccess); + return LE; +} + +const Expression * +NewGVN::performSymbolicPredicateInfoEvaluation(Instruction *I) const { + auto *PI = PredInfo->getPredicateInfoFor(I); + if (!PI) + return nullptr; + + DEBUG(dbgs() << "Found predicate info from instruction !\n"); + + auto *PWC = dyn_cast<PredicateWithCondition>(PI); + if (!PWC) + return nullptr; + + auto *CopyOf = I->getOperand(0); + auto *Cond = PWC->Condition; + + // If this a copy of the condition, it must be either true or false depending + // on the predicate info type and edge. + if (CopyOf == Cond) { + // We should not need to add predicate users because the predicate info is + // already a use of this operand. + if (isa<PredicateAssume>(PI)) + return createConstantExpression(ConstantInt::getTrue(Cond->getType())); + if (auto *PBranch = dyn_cast<PredicateBranch>(PI)) { + if (PBranch->TrueEdge) + return createConstantExpression(ConstantInt::getTrue(Cond->getType())); + return createConstantExpression(ConstantInt::getFalse(Cond->getType())); + } + if (auto *PSwitch = dyn_cast<PredicateSwitch>(PI)) + return createConstantExpression(cast<Constant>(PSwitch->CaseValue)); + } + + // Not a copy of the condition, so see what the predicates tell us about this + // value. First, though, we check to make sure the value is actually a copy + // of one of the condition operands. It's possible, in certain cases, for it + // to be a copy of a predicateinfo copy. In particular, if two branch + // operations use the same condition, and one branch dominates the other, we + // will end up with a copy of a copy. This is currently a small deficiency in + // predicateinfo. What will end up happening here is that we will value + // number both copies the same anyway. + + // Everything below relies on the condition being a comparison. + auto *Cmp = dyn_cast<CmpInst>(Cond); + if (!Cmp) + return nullptr; + + if (CopyOf != Cmp->getOperand(0) && CopyOf != Cmp->getOperand(1)) { + DEBUG(dbgs() << "Copy is not of any condition operands!\n"); + return nullptr; + } + Value *FirstOp = lookupOperandLeader(Cmp->getOperand(0)); + Value *SecondOp = lookupOperandLeader(Cmp->getOperand(1)); + bool SwappedOps = false; + // Sort the ops. + if (shouldSwapOperands(FirstOp, SecondOp)) { + std::swap(FirstOp, SecondOp); + SwappedOps = true; + } + CmpInst::Predicate Predicate = + SwappedOps ? Cmp->getSwappedPredicate() : Cmp->getPredicate(); + + if (isa<PredicateAssume>(PI)) { + // If the comparison is true when the operands are equal, then we know the + // operands are equal, because assumes must always be true. + if (CmpInst::isTrueWhenEqual(Predicate)) { + addPredicateUsers(PI, I); + addAdditionalUsers(Cmp->getOperand(0), I); + return createVariableOrConstant(FirstOp); + } + } + if (const auto *PBranch = dyn_cast<PredicateBranch>(PI)) { + // If we are *not* a copy of the comparison, we may equal to the other + // operand when the predicate implies something about equality of + // operations. In particular, if the comparison is true/false when the + // operands are equal, and we are on the right edge, we know this operation + // is equal to something. + if ((PBranch->TrueEdge && Predicate == CmpInst::ICMP_EQ) || + (!PBranch->TrueEdge && Predicate == CmpInst::ICMP_NE)) { + addPredicateUsers(PI, I); + addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), + I); + return createVariableOrConstant(FirstOp); + } + // Handle the special case of floating point. + if (((PBranch->TrueEdge && Predicate == CmpInst::FCMP_OEQ) || + (!PBranch->TrueEdge && Predicate == CmpInst::FCMP_UNE)) && + isa<ConstantFP>(FirstOp) && !cast<ConstantFP>(FirstOp)->isZero()) { + addPredicateUsers(PI, I); + addAdditionalUsers(SwappedOps ? Cmp->getOperand(1) : Cmp->getOperand(0), + I); + return createConstantExpression(cast<Constant>(FirstOp)); + } + } + return nullptr; +} + +// Evaluate read only and pure calls, and create an expression result. +const Expression *NewGVN::performSymbolicCallEvaluation(Instruction *I) const { + auto *CI = cast<CallInst>(I); + if (auto *II = dyn_cast<IntrinsicInst>(I)) { + // Instrinsics with the returned attribute are copies of arguments. + if (auto *ReturnedValue = II->getReturnedArgOperand()) { + if (II->getIntrinsicID() == Intrinsic::ssa_copy) + if (const auto *Result = performSymbolicPredicateInfoEvaluation(I)) + return Result; + return createVariableOrConstant(ReturnedValue); + } + } + if (AA->doesNotAccessMemory(CI)) { + return createCallExpression(CI, TOPClass->getMemoryLeader()); + } else if (AA->onlyReadsMemory(CI)) { + MemoryAccess *DefiningAccess = MSSAWalker->getClobberingMemoryAccess(CI); + return createCallExpression(CI, DefiningAccess); + } + return nullptr; +} + +// Retrieve the memory class for a given MemoryAccess. +CongruenceClass *NewGVN::getMemoryClass(const MemoryAccess *MA) const { + auto *Result = MemoryAccessToClass.lookup(MA); + assert(Result && "Should have found memory class"); + return Result; +} + +// Update the MemoryAccess equivalence table to say that From is equal to To, +// and return true if this is different from what already existed in the table. +bool NewGVN::setMemoryClass(const MemoryAccess *From, + CongruenceClass *NewClass) { + assert(NewClass && + "Every MemoryAccess should be getting mapped to a non-null class"); + DEBUG(dbgs() << "Setting " << *From); + DEBUG(dbgs() << " equivalent to congruence class "); + DEBUG(dbgs() << NewClass->getID() << " with current MemoryAccess leader "); + DEBUG(dbgs() << *NewClass->getMemoryLeader() << "\n"); + + auto LookupResult = MemoryAccessToClass.find(From); + bool Changed = false; + // If it's already in the table, see if the value changed. + if (LookupResult != MemoryAccessToClass.end()) { + auto *OldClass = LookupResult->second; + if (OldClass != NewClass) { + // If this is a phi, we have to handle memory member updates. + if (auto *MP = dyn_cast<MemoryPhi>(From)) { + OldClass->memory_erase(MP); + NewClass->memory_insert(MP); + // This may have killed the class if it had no non-memory members + if (OldClass->getMemoryLeader() == From) { + if (OldClass->definesNoMemory()) { + OldClass->setMemoryLeader(nullptr); + } else { + OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); + DEBUG(dbgs() << "Memory class leader change for class " + << OldClass->getID() << " to " + << *OldClass->getMemoryLeader() + << " due to removal of a memory member " << *From + << "\n"); + markMemoryLeaderChangeTouched(OldClass); + } + } + } + // It wasn't equivalent before, and now it is. + LookupResult->second = NewClass; + Changed = true; + } + } + + return Changed; +} + +// Determine if a instruction is cycle-free. That means the values in the +// instruction don't depend on any expressions that can change value as a result +// of the instruction. For example, a non-cycle free instruction would be v = +// phi(0, v+1). +bool NewGVN::isCycleFree(const Instruction *I) const { + // In order to compute cycle-freeness, we do SCC finding on the instruction, + // and see what kind of SCC it ends up in. If it is a singleton, it is + // cycle-free. If it is not in a singleton, it is only cycle free if the + // other members are all phi nodes (as they do not compute anything, they are + // copies). + auto ICS = InstCycleState.lookup(I); + if (ICS == ICS_Unknown) { + SCCFinder.Start(I); + auto &SCC = SCCFinder.getComponentFor(I); + // It's cycle free if it's size 1 or or the SCC is *only* phi nodes. + if (SCC.size() == 1) + InstCycleState.insert({I, ICS_CycleFree}); + else { + bool AllPhis = llvm::all_of(SCC, [](const Value *V) { + return isa<PHINode>(V) || isCopyOfAPHI(V); + }); + ICS = AllPhis ? ICS_CycleFree : ICS_Cycle; + for (auto *Member : SCC) + if (auto *MemberPhi = dyn_cast<PHINode>(Member)) + InstCycleState.insert({MemberPhi, ICS}); + } + } + if (ICS == ICS_Cycle) + return false; + return true; +} + +// Evaluate PHI nodes symbolically and create an expression result. +const Expression * +NewGVN::performSymbolicPHIEvaluation(ArrayRef<ValPair> PHIOps, + Instruction *I, + BasicBlock *PHIBlock) const { + // True if one of the incoming phi edges is a backedge. + bool HasBackedge = false; + // All constant tracks the state of whether all the *original* phi operands + // This is really shorthand for "this phi cannot cycle due to forward + // change in value of the phi is guaranteed not to later change the value of + // the phi. IE it can't be v = phi(undef, v+1) + bool OriginalOpsConstant = true; + auto *E = cast<PHIExpression>(createPHIExpression( + PHIOps, I, PHIBlock, HasBackedge, OriginalOpsConstant)); + // We match the semantics of SimplifyPhiNode from InstructionSimplify here. + // See if all arguments are the same. + // We track if any were undef because they need special handling. + bool HasUndef = false; + auto Filtered = make_filter_range(E->operands(), [&](Value *Arg) { + if (isa<UndefValue>(Arg)) { + HasUndef = true; + return false; + } + return true; + }); + // If we are left with no operands, it's dead. + if (Filtered.begin() == Filtered.end()) { + // If it has undef at this point, it means there are no-non-undef arguments, + // and thus, the value of the phi node must be undef. + if (HasUndef) { + DEBUG(dbgs() << "PHI Node " << *I + << " has no non-undef arguments, valuing it as undef\n"); + return createConstantExpression(UndefValue::get(I->getType())); + } + + DEBUG(dbgs() << "No arguments of PHI node " << *I << " are live\n"); + deleteExpression(E); + return createDeadExpression(); + } + Value *AllSameValue = *(Filtered.begin()); + ++Filtered.begin(); + // Can't use std::equal here, sadly, because filter.begin moves. + if (llvm::all_of(Filtered, [&](Value *Arg) { return Arg == AllSameValue; })) { + // In LLVM's non-standard representation of phi nodes, it's possible to have + // phi nodes with cycles (IE dependent on other phis that are .... dependent + // on the original phi node), especially in weird CFG's where some arguments + // are unreachable, or uninitialized along certain paths. This can cause + // infinite loops during evaluation. We work around this by not trying to + // really evaluate them independently, but instead using a variable + // expression to say if one is equivalent to the other. + // We also special case undef, so that if we have an undef, we can't use the + // common value unless it dominates the phi block. + if (HasUndef) { + // If we have undef and at least one other value, this is really a + // multivalued phi, and we need to know if it's cycle free in order to + // evaluate whether we can ignore the undef. The other parts of this are + // just shortcuts. If there is no backedge, or all operands are + // constants, it also must be cycle free. + if (HasBackedge && !OriginalOpsConstant && + !isa<UndefValue>(AllSameValue) && !isCycleFree(I)) + return E; + + // Only have to check for instructions + if (auto *AllSameInst = dyn_cast<Instruction>(AllSameValue)) + if (!someEquivalentDominates(AllSameInst, I)) + return E; + } + // Can't simplify to something that comes later in the iteration. + // Otherwise, when and if it changes congruence class, we will never catch + // up. We will always be a class behind it. + if (isa<Instruction>(AllSameValue) && + InstrToDFSNum(AllSameValue) > InstrToDFSNum(I)) + return E; + NumGVNPhisAllSame++; + DEBUG(dbgs() << "Simplified PHI node " << *I << " to " << *AllSameValue + << "\n"); + deleteExpression(E); + return createVariableOrConstant(AllSameValue); + } + return E; +} + +const Expression * +NewGVN::performSymbolicAggrValueEvaluation(Instruction *I) const { + if (auto *EI = dyn_cast<ExtractValueInst>(I)) { + auto *II = dyn_cast<IntrinsicInst>(EI->getAggregateOperand()); + if (II && EI->getNumIndices() == 1 && *EI->idx_begin() == 0) { + unsigned Opcode = 0; + // EI might be an extract from one of our recognised intrinsics. If it + // is we'll synthesize a semantically equivalent expression instead on + // an extract value expression. + switch (II->getIntrinsicID()) { + case Intrinsic::sadd_with_overflow: + case Intrinsic::uadd_with_overflow: + Opcode = Instruction::Add; + break; + case Intrinsic::ssub_with_overflow: + case Intrinsic::usub_with_overflow: + Opcode = Instruction::Sub; + break; + case Intrinsic::smul_with_overflow: + case Intrinsic::umul_with_overflow: + Opcode = Instruction::Mul; + break; + default: + break; + } + + if (Opcode != 0) { + // Intrinsic recognized. Grab its args to finish building the + // expression. + assert(II->getNumArgOperands() == 2 && + "Expect two args for recognised intrinsics."); + return createBinaryExpression(Opcode, EI->getType(), + II->getArgOperand(0), + II->getArgOperand(1), I); + } + } + } + + return createAggregateValueExpression(I); +} + +const Expression *NewGVN::performSymbolicCmpEvaluation(Instruction *I) const { + assert(isa<CmpInst>(I) && "Expected a cmp instruction."); + + auto *CI = cast<CmpInst>(I); + // See if our operands are equal to those of a previous predicate, and if so, + // if it implies true or false. + auto Op0 = lookupOperandLeader(CI->getOperand(0)); + auto Op1 = lookupOperandLeader(CI->getOperand(1)); + auto OurPredicate = CI->getPredicate(); + if (shouldSwapOperands(Op0, Op1)) { + std::swap(Op0, Op1); + OurPredicate = CI->getSwappedPredicate(); + } + + // Avoid processing the same info twice. + const PredicateBase *LastPredInfo = nullptr; + // See if we know something about the comparison itself, like it is the target + // of an assume. + auto *CmpPI = PredInfo->getPredicateInfoFor(I); + if (dyn_cast_or_null<PredicateAssume>(CmpPI)) + return createConstantExpression(ConstantInt::getTrue(CI->getType())); + + if (Op0 == Op1) { + // This condition does not depend on predicates, no need to add users + if (CI->isTrueWhenEqual()) + return createConstantExpression(ConstantInt::getTrue(CI->getType())); + else if (CI->isFalseWhenEqual()) + return createConstantExpression(ConstantInt::getFalse(CI->getType())); + } + + // NOTE: Because we are comparing both operands here and below, and using + // previous comparisons, we rely on fact that predicateinfo knows to mark + // comparisons that use renamed operands as users of the earlier comparisons. + // It is *not* enough to just mark predicateinfo renamed operands as users of + // the earlier comparisons, because the *other* operand may have changed in a + // previous iteration. + // Example: + // icmp slt %a, %b + // %b.0 = ssa.copy(%b) + // false branch: + // icmp slt %c, %b.0 + + // %c and %a may start out equal, and thus, the code below will say the second + // %icmp is false. c may become equal to something else, and in that case the + // %second icmp *must* be reexamined, but would not if only the renamed + // %operands are considered users of the icmp. + + // *Currently* we only check one level of comparisons back, and only mark one + // level back as touched when changes happen. If you modify this code to look + // back farther through comparisons, you *must* mark the appropriate + // comparisons as users in PredicateInfo.cpp, or you will cause bugs. See if + // we know something just from the operands themselves + + // See if our operands have predicate info, so that we may be able to derive + // something from a previous comparison. + for (const auto &Op : CI->operands()) { + auto *PI = PredInfo->getPredicateInfoFor(Op); + if (const auto *PBranch = dyn_cast_or_null<PredicateBranch>(PI)) { + if (PI == LastPredInfo) + continue; + LastPredInfo = PI; + // In phi of ops cases, we may have predicate info that we are evaluating + // in a different context. + if (!DT->dominates(PBranch->To, getBlockForValue(I))) + continue; + // TODO: Along the false edge, we may know more things too, like + // icmp of + // same operands is false. + // TODO: We only handle actual comparison conditions below, not + // and/or. + auto *BranchCond = dyn_cast<CmpInst>(PBranch->Condition); + if (!BranchCond) + continue; + auto *BranchOp0 = lookupOperandLeader(BranchCond->getOperand(0)); + auto *BranchOp1 = lookupOperandLeader(BranchCond->getOperand(1)); + auto BranchPredicate = BranchCond->getPredicate(); + if (shouldSwapOperands(BranchOp0, BranchOp1)) { + std::swap(BranchOp0, BranchOp1); + BranchPredicate = BranchCond->getSwappedPredicate(); + } + if (BranchOp0 == Op0 && BranchOp1 == Op1) { + if (PBranch->TrueEdge) { + // If we know the previous predicate is true and we are in the true + // edge then we may be implied true or false. + if (CmpInst::isImpliedTrueByMatchingCmp(BranchPredicate, + OurPredicate)) { + addPredicateUsers(PI, I); + return createConstantExpression( + ConstantInt::getTrue(CI->getType())); + } + + if (CmpInst::isImpliedFalseByMatchingCmp(BranchPredicate, + OurPredicate)) { + addPredicateUsers(PI, I); + return createConstantExpression( + ConstantInt::getFalse(CI->getType())); + } + } else { + // Just handle the ne and eq cases, where if we have the same + // operands, we may know something. + if (BranchPredicate == OurPredicate) { + addPredicateUsers(PI, I); + // Same predicate, same ops,we know it was false, so this is false. + return createConstantExpression( + ConstantInt::getFalse(CI->getType())); + } else if (BranchPredicate == + CmpInst::getInversePredicate(OurPredicate)) { + addPredicateUsers(PI, I); + // Inverse predicate, we know the other was false, so this is true. + return createConstantExpression( + ConstantInt::getTrue(CI->getType())); + } + } + } + } + } + // Create expression will take care of simplifyCmpInst + return createExpression(I); +} + +// Substitute and symbolize the value before value numbering. +const Expression * +NewGVN::performSymbolicEvaluation(Value *V, + SmallPtrSetImpl<Value *> &Visited) const { + const Expression *E = nullptr; + if (auto *C = dyn_cast<Constant>(V)) + E = createConstantExpression(C); + else if (isa<Argument>(V) || isa<GlobalVariable>(V)) { + E = createVariableExpression(V); + } else { + // TODO: memory intrinsics. + // TODO: Some day, we should do the forward propagation and reassociation + // parts of the algorithm. + auto *I = cast<Instruction>(V); + switch (I->getOpcode()) { + case Instruction::ExtractValue: + case Instruction::InsertValue: + E = performSymbolicAggrValueEvaluation(I); + break; + case Instruction::PHI: { + SmallVector<ValPair, 3> Ops; + auto *PN = cast<PHINode>(I); + for (unsigned i = 0; i < PN->getNumOperands(); ++i) + Ops.push_back({PN->getIncomingValue(i), PN->getIncomingBlock(i)}); + // Sort to ensure the invariant createPHIExpression requires is met. + sortPHIOps(Ops); + E = performSymbolicPHIEvaluation(Ops, I, getBlockForValue(I)); + } break; + case Instruction::Call: + E = performSymbolicCallEvaluation(I); + break; + case Instruction::Store: + E = performSymbolicStoreEvaluation(I); + break; + case Instruction::Load: + E = performSymbolicLoadEvaluation(I); + break; + case Instruction::BitCast: + E = createExpression(I); + break; + case Instruction::ICmp: + case Instruction::FCmp: + E = performSymbolicCmpEvaluation(I); + break; + case Instruction::Add: + case Instruction::FAdd: + case Instruction::Sub: + case Instruction::FSub: + case Instruction::Mul: + case Instruction::FMul: + case Instruction::UDiv: + case Instruction::SDiv: + case Instruction::FDiv: + case Instruction::URem: + case Instruction::SRem: + case Instruction::FRem: + case Instruction::Shl: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::And: + case Instruction::Or: + case Instruction::Xor: + case Instruction::Trunc: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPTrunc: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::Select: + case Instruction::ExtractElement: + case Instruction::InsertElement: + case Instruction::ShuffleVector: + case Instruction::GetElementPtr: + E = createExpression(I); + break; + default: + return nullptr; + } + } + return E; +} + +// Look up a container in a map, and then call a function for each thing in the +// found container. +template <typename Map, typename KeyType, typename Func> +void NewGVN::for_each_found(Map &M, const KeyType &Key, Func F) { + const auto Result = M.find_as(Key); + if (Result != M.end()) + for (typename Map::mapped_type::value_type Mapped : Result->second) + F(Mapped); +} + +// Look up a container of values/instructions in a map, and touch all the +// instructions in the container. Then erase value from the map. +template <typename Map, typename KeyType> +void NewGVN::touchAndErase(Map &M, const KeyType &Key) { + const auto Result = M.find_as(Key); + if (Result != M.end()) { + for (const typename Map::mapped_type::value_type Mapped : Result->second) + TouchedInstructions.set(InstrToDFSNum(Mapped)); + M.erase(Result); + } +} + +void NewGVN::addAdditionalUsers(Value *To, Value *User) const { + assert(User && To != User); + if (isa<Instruction>(To)) + AdditionalUsers[To].insert(User); +} + +void NewGVN::markUsersTouched(Value *V) { + // Now mark the users as touched. + for (auto *User : V->users()) { + assert(isa<Instruction>(User) && "Use of value not within an instruction?"); + TouchedInstructions.set(InstrToDFSNum(User)); + } + touchAndErase(AdditionalUsers, V); +} + +void NewGVN::addMemoryUsers(const MemoryAccess *To, MemoryAccess *U) const { + DEBUG(dbgs() << "Adding memory user " << *U << " to " << *To << "\n"); + MemoryToUsers[To].insert(U); +} + +void NewGVN::markMemoryDefTouched(const MemoryAccess *MA) { + TouchedInstructions.set(MemoryToDFSNum(MA)); +} + +void NewGVN::markMemoryUsersTouched(const MemoryAccess *MA) { + if (isa<MemoryUse>(MA)) + return; + for (auto U : MA->users()) + TouchedInstructions.set(MemoryToDFSNum(U)); + touchAndErase(MemoryToUsers, MA); +} + +// Add I to the set of users of a given predicate. +void NewGVN::addPredicateUsers(const PredicateBase *PB, Instruction *I) const { + // Don't add temporary instructions to the user lists. + if (AllTempInstructions.count(I)) + return; + + if (auto *PBranch = dyn_cast<PredicateBranch>(PB)) + PredicateToUsers[PBranch->Condition].insert(I); + else if (auto *PAssume = dyn_cast<PredicateBranch>(PB)) + PredicateToUsers[PAssume->Condition].insert(I); +} + +// Touch all the predicates that depend on this instruction. +void NewGVN::markPredicateUsersTouched(Instruction *I) { + touchAndErase(PredicateToUsers, I); +} + +// Mark users affected by a memory leader change. +void NewGVN::markMemoryLeaderChangeTouched(CongruenceClass *CC) { + for (auto M : CC->memory()) + markMemoryDefTouched(M); +} + +// Touch the instructions that need to be updated after a congruence class has a +// leader change, and mark changed values. +void NewGVN::markValueLeaderChangeTouched(CongruenceClass *CC) { + for (auto M : *CC) { + if (auto *I = dyn_cast<Instruction>(M)) + TouchedInstructions.set(InstrToDFSNum(I)); + LeaderChanges.insert(M); + } +} + +// Give a range of things that have instruction DFS numbers, this will return +// the member of the range with the smallest dfs number. +template <class T, class Range> +T *NewGVN::getMinDFSOfRange(const Range &R) const { + std::pair<T *, unsigned> MinDFS = {nullptr, ~0U}; + for (const auto X : R) { + auto DFSNum = InstrToDFSNum(X); + if (DFSNum < MinDFS.second) + MinDFS = {X, DFSNum}; + } + return MinDFS.first; +} + +// This function returns the MemoryAccess that should be the next leader of +// congruence class CC, under the assumption that the current leader is going to +// disappear. +const MemoryAccess *NewGVN::getNextMemoryLeader(CongruenceClass *CC) const { + // TODO: If this ends up to slow, we can maintain a next memory leader like we + // do for regular leaders. + // Make sure there will be a leader to find. + assert(!CC->definesNoMemory() && "Can't get next leader if there is none"); + if (CC->getStoreCount() > 0) { + if (auto *NL = dyn_cast_or_null<StoreInst>(CC->getNextLeader().first)) + return getMemoryAccess(NL); + // Find the store with the minimum DFS number. + auto *V = getMinDFSOfRange<Value>(make_filter_range( + *CC, [&](const Value *V) { return isa<StoreInst>(V); })); + return getMemoryAccess(cast<StoreInst>(V)); + } + assert(CC->getStoreCount() == 0); + + // Given our assertion, hitting this part must mean + // !OldClass->memory_empty() + if (CC->memory_size() == 1) + return *CC->memory_begin(); + return getMinDFSOfRange<const MemoryPhi>(CC->memory()); +} + +// This function returns the next value leader of a congruence class, under the +// assumption that the current leader is going away. This should end up being +// the next most dominating member. +Value *NewGVN::getNextValueLeader(CongruenceClass *CC) const { + // We don't need to sort members if there is only 1, and we don't care about + // sorting the TOP class because everything either gets out of it or is + // unreachable. + + if (CC->size() == 1 || CC == TOPClass) { + return *(CC->begin()); + } else if (CC->getNextLeader().first) { + ++NumGVNAvoidedSortedLeaderChanges; + return CC->getNextLeader().first; + } else { + ++NumGVNSortedLeaderChanges; + // NOTE: If this ends up to slow, we can maintain a dual structure for + // member testing/insertion, or keep things mostly sorted, and sort only + // here, or use SparseBitVector or .... + return getMinDFSOfRange<Value>(*CC); + } +} + +// Move a MemoryAccess, currently in OldClass, to NewClass, including updates to +// the memory members, etc for the move. +// +// The invariants of this function are: +// +// - I must be moving to NewClass from OldClass +// - The StoreCount of OldClass and NewClass is expected to have been updated +// for I already if it is is a store. +// - The OldClass memory leader has not been updated yet if I was the leader. +void NewGVN::moveMemoryToNewCongruenceClass(Instruction *I, + MemoryAccess *InstMA, + CongruenceClass *OldClass, + CongruenceClass *NewClass) { + // If the leader is I, and we had a represenative MemoryAccess, it should + // be the MemoryAccess of OldClass. + assert((!InstMA || !OldClass->getMemoryLeader() || + OldClass->getLeader() != I || + MemoryAccessToClass.lookup(OldClass->getMemoryLeader()) == + MemoryAccessToClass.lookup(InstMA)) && + "Representative MemoryAccess mismatch"); + // First, see what happens to the new class + if (!NewClass->getMemoryLeader()) { + // Should be a new class, or a store becoming a leader of a new class. + assert(NewClass->size() == 1 || + (isa<StoreInst>(I) && NewClass->getStoreCount() == 1)); + NewClass->setMemoryLeader(InstMA); + // Mark it touched if we didn't just create a singleton + DEBUG(dbgs() << "Memory class leader change for class " << NewClass->getID() + << " due to new memory instruction becoming leader\n"); + markMemoryLeaderChangeTouched(NewClass); + } + setMemoryClass(InstMA, NewClass); + // Now, fixup the old class if necessary + if (OldClass->getMemoryLeader() == InstMA) { + if (!OldClass->definesNoMemory()) { + OldClass->setMemoryLeader(getNextMemoryLeader(OldClass)); + DEBUG(dbgs() << "Memory class leader change for class " + << OldClass->getID() << " to " + << *OldClass->getMemoryLeader() + << " due to removal of old leader " << *InstMA << "\n"); + markMemoryLeaderChangeTouched(OldClass); + } else + OldClass->setMemoryLeader(nullptr); + } +} + +// Move a value, currently in OldClass, to be part of NewClass +// Update OldClass and NewClass for the move (including changing leaders, etc). +void NewGVN::moveValueToNewCongruenceClass(Instruction *I, const Expression *E, + CongruenceClass *OldClass, + CongruenceClass *NewClass) { + if (I == OldClass->getNextLeader().first) + OldClass->resetNextLeader(); + + OldClass->erase(I); + NewClass->insert(I); + + if (NewClass->getLeader() != I) + NewClass->addPossibleNextLeader({I, InstrToDFSNum(I)}); + // Handle our special casing of stores. + if (auto *SI = dyn_cast<StoreInst>(I)) { + OldClass->decStoreCount(); + // Okay, so when do we want to make a store a leader of a class? + // If we have a store defined by an earlier load, we want the earlier load + // to lead the class. + // If we have a store defined by something else, we want the store to lead + // the class so everything else gets the "something else" as a value. + // If we have a store as the single member of the class, we want the store + // as the leader + if (NewClass->getStoreCount() == 0 && !NewClass->getStoredValue()) { + // If it's a store expression we are using, it means we are not equivalent + // to something earlier. + if (auto *SE = dyn_cast<StoreExpression>(E)) { + NewClass->setStoredValue(SE->getStoredValue()); + markValueLeaderChangeTouched(NewClass); + // Shift the new class leader to be the store + DEBUG(dbgs() << "Changing leader of congruence class " + << NewClass->getID() << " from " << *NewClass->getLeader() + << " to " << *SI << " because store joined class\n"); + // If we changed the leader, we have to mark it changed because we don't + // know what it will do to symbolic evaluation. + NewClass->setLeader(SI); + } + // We rely on the code below handling the MemoryAccess change. + } + NewClass->incStoreCount(); + } + // True if there is no memory instructions left in a class that had memory + // instructions before. + + // If it's not a memory use, set the MemoryAccess equivalence + auto *InstMA = dyn_cast_or_null<MemoryDef>(getMemoryAccess(I)); + if (InstMA) + moveMemoryToNewCongruenceClass(I, InstMA, OldClass, NewClass); + ValueToClass[I] = NewClass; + // See if we destroyed the class or need to swap leaders. + if (OldClass->empty() && OldClass != TOPClass) { + if (OldClass->getDefiningExpr()) { + DEBUG(dbgs() << "Erasing expression " << *OldClass->getDefiningExpr() + << " from table\n"); + // We erase it as an exact expression to make sure we don't just erase an + // equivalent one. + auto Iter = ExpressionToClass.find_as( + ExactEqualsExpression(*OldClass->getDefiningExpr())); + if (Iter != ExpressionToClass.end()) + ExpressionToClass.erase(Iter); +#ifdef EXPENSIVE_CHECKS + assert( + (*OldClass->getDefiningExpr() != *E || ExpressionToClass.lookup(E)) && + "We erased the expression we just inserted, which should not happen"); +#endif + } + } else if (OldClass->getLeader() == I) { + // When the leader changes, the value numbering of + // everything may change due to symbolization changes, so we need to + // reprocess. + DEBUG(dbgs() << "Value class leader change for class " << OldClass->getID() + << "\n"); + ++NumGVNLeaderChanges; + // Destroy the stored value if there are no more stores to represent it. + // Note that this is basically clean up for the expression removal that + // happens below. If we remove stores from a class, we may leave it as a + // class of equivalent memory phis. + if (OldClass->getStoreCount() == 0) { + if (OldClass->getStoredValue()) + OldClass->setStoredValue(nullptr); + } + OldClass->setLeader(getNextValueLeader(OldClass)); + OldClass->resetNextLeader(); + markValueLeaderChangeTouched(OldClass); + } +} + +// For a given expression, mark the phi of ops instructions that could have +// changed as a result. +void NewGVN::markPhiOfOpsChanged(const Expression *E) { + touchAndErase(ExpressionToPhiOfOps, E); +} + +// Perform congruence finding on a given value numbering expression. +void NewGVN::performCongruenceFinding(Instruction *I, const Expression *E) { + // This is guaranteed to return something, since it will at least find + // TOP. + + CongruenceClass *IClass = ValueToClass.lookup(I); + assert(IClass && "Should have found a IClass"); + // Dead classes should have been eliminated from the mapping. + assert(!IClass->isDead() && "Found a dead class"); + + CongruenceClass *EClass = nullptr; + if (const auto *VE = dyn_cast<VariableExpression>(E)) { + EClass = ValueToClass.lookup(VE->getVariableValue()); + } else if (isa<DeadExpression>(E)) { + EClass = TOPClass; + } + if (!EClass) { + auto lookupResult = ExpressionToClass.insert({E, nullptr}); + + // If it's not in the value table, create a new congruence class. + if (lookupResult.second) { + CongruenceClass *NewClass = createCongruenceClass(nullptr, E); + auto place = lookupResult.first; + place->second = NewClass; + + // Constants and variables should always be made the leader. + if (const auto *CE = dyn_cast<ConstantExpression>(E)) { + NewClass->setLeader(CE->getConstantValue()); + } else if (const auto *SE = dyn_cast<StoreExpression>(E)) { + StoreInst *SI = SE->getStoreInst(); + NewClass->setLeader(SI); + NewClass->setStoredValue(SE->getStoredValue()); + // The RepMemoryAccess field will be filled in properly by the + // moveValueToNewCongruenceClass call. + } else { + NewClass->setLeader(I); + } + assert(!isa<VariableExpression>(E) && + "VariableExpression should have been handled already"); + + EClass = NewClass; + DEBUG(dbgs() << "Created new congruence class for " << *I + << " using expression " << *E << " at " << NewClass->getID() + << " and leader " << *(NewClass->getLeader())); + if (NewClass->getStoredValue()) + DEBUG(dbgs() << " and stored value " << *(NewClass->getStoredValue())); + DEBUG(dbgs() << "\n"); + } else { + EClass = lookupResult.first->second; + if (isa<ConstantExpression>(E)) + assert((isa<Constant>(EClass->getLeader()) || + (EClass->getStoredValue() && + isa<Constant>(EClass->getStoredValue()))) && + "Any class with a constant expression should have a " + "constant leader"); + + assert(EClass && "Somehow don't have an eclass"); + + assert(!EClass->isDead() && "We accidentally looked up a dead class"); + } + } + bool ClassChanged = IClass != EClass; + bool LeaderChanged = LeaderChanges.erase(I); + if (ClassChanged || LeaderChanged) { + DEBUG(dbgs() << "New class " << EClass->getID() << " for expression " << *E + << "\n"); + if (ClassChanged) { + moveValueToNewCongruenceClass(I, E, IClass, EClass); + markPhiOfOpsChanged(E); + } + + markUsersTouched(I); + if (MemoryAccess *MA = getMemoryAccess(I)) + markMemoryUsersTouched(MA); + if (auto *CI = dyn_cast<CmpInst>(I)) + markPredicateUsersTouched(CI); + } + // If we changed the class of the store, we want to ensure nothing finds the + // old store expression. In particular, loads do not compare against stored + // value, so they will find old store expressions (and associated class + // mappings) if we leave them in the table. + if (ClassChanged && isa<StoreInst>(I)) { + auto *OldE = ValueToExpression.lookup(I); + // It could just be that the old class died. We don't want to erase it if we + // just moved classes. + if (OldE && isa<StoreExpression>(OldE) && *E != *OldE) { + // Erase this as an exact expression to ensure we don't erase expressions + // equivalent to it. + auto Iter = ExpressionToClass.find_as(ExactEqualsExpression(*OldE)); + if (Iter != ExpressionToClass.end()) + ExpressionToClass.erase(Iter); + } + } + ValueToExpression[I] = E; +} + +// Process the fact that Edge (from, to) is reachable, including marking +// any newly reachable blocks and instructions for processing. +void NewGVN::updateReachableEdge(BasicBlock *From, BasicBlock *To) { + // Check if the Edge was reachable before. + if (ReachableEdges.insert({From, To}).second) { + // If this block wasn't reachable before, all instructions are touched. + if (ReachableBlocks.insert(To).second) { + DEBUG(dbgs() << "Block " << getBlockName(To) << " marked reachable\n"); + const auto &InstRange = BlockInstRange.lookup(To); + TouchedInstructions.set(InstRange.first, InstRange.second); + } else { + DEBUG(dbgs() << "Block " << getBlockName(To) + << " was reachable, but new edge {" << getBlockName(From) + << "," << getBlockName(To) << "} to it found\n"); + + // We've made an edge reachable to an existing block, which may + // impact predicates. Otherwise, only mark the phi nodes as touched, as + // they are the only thing that depend on new edges. Anything using their + // values will get propagated to if necessary. + if (MemoryAccess *MemPhi = getMemoryAccess(To)) + TouchedInstructions.set(InstrToDFSNum(MemPhi)); + + // FIXME: We should just add a union op on a Bitvector and + // SparseBitVector. We can do it word by word faster than we are doing it + // here. + for (auto InstNum : RevisitOnReachabilityChange[To]) + TouchedInstructions.set(InstNum); + } + } +} + +// Given a predicate condition (from a switch, cmp, or whatever) and a block, +// see if we know some constant value for it already. +Value *NewGVN::findConditionEquivalence(Value *Cond) const { + auto Result = lookupOperandLeader(Cond); + return isa<Constant>(Result) ? Result : nullptr; +} + +// Process the outgoing edges of a block for reachability. +void NewGVN::processOutgoingEdges(TerminatorInst *TI, BasicBlock *B) { + // Evaluate reachability of terminator instruction. + BranchInst *BR; + if ((BR = dyn_cast<BranchInst>(TI)) && BR->isConditional()) { + Value *Cond = BR->getCondition(); + Value *CondEvaluated = findConditionEquivalence(Cond); + if (!CondEvaluated) { + if (auto *I = dyn_cast<Instruction>(Cond)) { + const Expression *E = createExpression(I); + if (const auto *CE = dyn_cast<ConstantExpression>(E)) { + CondEvaluated = CE->getConstantValue(); + } + } else if (isa<ConstantInt>(Cond)) { + CondEvaluated = Cond; + } + } + ConstantInt *CI; + BasicBlock *TrueSucc = BR->getSuccessor(0); + BasicBlock *FalseSucc = BR->getSuccessor(1); + if (CondEvaluated && (CI = dyn_cast<ConstantInt>(CondEvaluated))) { + if (CI->isOne()) { + DEBUG(dbgs() << "Condition for Terminator " << *TI + << " evaluated to true\n"); + updateReachableEdge(B, TrueSucc); + } else if (CI->isZero()) { + DEBUG(dbgs() << "Condition for Terminator " << *TI + << " evaluated to false\n"); + updateReachableEdge(B, FalseSucc); + } + } else { + updateReachableEdge(B, TrueSucc); + updateReachableEdge(B, FalseSucc); + } + } else if (auto *SI = dyn_cast<SwitchInst>(TI)) { + // For switches, propagate the case values into the case + // destinations. + + // Remember how many outgoing edges there are to every successor. + SmallDenseMap<BasicBlock *, unsigned, 16> SwitchEdges; + + Value *SwitchCond = SI->getCondition(); + Value *CondEvaluated = findConditionEquivalence(SwitchCond); + // See if we were able to turn this switch statement into a constant. + if (CondEvaluated && isa<ConstantInt>(CondEvaluated)) { + auto *CondVal = cast<ConstantInt>(CondEvaluated); + // We should be able to get case value for this. + auto Case = *SI->findCaseValue(CondVal); + if (Case.getCaseSuccessor() == SI->getDefaultDest()) { + // We proved the value is outside of the range of the case. + // We can't do anything other than mark the default dest as reachable, + // and go home. + updateReachableEdge(B, SI->getDefaultDest()); + return; + } + // Now get where it goes and mark it reachable. + BasicBlock *TargetBlock = Case.getCaseSuccessor(); + updateReachableEdge(B, TargetBlock); + } else { + for (unsigned i = 0, e = SI->getNumSuccessors(); i != e; ++i) { + BasicBlock *TargetBlock = SI->getSuccessor(i); + ++SwitchEdges[TargetBlock]; + updateReachableEdge(B, TargetBlock); + } + } + } else { + // Otherwise this is either unconditional, or a type we have no + // idea about. Just mark successors as reachable. + for (unsigned i = 0, e = TI->getNumSuccessors(); i != e; ++i) { + BasicBlock *TargetBlock = TI->getSuccessor(i); + updateReachableEdge(B, TargetBlock); + } + + // This also may be a memory defining terminator, in which case, set it + // equivalent only to itself. + // + auto *MA = getMemoryAccess(TI); + if (MA && !isa<MemoryUse>(MA)) { + auto *CC = ensureLeaderOfMemoryClass(MA); + if (setMemoryClass(MA, CC)) + markMemoryUsersTouched(MA); + } + } +} + +// Remove the PHI of Ops PHI for I +void NewGVN::removePhiOfOps(Instruction *I, PHINode *PHITemp) { + InstrDFS.erase(PHITemp); + // It's still a temp instruction. We keep it in the array so it gets erased. + // However, it's no longer used by I, or in the block + TempToBlock.erase(PHITemp); + RealToTemp.erase(I); + // We don't remove the users from the phi node uses. This wastes a little + // time, but such is life. We could use two sets to track which were there + // are the start of NewGVN, and which were added, but right nowt he cost of + // tracking is more than the cost of checking for more phi of ops. +} + +// Add PHI Op in BB as a PHI of operations version of ExistingValue. +void NewGVN::addPhiOfOps(PHINode *Op, BasicBlock *BB, + Instruction *ExistingValue) { + InstrDFS[Op] = InstrToDFSNum(ExistingValue); + AllTempInstructions.insert(Op); + TempToBlock[Op] = BB; + RealToTemp[ExistingValue] = Op; + // Add all users to phi node use, as they are now uses of the phi of ops phis + // and may themselves be phi of ops. + for (auto *U : ExistingValue->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + PHINodeUses.insert(UI); +} + +static bool okayForPHIOfOps(const Instruction *I) { + if (!EnablePhiOfOps) + return false; + return isa<BinaryOperator>(I) || isa<SelectInst>(I) || isa<CmpInst>(I) || + isa<LoadInst>(I); +} + +bool NewGVN::OpIsSafeForPHIOfOpsHelper( + Value *V, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &Visited, + SmallVectorImpl<Instruction *> &Worklist) { + + if (!isa<Instruction>(V)) + return true; + auto OISIt = OpSafeForPHIOfOps.find(V); + if (OISIt != OpSafeForPHIOfOps.end()) + return OISIt->second; + + // Keep walking until we either dominate the phi block, or hit a phi, or run + // out of things to check. + if (DT->properlyDominates(getBlockForValue(V), PHIBlock)) { + OpSafeForPHIOfOps.insert({V, true}); + return true; + } + // PHI in the same block. + if (isa<PHINode>(V) && getBlockForValue(V) == PHIBlock) { + OpSafeForPHIOfOps.insert({V, false}); + return false; + } + + auto *OrigI = cast<Instruction>(V); + for (auto *Op : OrigI->operand_values()) { + if (!isa<Instruction>(Op)) + continue; + // Stop now if we find an unsafe operand. + auto OISIt = OpSafeForPHIOfOps.find(OrigI); + if (OISIt != OpSafeForPHIOfOps.end()) { + if (!OISIt->second) { + OpSafeForPHIOfOps.insert({V, false}); + return false; + } + continue; + } + if (!Visited.insert(Op).second) + continue; + Worklist.push_back(cast<Instruction>(Op)); + } + return true; +} + +// Return true if this operand will be safe to use for phi of ops. +// +// The reason some operands are unsafe is that we are not trying to recursively +// translate everything back through phi nodes. We actually expect some lookups +// of expressions to fail. In particular, a lookup where the expression cannot +// exist in the predecessor. This is true even if the expression, as shown, can +// be determined to be constant. +bool NewGVN::OpIsSafeForPHIOfOps(Value *V, const BasicBlock *PHIBlock, + SmallPtrSetImpl<const Value *> &Visited) { + SmallVector<Instruction *, 4> Worklist; + if (!OpIsSafeForPHIOfOpsHelper(V, PHIBlock, Visited, Worklist)) + return false; + while (!Worklist.empty()) { + auto *I = Worklist.pop_back_val(); + if (!OpIsSafeForPHIOfOpsHelper(I, PHIBlock, Visited, Worklist)) + return false; + } + OpSafeForPHIOfOps.insert({V, true}); + return true; +} + +// Try to find a leader for instruction TransInst, which is a phi translated +// version of something in our original program. Visited is used to ensure we +// don't infinite loop during translations of cycles. OrigInst is the +// instruction in the original program, and PredBB is the predecessor we +// translated it through. +Value *NewGVN::findLeaderForInst(Instruction *TransInst, + SmallPtrSetImpl<Value *> &Visited, + MemoryAccess *MemAccess, Instruction *OrigInst, + BasicBlock *PredBB) { + unsigned IDFSNum = InstrToDFSNum(OrigInst); + // Make sure it's marked as a temporary instruction. + AllTempInstructions.insert(TransInst); + // and make sure anything that tries to add it's DFS number is + // redirected to the instruction we are making a phi of ops + // for. + TempToBlock.insert({TransInst, PredBB}); + InstrDFS.insert({TransInst, IDFSNum}); + + const Expression *E = performSymbolicEvaluation(TransInst, Visited); + InstrDFS.erase(TransInst); + AllTempInstructions.erase(TransInst); + TempToBlock.erase(TransInst); + if (MemAccess) + TempToMemory.erase(TransInst); + if (!E) + return nullptr; + auto *FoundVal = findPHIOfOpsLeader(E, OrigInst, PredBB); + if (!FoundVal) { + ExpressionToPhiOfOps[E].insert(OrigInst); + DEBUG(dbgs() << "Cannot find phi of ops operand for " << *TransInst + << " in block " << getBlockName(PredBB) << "\n"); + return nullptr; + } + if (auto *SI = dyn_cast<StoreInst>(FoundVal)) + FoundVal = SI->getValueOperand(); + return FoundVal; +} + +// When we see an instruction that is an op of phis, generate the equivalent phi +// of ops form. +const Expression * +NewGVN::makePossiblePHIOfOps(Instruction *I, + SmallPtrSetImpl<Value *> &Visited) { + if (!okayForPHIOfOps(I)) + return nullptr; + + if (!Visited.insert(I).second) + return nullptr; + // For now, we require the instruction be cycle free because we don't + // *always* create a phi of ops for instructions that could be done as phi + // of ops, we only do it if we think it is useful. If we did do it all the + // time, we could remove the cycle free check. + if (!isCycleFree(I)) + return nullptr; + + SmallPtrSet<const Value *, 8> ProcessedPHIs; + // TODO: We don't do phi translation on memory accesses because it's + // complicated. For a load, we'd need to be able to simulate a new memoryuse, + // which we don't have a good way of doing ATM. + auto *MemAccess = getMemoryAccess(I); + // If the memory operation is defined by a memory operation this block that + // isn't a MemoryPhi, transforming the pointer backwards through a scalar phi + // can't help, as it would still be killed by that memory operation. + if (MemAccess && !isa<MemoryPhi>(MemAccess->getDefiningAccess()) && + MemAccess->getDefiningAccess()->getBlock() == I->getParent()) + return nullptr; + + SmallPtrSet<const Value *, 10> VisitedOps; + // Convert op of phis to phi of ops + for (auto *Op : I->operand_values()) { + if (!isa<PHINode>(Op)) { + auto *ValuePHI = RealToTemp.lookup(Op); + if (!ValuePHI) + continue; + DEBUG(dbgs() << "Found possible dependent phi of ops\n"); + Op = ValuePHI; + } + auto *OpPHI = cast<PHINode>(Op); + // No point in doing this for one-operand phis. + if (OpPHI->getNumOperands() == 1) + continue; + if (!DebugCounter::shouldExecute(PHIOfOpsCounter)) + return nullptr; + SmallVector<ValPair, 4> Ops; + SmallPtrSet<Value *, 4> Deps; + auto *PHIBlock = getBlockForValue(OpPHI); + RevisitOnReachabilityChange[PHIBlock].reset(InstrToDFSNum(I)); + for (unsigned PredNum = 0; PredNum < OpPHI->getNumOperands(); ++PredNum) { + auto *PredBB = OpPHI->getIncomingBlock(PredNum); + Value *FoundVal = nullptr; + // We could just skip unreachable edges entirely but it's tricky to do + // with rewriting existing phi nodes. + if (ReachableEdges.count({PredBB, PHIBlock})) { + // Clone the instruction, create an expression from it that is + // translated back into the predecessor, and see if we have a leader. + Instruction *ValueOp = I->clone(); + if (MemAccess) + TempToMemory.insert({ValueOp, MemAccess}); + bool SafeForPHIOfOps = true; + VisitedOps.clear(); + for (auto &Op : ValueOp->operands()) { + auto *OrigOp = &*Op; + // When these operand changes, it could change whether there is a + // leader for us or not, so we have to add additional users. + if (isa<PHINode>(Op)) { + Op = Op->DoPHITranslation(PHIBlock, PredBB); + if (Op != OrigOp && Op != I) + Deps.insert(Op); + } else if (auto *ValuePHI = RealToTemp.lookup(Op)) { + if (getBlockForValue(ValuePHI) == PHIBlock) + Op = ValuePHI->getIncomingValueForBlock(PredBB); + } + // If we phi-translated the op, it must be safe. + SafeForPHIOfOps = + SafeForPHIOfOps && + (Op != OrigOp || OpIsSafeForPHIOfOps(Op, PHIBlock, VisitedOps)); + } + // FIXME: For those things that are not safe we could generate + // expressions all the way down, and see if this comes out to a + // constant. For anything where that is true, and unsafe, we should + // have made a phi-of-ops (or value numbered it equivalent to something) + // for the pieces already. + FoundVal = !SafeForPHIOfOps ? nullptr + : findLeaderForInst(ValueOp, Visited, + MemAccess, I, PredBB); + ValueOp->deleteValue(); + if (!FoundVal) + return nullptr; + } else { + DEBUG(dbgs() << "Skipping phi of ops operand for incoming block " + << getBlockName(PredBB) + << " because the block is unreachable\n"); + FoundVal = UndefValue::get(I->getType()); + RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); + } + + Ops.push_back({FoundVal, PredBB}); + DEBUG(dbgs() << "Found phi of ops operand " << *FoundVal << " in " + << getBlockName(PredBB) << "\n"); + } + for (auto Dep : Deps) + addAdditionalUsers(Dep, I); + sortPHIOps(Ops); + auto *E = performSymbolicPHIEvaluation(Ops, I, PHIBlock); + if (isa<ConstantExpression>(E) || isa<VariableExpression>(E)) { + DEBUG(dbgs() + << "Not creating real PHI of ops because it simplified to existing " + "value or constant\n"); + return E; + } + auto *ValuePHI = RealToTemp.lookup(I); + bool NewPHI = false; + if (!ValuePHI) { + ValuePHI = + PHINode::Create(I->getType(), OpPHI->getNumOperands(), "phiofops"); + addPhiOfOps(ValuePHI, PHIBlock, I); + NewPHI = true; + NumGVNPHIOfOpsCreated++; + } + if (NewPHI) { + for (auto PHIOp : Ops) + ValuePHI->addIncoming(PHIOp.first, PHIOp.second); + } else { + unsigned int i = 0; + for (auto PHIOp : Ops) { + ValuePHI->setIncomingValue(i, PHIOp.first); + ValuePHI->setIncomingBlock(i, PHIOp.second); + ++i; + } + } + RevisitOnReachabilityChange[PHIBlock].set(InstrToDFSNum(I)); + DEBUG(dbgs() << "Created phi of ops " << *ValuePHI << " for " << *I + << "\n"); + + return E; + } + return nullptr; +} + +// The algorithm initially places the values of the routine in the TOP +// congruence class. The leader of TOP is the undetermined value `undef`. +// When the algorithm has finished, values still in TOP are unreachable. +void NewGVN::initializeCongruenceClasses(Function &F) { + NextCongruenceNum = 0; + + // Note that even though we use the live on entry def as a representative + // MemoryAccess, it is *not* the same as the actual live on entry def. We + // have no real equivalemnt to undef for MemoryAccesses, and so we really + // should be checking whether the MemoryAccess is top if we want to know if it + // is equivalent to everything. Otherwise, what this really signifies is that + // the access "it reaches all the way back to the beginning of the function" + + // Initialize all other instructions to be in TOP class. + TOPClass = createCongruenceClass(nullptr, nullptr); + TOPClass->setMemoryLeader(MSSA->getLiveOnEntryDef()); + // The live on entry def gets put into it's own class + MemoryAccessToClass[MSSA->getLiveOnEntryDef()] = + createMemoryClass(MSSA->getLiveOnEntryDef()); + + for (auto DTN : nodes(DT)) { + BasicBlock *BB = DTN->getBlock(); + // All MemoryAccesses are equivalent to live on entry to start. They must + // be initialized to something so that initial changes are noticed. For + // the maximal answer, we initialize them all to be the same as + // liveOnEntry. + auto *MemoryBlockDefs = MSSA->getBlockDefs(BB); + if (MemoryBlockDefs) + for (const auto &Def : *MemoryBlockDefs) { + MemoryAccessToClass[&Def] = TOPClass; + auto *MD = dyn_cast<MemoryDef>(&Def); + // Insert the memory phis into the member list. + if (!MD) { + const MemoryPhi *MP = cast<MemoryPhi>(&Def); + TOPClass->memory_insert(MP); + MemoryPhiState.insert({MP, MPS_TOP}); + } + + if (MD && isa<StoreInst>(MD->getMemoryInst())) + TOPClass->incStoreCount(); + } + + // FIXME: This is trying to discover which instructions are uses of phi + // nodes. We should move this into one of the myriad of places that walk + // all the operands already. + for (auto &I : *BB) { + if (isa<PHINode>(&I)) + for (auto *U : I.users()) + if (auto *UInst = dyn_cast<Instruction>(U)) + if (InstrToDFSNum(UInst) != 0 && okayForPHIOfOps(UInst)) + PHINodeUses.insert(UInst); + // Don't insert void terminators into the class. We don't value number + // them, and they just end up sitting in TOP. + if (isa<TerminatorInst>(I) && I.getType()->isVoidTy()) + continue; + TOPClass->insert(&I); + ValueToClass[&I] = TOPClass; + } + } + + // Initialize arguments to be in their own unique congruence classes + for (auto &FA : F.args()) + createSingletonCongruenceClass(&FA); +} + +void NewGVN::cleanupTables() { + for (unsigned i = 0, e = CongruenceClasses.size(); i != e; ++i) { + DEBUG(dbgs() << "Congruence class " << CongruenceClasses[i]->getID() + << " has " << CongruenceClasses[i]->size() << " members\n"); + // Make sure we delete the congruence class (probably worth switching to + // a unique_ptr at some point. + delete CongruenceClasses[i]; + CongruenceClasses[i] = nullptr; + } + + // Destroy the value expressions + SmallVector<Instruction *, 8> TempInst(AllTempInstructions.begin(), + AllTempInstructions.end()); + AllTempInstructions.clear(); + + // We have to drop all references for everything first, so there are no uses + // left as we delete them. + for (auto *I : TempInst) { + I->dropAllReferences(); + } + + while (!TempInst.empty()) { + auto *I = TempInst.back(); + TempInst.pop_back(); + I->deleteValue(); + } + + ValueToClass.clear(); + ArgRecycler.clear(ExpressionAllocator); + ExpressionAllocator.Reset(); + CongruenceClasses.clear(); + ExpressionToClass.clear(); + ValueToExpression.clear(); + RealToTemp.clear(); + AdditionalUsers.clear(); + ExpressionToPhiOfOps.clear(); + TempToBlock.clear(); + TempToMemory.clear(); + PHINodeUses.clear(); + OpSafeForPHIOfOps.clear(); + ReachableBlocks.clear(); + ReachableEdges.clear(); +#ifndef NDEBUG + ProcessedCount.clear(); +#endif + InstrDFS.clear(); + InstructionsToErase.clear(); + DFSToInstr.clear(); + BlockInstRange.clear(); + TouchedInstructions.clear(); + MemoryAccessToClass.clear(); + PredicateToUsers.clear(); + MemoryToUsers.clear(); + RevisitOnReachabilityChange.clear(); +} + +// Assign local DFS number mapping to instructions, and leave space for Value +// PHI's. +std::pair<unsigned, unsigned> NewGVN::assignDFSNumbers(BasicBlock *B, + unsigned Start) { + unsigned End = Start; + if (MemoryAccess *MemPhi = getMemoryAccess(B)) { + InstrDFS[MemPhi] = End++; + DFSToInstr.emplace_back(MemPhi); + } + + // Then the real block goes next. + for (auto &I : *B) { + // There's no need to call isInstructionTriviallyDead more than once on + // an instruction. Therefore, once we know that an instruction is dead + // we change its DFS number so that it doesn't get value numbered. + if (isInstructionTriviallyDead(&I, TLI)) { + InstrDFS[&I] = 0; + DEBUG(dbgs() << "Skipping trivially dead instruction " << I << "\n"); + markInstructionForDeletion(&I); + continue; + } + if (isa<PHINode>(&I)) + RevisitOnReachabilityChange[B].set(End); + InstrDFS[&I] = End++; + DFSToInstr.emplace_back(&I); + } + + // All of the range functions taken half-open ranges (open on the end side). + // So we do not subtract one from count, because at this point it is one + // greater than the last instruction. + return std::make_pair(Start, End); +} + +void NewGVN::updateProcessedCount(const Value *V) { +#ifndef NDEBUG + if (ProcessedCount.count(V) == 0) { + ProcessedCount.insert({V, 1}); + } else { + ++ProcessedCount[V]; + assert(ProcessedCount[V] < 100 && + "Seem to have processed the same Value a lot"); + } +#endif +} + +// Evaluate MemoryPhi nodes symbolically, just like PHI nodes +void NewGVN::valueNumberMemoryPhi(MemoryPhi *MP) { + // If all the arguments are the same, the MemoryPhi has the same value as the + // argument. Filter out unreachable blocks and self phis from our operands. + // TODO: We could do cycle-checking on the memory phis to allow valueizing for + // self-phi checking. + const BasicBlock *PHIBlock = MP->getBlock(); + auto Filtered = make_filter_range(MP->operands(), [&](const Use &U) { + return cast<MemoryAccess>(U) != MP && + !isMemoryAccessTOP(cast<MemoryAccess>(U)) && + ReachableEdges.count({MP->getIncomingBlock(U), PHIBlock}); + }); + // If all that is left is nothing, our memoryphi is undef. We keep it as + // InitialClass. Note: The only case this should happen is if we have at + // least one self-argument. + if (Filtered.begin() == Filtered.end()) { + if (setMemoryClass(MP, TOPClass)) + markMemoryUsersTouched(MP); + return; + } + + // Transform the remaining operands into operand leaders. + // FIXME: mapped_iterator should have a range version. + auto LookupFunc = [&](const Use &U) { + return lookupMemoryLeader(cast<MemoryAccess>(U)); + }; + auto MappedBegin = map_iterator(Filtered.begin(), LookupFunc); + auto MappedEnd = map_iterator(Filtered.end(), LookupFunc); + + // and now check if all the elements are equal. + // Sadly, we can't use std::equals since these are random access iterators. + const auto *AllSameValue = *MappedBegin; + ++MappedBegin; + bool AllEqual = std::all_of( + MappedBegin, MappedEnd, + [&AllSameValue](const MemoryAccess *V) { return V == AllSameValue; }); + + if (AllEqual) + DEBUG(dbgs() << "Memory Phi value numbered to " << *AllSameValue << "\n"); + else + DEBUG(dbgs() << "Memory Phi value numbered to itself\n"); + // If it's equal to something, it's in that class. Otherwise, it has to be in + // a class where it is the leader (other things may be equivalent to it, but + // it needs to start off in its own class, which means it must have been the + // leader, and it can't have stopped being the leader because it was never + // removed). + CongruenceClass *CC = + AllEqual ? getMemoryClass(AllSameValue) : ensureLeaderOfMemoryClass(MP); + auto OldState = MemoryPhiState.lookup(MP); + assert(OldState != MPS_Invalid && "Invalid memory phi state"); + auto NewState = AllEqual ? MPS_Equivalent : MPS_Unique; + MemoryPhiState[MP] = NewState; + if (setMemoryClass(MP, CC) || OldState != NewState) + markMemoryUsersTouched(MP); +} + +// Value number a single instruction, symbolically evaluating, performing +// congruence finding, and updating mappings. +void NewGVN::valueNumberInstruction(Instruction *I) { + DEBUG(dbgs() << "Processing instruction " << *I << "\n"); + if (!I->isTerminator()) { + const Expression *Symbolized = nullptr; + SmallPtrSet<Value *, 2> Visited; + if (DebugCounter::shouldExecute(VNCounter)) { + Symbolized = performSymbolicEvaluation(I, Visited); + // Make a phi of ops if necessary + if (Symbolized && !isa<ConstantExpression>(Symbolized) && + !isa<VariableExpression>(Symbolized) && PHINodeUses.count(I)) { + auto *PHIE = makePossiblePHIOfOps(I, Visited); + // If we created a phi of ops, use it. + // If we couldn't create one, make sure we don't leave one lying around + if (PHIE) { + Symbolized = PHIE; + } else if (auto *Op = RealToTemp.lookup(I)) { + removePhiOfOps(I, Op); + } + } + } else { + // Mark the instruction as unused so we don't value number it again. + InstrDFS[I] = 0; + } + // If we couldn't come up with a symbolic expression, use the unknown + // expression + if (Symbolized == nullptr) + Symbolized = createUnknownExpression(I); + performCongruenceFinding(I, Symbolized); + } else { + // Handle terminators that return values. All of them produce values we + // don't currently understand. We don't place non-value producing + // terminators in a class. + if (!I->getType()->isVoidTy()) { + auto *Symbolized = createUnknownExpression(I); + performCongruenceFinding(I, Symbolized); + } + processOutgoingEdges(dyn_cast<TerminatorInst>(I), I->getParent()); + } +} + +// Check if there is a path, using single or equal argument phi nodes, from +// First to Second. +bool NewGVN::singleReachablePHIPath( + SmallPtrSet<const MemoryAccess *, 8> &Visited, const MemoryAccess *First, + const MemoryAccess *Second) const { + if (First == Second) + return true; + if (MSSA->isLiveOnEntryDef(First)) + return false; + + // This is not perfect, but as we're just verifying here, we can live with + // the loss of precision. The real solution would be that of doing strongly + // connected component finding in this routine, and it's probably not worth + // the complexity for the time being. So, we just keep a set of visited + // MemoryAccess and return true when we hit a cycle. + if (Visited.count(First)) + return true; + Visited.insert(First); + + const auto *EndDef = First; + for (auto *ChainDef : optimized_def_chain(First)) { + if (ChainDef == Second) + return true; + if (MSSA->isLiveOnEntryDef(ChainDef)) + return false; + EndDef = ChainDef; + } + auto *MP = cast<MemoryPhi>(EndDef); + auto ReachableOperandPred = [&](const Use &U) { + return ReachableEdges.count({MP->getIncomingBlock(U), MP->getBlock()}); + }; + auto FilteredPhiArgs = + make_filter_range(MP->operands(), ReachableOperandPred); + SmallVector<const Value *, 32> OperandList; + std::copy(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), + std::back_inserter(OperandList)); + bool Okay = OperandList.size() == 1; + if (!Okay) + Okay = + std::equal(OperandList.begin(), OperandList.end(), OperandList.begin()); + if (Okay) + return singleReachablePHIPath(Visited, cast<MemoryAccess>(OperandList[0]), + Second); + return false; +} + +// Verify the that the memory equivalence table makes sense relative to the +// congruence classes. Note that this checking is not perfect, and is currently +// subject to very rare false negatives. It is only useful for +// testing/debugging. +void NewGVN::verifyMemoryCongruency() const { +#ifndef NDEBUG + // Verify that the memory table equivalence and memory member set match + for (const auto *CC : CongruenceClasses) { + if (CC == TOPClass || CC->isDead()) + continue; + if (CC->getStoreCount() != 0) { + assert((CC->getStoredValue() || !isa<StoreInst>(CC->getLeader())) && + "Any class with a store as a leader should have a " + "representative stored value"); + assert(CC->getMemoryLeader() && + "Any congruence class with a store should have a " + "representative access"); + } + + if (CC->getMemoryLeader()) + assert(MemoryAccessToClass.lookup(CC->getMemoryLeader()) == CC && + "Representative MemoryAccess does not appear to be reverse " + "mapped properly"); + for (auto M : CC->memory()) + assert(MemoryAccessToClass.lookup(M) == CC && + "Memory member does not appear to be reverse mapped properly"); + } + + // Anything equivalent in the MemoryAccess table should be in the same + // congruence class. + + // Filter out the unreachable and trivially dead entries, because they may + // never have been updated if the instructions were not processed. + auto ReachableAccessPred = + [&](const std::pair<const MemoryAccess *, CongruenceClass *> Pair) { + bool Result = ReachableBlocks.count(Pair.first->getBlock()); + if (!Result || MSSA->isLiveOnEntryDef(Pair.first) || + MemoryToDFSNum(Pair.first) == 0) + return false; + if (auto *MemDef = dyn_cast<MemoryDef>(Pair.first)) + return !isInstructionTriviallyDead(MemDef->getMemoryInst()); + + // We could have phi nodes which operands are all trivially dead, + // so we don't process them. + if (auto *MemPHI = dyn_cast<MemoryPhi>(Pair.first)) { + for (auto &U : MemPHI->incoming_values()) { + if (auto *I = dyn_cast<Instruction>(&*U)) { + if (!isInstructionTriviallyDead(I)) + return true; + } + } + return false; + } + + return true; + }; + + auto Filtered = make_filter_range(MemoryAccessToClass, ReachableAccessPred); + for (auto KV : Filtered) { + if (auto *FirstMUD = dyn_cast<MemoryUseOrDef>(KV.first)) { + auto *SecondMUD = dyn_cast<MemoryUseOrDef>(KV.second->getMemoryLeader()); + if (FirstMUD && SecondMUD) { + SmallPtrSet<const MemoryAccess *, 8> VisitedMAS; + assert((singleReachablePHIPath(VisitedMAS, FirstMUD, SecondMUD) || + ValueToClass.lookup(FirstMUD->getMemoryInst()) == + ValueToClass.lookup(SecondMUD->getMemoryInst())) && + "The instructions for these memory operations should have " + "been in the same congruence class or reachable through" + "a single argument phi"); + } + } else if (auto *FirstMP = dyn_cast<MemoryPhi>(KV.first)) { + // We can only sanely verify that MemoryDefs in the operand list all have + // the same class. + auto ReachableOperandPred = [&](const Use &U) { + return ReachableEdges.count( + {FirstMP->getIncomingBlock(U), FirstMP->getBlock()}) && + isa<MemoryDef>(U); + + }; + // All arguments should in the same class, ignoring unreachable arguments + auto FilteredPhiArgs = + make_filter_range(FirstMP->operands(), ReachableOperandPred); + SmallVector<const CongruenceClass *, 16> PhiOpClasses; + std::transform(FilteredPhiArgs.begin(), FilteredPhiArgs.end(), + std::back_inserter(PhiOpClasses), [&](const Use &U) { + const MemoryDef *MD = cast<MemoryDef>(U); + return ValueToClass.lookup(MD->getMemoryInst()); + }); + assert(std::equal(PhiOpClasses.begin(), PhiOpClasses.end(), + PhiOpClasses.begin()) && + "All MemoryPhi arguments should be in the same class"); + } + } +#endif +} + +// Verify that the sparse propagation we did actually found the maximal fixpoint +// We do this by storing the value to class mapping, touching all instructions, +// and redoing the iteration to see if anything changed. +void NewGVN::verifyIterationSettled(Function &F) { +#ifndef NDEBUG + DEBUG(dbgs() << "Beginning iteration verification\n"); + if (DebugCounter::isCounterSet(VNCounter)) + DebugCounter::setCounterValue(VNCounter, StartingVNCounter); + + // Note that we have to store the actual classes, as we may change existing + // classes during iteration. This is because our memory iteration propagation + // is not perfect, and so may waste a little work. But it should generate + // exactly the same congruence classes we have now, with different IDs. + std::map<const Value *, CongruenceClass> BeforeIteration; + + for (auto &KV : ValueToClass) { + if (auto *I = dyn_cast<Instruction>(KV.first)) + // Skip unused/dead instructions. + if (InstrToDFSNum(I) == 0) + continue; + BeforeIteration.insert({KV.first, *KV.second}); + } + + TouchedInstructions.set(); + TouchedInstructions.reset(0); + iterateTouchedInstructions(); + DenseSet<std::pair<const CongruenceClass *, const CongruenceClass *>> + EqualClasses; + for (const auto &KV : ValueToClass) { + if (auto *I = dyn_cast<Instruction>(KV.first)) + // Skip unused/dead instructions. + if (InstrToDFSNum(I) == 0) + continue; + // We could sink these uses, but i think this adds a bit of clarity here as + // to what we are comparing. + auto *BeforeCC = &BeforeIteration.find(KV.first)->second; + auto *AfterCC = KV.second; + // Note that the classes can't change at this point, so we memoize the set + // that are equal. + if (!EqualClasses.count({BeforeCC, AfterCC})) { + assert(BeforeCC->isEquivalentTo(AfterCC) && + "Value number changed after main loop completed!"); + EqualClasses.insert({BeforeCC, AfterCC}); + } + } +#endif +} + +// Verify that for each store expression in the expression to class mapping, +// only the latest appears, and multiple ones do not appear. +// Because loads do not use the stored value when doing equality with stores, +// if we don't erase the old store expressions from the table, a load can find +// a no-longer valid StoreExpression. +void NewGVN::verifyStoreExpressions() const { +#ifndef NDEBUG + // This is the only use of this, and it's not worth defining a complicated + // densemapinfo hash/equality function for it. + std::set< + std::pair<const Value *, + std::tuple<const Value *, const CongruenceClass *, Value *>>> + StoreExpressionSet; + for (const auto &KV : ExpressionToClass) { + if (auto *SE = dyn_cast<StoreExpression>(KV.first)) { + // Make sure a version that will conflict with loads is not already there + auto Res = StoreExpressionSet.insert( + {SE->getOperand(0), std::make_tuple(SE->getMemoryLeader(), KV.second, + SE->getStoredValue())}); + bool Okay = Res.second; + // It's okay to have the same expression already in there if it is + // identical in nature. + // This can happen when the leader of the stored value changes over time. + if (!Okay) + Okay = (std::get<1>(Res.first->second) == KV.second) && + (lookupOperandLeader(std::get<2>(Res.first->second)) == + lookupOperandLeader(SE->getStoredValue())); + assert(Okay && "Stored expression conflict exists in expression table"); + auto *ValueExpr = ValueToExpression.lookup(SE->getStoreInst()); + assert(ValueExpr && ValueExpr->equals(*SE) && + "StoreExpression in ExpressionToClass is not latest " + "StoreExpression for value"); + } + } +#endif +} + +// This is the main value numbering loop, it iterates over the initial touched +// instruction set, propagating value numbers, marking things touched, etc, +// until the set of touched instructions is completely empty. +void NewGVN::iterateTouchedInstructions() { + unsigned int Iterations = 0; + // Figure out where touchedinstructions starts + int FirstInstr = TouchedInstructions.find_first(); + // Nothing set, nothing to iterate, just return. + if (FirstInstr == -1) + return; + const BasicBlock *LastBlock = getBlockForValue(InstrFromDFSNum(FirstInstr)); + while (TouchedInstructions.any()) { + ++Iterations; + // Walk through all the instructions in all the blocks in RPO. + // TODO: As we hit a new block, we should push and pop equalities into a + // table lookupOperandLeader can use, to catch things PredicateInfo + // might miss, like edge-only equivalences. + for (unsigned InstrNum : TouchedInstructions.set_bits()) { + + // This instruction was found to be dead. We don't bother looking + // at it again. + if (InstrNum == 0) { + TouchedInstructions.reset(InstrNum); + continue; + } + + Value *V = InstrFromDFSNum(InstrNum); + const BasicBlock *CurrBlock = getBlockForValue(V); + + // If we hit a new block, do reachability processing. + if (CurrBlock != LastBlock) { + LastBlock = CurrBlock; + bool BlockReachable = ReachableBlocks.count(CurrBlock); + const auto &CurrInstRange = BlockInstRange.lookup(CurrBlock); + + // If it's not reachable, erase any touched instructions and move on. + if (!BlockReachable) { + TouchedInstructions.reset(CurrInstRange.first, CurrInstRange.second); + DEBUG(dbgs() << "Skipping instructions in block " + << getBlockName(CurrBlock) + << " because it is unreachable\n"); + continue; + } + updateProcessedCount(CurrBlock); + } + // Reset after processing (because we may mark ourselves as touched when + // we propagate equalities). + TouchedInstructions.reset(InstrNum); + + if (auto *MP = dyn_cast<MemoryPhi>(V)) { + DEBUG(dbgs() << "Processing MemoryPhi " << *MP << "\n"); + valueNumberMemoryPhi(MP); + } else if (auto *I = dyn_cast<Instruction>(V)) { + valueNumberInstruction(I); + } else { + llvm_unreachable("Should have been a MemoryPhi or Instruction"); + } + updateProcessedCount(V); + } + } + NumGVNMaxIterations = std::max(NumGVNMaxIterations.getValue(), Iterations); +} + +// This is the main transformation entry point. +bool NewGVN::runGVN() { + if (DebugCounter::isCounterSet(VNCounter)) + StartingVNCounter = DebugCounter::getCounterValue(VNCounter); + bool Changed = false; + NumFuncArgs = F.arg_size(); + MSSAWalker = MSSA->getWalker(); + SingletonDeadExpression = new (ExpressionAllocator) DeadExpression(); + + // Count number of instructions for sizing of hash tables, and come + // up with a global dfs numbering for instructions. + unsigned ICount = 1; + // Add an empty instruction to account for the fact that we start at 1 + DFSToInstr.emplace_back(nullptr); + // Note: We want ideal RPO traversal of the blocks, which is not quite the + // same as dominator tree order, particularly with regard whether backedges + // get visited first or second, given a block with multiple successors. + // If we visit in the wrong order, we will end up performing N times as many + // iterations. + // The dominator tree does guarantee that, for a given dom tree node, it's + // parent must occur before it in the RPO ordering. Thus, we only need to sort + // the siblings. + ReversePostOrderTraversal<Function *> RPOT(&F); + unsigned Counter = 0; + for (auto &B : RPOT) { + auto *Node = DT->getNode(B); + assert(Node && "RPO and Dominator tree should have same reachability"); + RPOOrdering[Node] = ++Counter; + } + // Sort dominator tree children arrays into RPO. + for (auto &B : RPOT) { + auto *Node = DT->getNode(B); + if (Node->getChildren().size() > 1) + std::sort(Node->begin(), Node->end(), + [&](const DomTreeNode *A, const DomTreeNode *B) { + return RPOOrdering[A] < RPOOrdering[B]; + }); + } + + // Now a standard depth first ordering of the domtree is equivalent to RPO. + for (auto DTN : depth_first(DT->getRootNode())) { + BasicBlock *B = DTN->getBlock(); + const auto &BlockRange = assignDFSNumbers(B, ICount); + BlockInstRange.insert({B, BlockRange}); + ICount += BlockRange.second - BlockRange.first; + } + initializeCongruenceClasses(F); + + TouchedInstructions.resize(ICount); + // Ensure we don't end up resizing the expressionToClass map, as + // that can be quite expensive. At most, we have one expression per + // instruction. + ExpressionToClass.reserve(ICount); + + // Initialize the touched instructions to include the entry block. + const auto &InstRange = BlockInstRange.lookup(&F.getEntryBlock()); + TouchedInstructions.set(InstRange.first, InstRange.second); + DEBUG(dbgs() << "Block " << getBlockName(&F.getEntryBlock()) + << " marked reachable\n"); + ReachableBlocks.insert(&F.getEntryBlock()); + + iterateTouchedInstructions(); + verifyMemoryCongruency(); + verifyIterationSettled(F); + verifyStoreExpressions(); + + Changed |= eliminateInstructions(F); + + // Delete all instructions marked for deletion. + for (Instruction *ToErase : InstructionsToErase) { + if (!ToErase->use_empty()) + ToErase->replaceAllUsesWith(UndefValue::get(ToErase->getType())); + + if (ToErase->getParent()) + ToErase->eraseFromParent(); + } + + // Delete all unreachable blocks. + auto UnreachableBlockPred = [&](const BasicBlock &BB) { + return !ReachableBlocks.count(&BB); + }; + + for (auto &BB : make_filter_range(F, UnreachableBlockPred)) { + DEBUG(dbgs() << "We believe block " << getBlockName(&BB) + << " is unreachable\n"); + deleteInstructionsInBlock(&BB); + Changed = true; + } + + cleanupTables(); + return Changed; +} + +struct NewGVN::ValueDFS { + int DFSIn = 0; + int DFSOut = 0; + int LocalNum = 0; + + // Only one of Def and U will be set. + // The bool in the Def tells us whether the Def is the stored value of a + // store. + PointerIntPair<Value *, 1, bool> Def; + Use *U = nullptr; + + bool operator<(const ValueDFS &Other) const { + // It's not enough that any given field be less than - we have sets + // of fields that need to be evaluated together to give a proper ordering. + // For example, if you have; + // DFS (1, 3) + // Val 0 + // DFS (1, 2) + // Val 50 + // We want the second to be less than the first, but if we just go field + // by field, we will get to Val 0 < Val 50 and say the first is less than + // the second. We only want it to be less than if the DFS orders are equal. + // + // Each LLVM instruction only produces one value, and thus the lowest-level + // differentiator that really matters for the stack (and what we use as as a + // replacement) is the local dfs number. + // Everything else in the structure is instruction level, and only affects + // the order in which we will replace operands of a given instruction. + // + // For a given instruction (IE things with equal dfsin, dfsout, localnum), + // the order of replacement of uses does not matter. + // IE given, + // a = 5 + // b = a + a + // When you hit b, you will have two valuedfs with the same dfsin, out, and + // localnum. + // The .val will be the same as well. + // The .u's will be different. + // You will replace both, and it does not matter what order you replace them + // in (IE whether you replace operand 2, then operand 1, or operand 1, then + // operand 2). + // Similarly for the case of same dfsin, dfsout, localnum, but different + // .val's + // a = 5 + // b = 6 + // c = a + b + // in c, we will a valuedfs for a, and one for b,with everything the same + // but .val and .u. + // It does not matter what order we replace these operands in. + // You will always end up with the same IR, and this is guaranteed. + return std::tie(DFSIn, DFSOut, LocalNum, Def, U) < + std::tie(Other.DFSIn, Other.DFSOut, Other.LocalNum, Other.Def, + Other.U); + } +}; + +// This function converts the set of members for a congruence class from values, +// to sets of defs and uses with associated DFS info. The total number of +// reachable uses for each value is stored in UseCount, and instructions that +// seem +// dead (have no non-dead uses) are stored in ProbablyDead. +void NewGVN::convertClassToDFSOrdered( + const CongruenceClass &Dense, SmallVectorImpl<ValueDFS> &DFSOrderedSet, + DenseMap<const Value *, unsigned int> &UseCounts, + SmallPtrSetImpl<Instruction *> &ProbablyDead) const { + for (auto D : Dense) { + // First add the value. + BasicBlock *BB = getBlockForValue(D); + // Constants are handled prior to ever calling this function, so + // we should only be left with instructions as members. + assert(BB && "Should have figured out a basic block for value"); + ValueDFS VDDef; + DomTreeNode *DomNode = DT->getNode(BB); + VDDef.DFSIn = DomNode->getDFSNumIn(); + VDDef.DFSOut = DomNode->getDFSNumOut(); + // If it's a store, use the leader of the value operand, if it's always + // available, or the value operand. TODO: We could do dominance checks to + // find a dominating leader, but not worth it ATM. + if (auto *SI = dyn_cast<StoreInst>(D)) { + auto Leader = lookupOperandLeader(SI->getValueOperand()); + if (alwaysAvailable(Leader)) { + VDDef.Def.setPointer(Leader); + } else { + VDDef.Def.setPointer(SI->getValueOperand()); + VDDef.Def.setInt(true); + } + } else { + VDDef.Def.setPointer(D); + } + assert(isa<Instruction>(D) && + "The dense set member should always be an instruction"); + Instruction *Def = cast<Instruction>(D); + VDDef.LocalNum = InstrToDFSNum(D); + DFSOrderedSet.push_back(VDDef); + // If there is a phi node equivalent, add it + if (auto *PN = RealToTemp.lookup(Def)) { + auto *PHIE = + dyn_cast_or_null<PHIExpression>(ValueToExpression.lookup(Def)); + if (PHIE) { + VDDef.Def.setInt(false); + VDDef.Def.setPointer(PN); + VDDef.LocalNum = 0; + DFSOrderedSet.push_back(VDDef); + } + } + + unsigned int UseCount = 0; + // Now add the uses. + for (auto &U : Def->uses()) { + if (auto *I = dyn_cast<Instruction>(U.getUser())) { + // Don't try to replace into dead uses + if (InstructionsToErase.count(I)) + continue; + ValueDFS VDUse; + // Put the phi node uses in the incoming block. + BasicBlock *IBlock; + if (auto *P = dyn_cast<PHINode>(I)) { + IBlock = P->getIncomingBlock(U); + // Make phi node users appear last in the incoming block + // they are from. + VDUse.LocalNum = InstrDFS.size() + 1; + } else { + IBlock = getBlockForValue(I); + VDUse.LocalNum = InstrToDFSNum(I); + } + + // Skip uses in unreachable blocks, as we're going + // to delete them. + if (ReachableBlocks.count(IBlock) == 0) + continue; + + DomTreeNode *DomNode = DT->getNode(IBlock); + VDUse.DFSIn = DomNode->getDFSNumIn(); + VDUse.DFSOut = DomNode->getDFSNumOut(); + VDUse.U = &U; + ++UseCount; + DFSOrderedSet.emplace_back(VDUse); + } + } + + // If there are no uses, it's probably dead (but it may have side-effects, + // so not definitely dead. Otherwise, store the number of uses so we can + // track if it becomes dead later). + if (UseCount == 0) + ProbablyDead.insert(Def); + else + UseCounts[Def] = UseCount; + } +} + +// This function converts the set of members for a congruence class from values, +// to the set of defs for loads and stores, with associated DFS info. +void NewGVN::convertClassToLoadsAndStores( + const CongruenceClass &Dense, + SmallVectorImpl<ValueDFS> &LoadsAndStores) const { + for (auto D : Dense) { + if (!isa<LoadInst>(D) && !isa<StoreInst>(D)) + continue; + + BasicBlock *BB = getBlockForValue(D); + ValueDFS VD; + DomTreeNode *DomNode = DT->getNode(BB); + VD.DFSIn = DomNode->getDFSNumIn(); + VD.DFSOut = DomNode->getDFSNumOut(); + VD.Def.setPointer(D); + + // If it's an instruction, use the real local dfs number. + if (auto *I = dyn_cast<Instruction>(D)) + VD.LocalNum = InstrToDFSNum(I); + else + llvm_unreachable("Should have been an instruction"); + + LoadsAndStores.emplace_back(VD); + } +} + +static void patchReplacementInstruction(Instruction *I, Value *Repl) { + auto *ReplInst = dyn_cast<Instruction>(Repl); + if (!ReplInst) + return; + + // Patch the replacement so that it is not more restrictive than the value + // being replaced. + // Note that if 'I' is a load being replaced by some operation, + // for example, by an arithmetic operation, then andIRFlags() + // would just erase all math flags from the original arithmetic + // operation, which is clearly not wanted and not needed. + if (!isa<LoadInst>(I)) + ReplInst->andIRFlags(I); + + // FIXME: If both the original and replacement value are part of the + // same control-flow region (meaning that the execution of one + // guarantees the execution of the other), then we can combine the + // noalias scopes here and do better than the general conservative + // answer used in combineMetadata(). + + // In general, GVN unifies expressions over different control-flow + // regions, and so we need a conservative combination of the noalias + // scopes. + static const unsigned KnownIDs[] = { + LLVMContext::MD_tbaa, LLVMContext::MD_alias_scope, + LLVMContext::MD_noalias, LLVMContext::MD_range, + LLVMContext::MD_fpmath, LLVMContext::MD_invariant_load, + LLVMContext::MD_invariant_group}; + combineMetadata(ReplInst, I, KnownIDs); +} + +static void patchAndReplaceAllUsesWith(Instruction *I, Value *Repl) { + patchReplacementInstruction(I, Repl); + I->replaceAllUsesWith(Repl); +} + +void NewGVN::deleteInstructionsInBlock(BasicBlock *BB) { + DEBUG(dbgs() << " BasicBlock Dead:" << *BB); + ++NumGVNBlocksDeleted; + + // Delete the instructions backwards, as it has a reduced likelihood of having + // to update as many def-use and use-def chains. Start after the terminator. + auto StartPoint = BB->rbegin(); + ++StartPoint; + // Note that we explicitly recalculate BB->rend() on each iteration, + // as it may change when we remove the first instruction. + for (BasicBlock::reverse_iterator I(StartPoint); I != BB->rend();) { + Instruction &Inst = *I++; + if (!Inst.use_empty()) + Inst.replaceAllUsesWith(UndefValue::get(Inst.getType())); + if (isa<LandingPadInst>(Inst)) + continue; + + Inst.eraseFromParent(); + ++NumGVNInstrDeleted; + } + // Now insert something that simplifycfg will turn into an unreachable. + Type *Int8Ty = Type::getInt8Ty(BB->getContext()); + new StoreInst(UndefValue::get(Int8Ty), + Constant::getNullValue(Int8Ty->getPointerTo()), + BB->getTerminator()); +} + +void NewGVN::markInstructionForDeletion(Instruction *I) { + DEBUG(dbgs() << "Marking " << *I << " for deletion\n"); + InstructionsToErase.insert(I); +} + +void NewGVN::replaceInstruction(Instruction *I, Value *V) { + DEBUG(dbgs() << "Replacing " << *I << " with " << *V << "\n"); + patchAndReplaceAllUsesWith(I, V); + // We save the actual erasing to avoid invalidating memory + // dependencies until we are done with everything. + markInstructionForDeletion(I); +} + +namespace { + +// This is a stack that contains both the value and dfs info of where +// that value is valid. +class ValueDFSStack { +public: + Value *back() const { return ValueStack.back(); } + std::pair<int, int> dfs_back() const { return DFSStack.back(); } + + void push_back(Value *V, int DFSIn, int DFSOut) { + ValueStack.emplace_back(V); + DFSStack.emplace_back(DFSIn, DFSOut); + } + + bool empty() const { return DFSStack.empty(); } + + bool isInScope(int DFSIn, int DFSOut) const { + if (empty()) + return false; + return DFSIn >= DFSStack.back().first && DFSOut <= DFSStack.back().second; + } + + void popUntilDFSScope(int DFSIn, int DFSOut) { + + // These two should always be in sync at this point. + assert(ValueStack.size() == DFSStack.size() && + "Mismatch between ValueStack and DFSStack"); + while ( + !DFSStack.empty() && + !(DFSIn >= DFSStack.back().first && DFSOut <= DFSStack.back().second)) { + DFSStack.pop_back(); + ValueStack.pop_back(); + } + } + +private: + SmallVector<Value *, 8> ValueStack; + SmallVector<std::pair<int, int>, 8> DFSStack; +}; + +} // end anonymous namespace + +// Given an expression, get the congruence class for it. +CongruenceClass *NewGVN::getClassForExpression(const Expression *E) const { + if (auto *VE = dyn_cast<VariableExpression>(E)) + return ValueToClass.lookup(VE->getVariableValue()); + else if (isa<DeadExpression>(E)) + return TOPClass; + return ExpressionToClass.lookup(E); +} + +// Given a value and a basic block we are trying to see if it is available in, +// see if the value has a leader available in that block. +Value *NewGVN::findPHIOfOpsLeader(const Expression *E, + const Instruction *OrigInst, + const BasicBlock *BB) const { + // It would already be constant if we could make it constant + if (auto *CE = dyn_cast<ConstantExpression>(E)) + return CE->getConstantValue(); + if (auto *VE = dyn_cast<VariableExpression>(E)) { + auto *V = VE->getVariableValue(); + if (alwaysAvailable(V) || DT->dominates(getBlockForValue(V), BB)) + return VE->getVariableValue(); + } + + auto *CC = getClassForExpression(E); + if (!CC) + return nullptr; + if (alwaysAvailable(CC->getLeader())) + return CC->getLeader(); + + for (auto Member : *CC) { + auto *MemberInst = dyn_cast<Instruction>(Member); + if (MemberInst == OrigInst) + continue; + // Anything that isn't an instruction is always available. + if (!MemberInst) + return Member; + if (DT->dominates(getBlockForValue(MemberInst), BB)) + return Member; + } + return nullptr; +} + +bool NewGVN::eliminateInstructions(Function &F) { + // This is a non-standard eliminator. The normal way to eliminate is + // to walk the dominator tree in order, keeping track of available + // values, and eliminating them. However, this is mildly + // pointless. It requires doing lookups on every instruction, + // regardless of whether we will ever eliminate it. For + // instructions part of most singleton congruence classes, we know we + // will never eliminate them. + + // Instead, this eliminator looks at the congruence classes directly, sorts + // them into a DFS ordering of the dominator tree, and then we just + // perform elimination straight on the sets by walking the congruence + // class member uses in order, and eliminate the ones dominated by the + // last member. This is worst case O(E log E) where E = number of + // instructions in a single congruence class. In theory, this is all + // instructions. In practice, it is much faster, as most instructions are + // either in singleton congruence classes or can't possibly be eliminated + // anyway (if there are no overlapping DFS ranges in class). + // When we find something not dominated, it becomes the new leader + // for elimination purposes. + // TODO: If we wanted to be faster, We could remove any members with no + // overlapping ranges while sorting, as we will never eliminate anything + // with those members, as they don't dominate anything else in our set. + + bool AnythingReplaced = false; + + // Since we are going to walk the domtree anyway, and we can't guarantee the + // DFS numbers are updated, we compute some ourselves. + DT->updateDFSNumbers(); + + // Go through all of our phi nodes, and kill the arguments associated with + // unreachable edges. + auto ReplaceUnreachablePHIArgs = [&](PHINode *PHI, BasicBlock *BB) { + for (auto &Operand : PHI->incoming_values()) + if (!ReachableEdges.count({PHI->getIncomingBlock(Operand), BB})) { + DEBUG(dbgs() << "Replacing incoming value of " << PHI << " for block " + << getBlockName(PHI->getIncomingBlock(Operand)) + << " with undef due to it being unreachable\n"); + Operand.set(UndefValue::get(PHI->getType())); + } + }; + // Replace unreachable phi arguments. + // At this point, RevisitOnReachabilityChange only contains: + // + // 1. PHIs + // 2. Temporaries that will convert to PHIs + // 3. Operations that are affected by an unreachable edge but do not fit into + // 1 or 2 (rare). + // So it is a slight overshoot of what we want. We could make it exact by + // using two SparseBitVectors per block. + DenseMap<const BasicBlock *, unsigned> ReachablePredCount; + for (auto &KV : ReachableEdges) + ReachablePredCount[KV.getEnd()]++; + for (auto &BBPair : RevisitOnReachabilityChange) { + for (auto InstNum : BBPair.second) { + auto *Inst = InstrFromDFSNum(InstNum); + auto *PHI = dyn_cast<PHINode>(Inst); + PHI = PHI ? PHI : dyn_cast_or_null<PHINode>(RealToTemp.lookup(Inst)); + if (!PHI) + continue; + auto *BB = BBPair.first; + if (ReachablePredCount.lookup(BB) != PHI->getNumIncomingValues()) + ReplaceUnreachablePHIArgs(PHI, BB); + } + } + + // Map to store the use counts + DenseMap<const Value *, unsigned int> UseCounts; + for (auto *CC : reverse(CongruenceClasses)) { + DEBUG(dbgs() << "Eliminating in congruence class " << CC->getID() << "\n"); + // Track the equivalent store info so we can decide whether to try + // dead store elimination. + SmallVector<ValueDFS, 8> PossibleDeadStores; + SmallPtrSet<Instruction *, 8> ProbablyDead; + if (CC->isDead() || CC->empty()) + continue; + // Everything still in the TOP class is unreachable or dead. + if (CC == TOPClass) { + for (auto M : *CC) { + auto *VTE = ValueToExpression.lookup(M); + if (VTE && isa<DeadExpression>(VTE)) + markInstructionForDeletion(cast<Instruction>(M)); + assert((!ReachableBlocks.count(cast<Instruction>(M)->getParent()) || + InstructionsToErase.count(cast<Instruction>(M))) && + "Everything in TOP should be unreachable or dead at this " + "point"); + } + continue; + } + + assert(CC->getLeader() && "We should have had a leader"); + // If this is a leader that is always available, and it's a + // constant or has no equivalences, just replace everything with + // it. We then update the congruence class with whatever members + // are left. + Value *Leader = + CC->getStoredValue() ? CC->getStoredValue() : CC->getLeader(); + if (alwaysAvailable(Leader)) { + CongruenceClass::MemberSet MembersLeft; + for (auto M : *CC) { + Value *Member = M; + // Void things have no uses we can replace. + if (Member == Leader || !isa<Instruction>(Member) || + Member->getType()->isVoidTy()) { + MembersLeft.insert(Member); + continue; + } + DEBUG(dbgs() << "Found replacement " << *(Leader) << " for " << *Member + << "\n"); + auto *I = cast<Instruction>(Member); + assert(Leader != I && "About to accidentally remove our leader"); + replaceInstruction(I, Leader); + AnythingReplaced = true; + } + CC->swap(MembersLeft); + } else { + // If this is a singleton, we can skip it. + if (CC->size() != 1 || RealToTemp.count(Leader)) { + // This is a stack because equality replacement/etc may place + // constants in the middle of the member list, and we want to use + // those constant values in preference to the current leader, over + // the scope of those constants. + ValueDFSStack EliminationStack; + + // Convert the members to DFS ordered sets and then merge them. + SmallVector<ValueDFS, 8> DFSOrderedSet; + convertClassToDFSOrdered(*CC, DFSOrderedSet, UseCounts, ProbablyDead); + + // Sort the whole thing. + std::sort(DFSOrderedSet.begin(), DFSOrderedSet.end()); + for (auto &VD : DFSOrderedSet) { + int MemberDFSIn = VD.DFSIn; + int MemberDFSOut = VD.DFSOut; + Value *Def = VD.Def.getPointer(); + bool FromStore = VD.Def.getInt(); + Use *U = VD.U; + // We ignore void things because we can't get a value from them. + if (Def && Def->getType()->isVoidTy()) + continue; + auto *DefInst = dyn_cast_or_null<Instruction>(Def); + if (DefInst && AllTempInstructions.count(DefInst)) { + auto *PN = cast<PHINode>(DefInst); + + // If this is a value phi and that's the expression we used, insert + // it into the program + // remove from temp instruction list. + AllTempInstructions.erase(PN); + auto *DefBlock = getBlockForValue(Def); + DEBUG(dbgs() << "Inserting fully real phi of ops" << *Def + << " into block " + << getBlockName(getBlockForValue(Def)) << "\n"); + PN->insertBefore(&DefBlock->front()); + Def = PN; + NumGVNPHIOfOpsEliminations++; + } + + if (EliminationStack.empty()) { + DEBUG(dbgs() << "Elimination Stack is empty\n"); + } else { + DEBUG(dbgs() << "Elimination Stack Top DFS numbers are (" + << EliminationStack.dfs_back().first << "," + << EliminationStack.dfs_back().second << ")\n"); + } + + DEBUG(dbgs() << "Current DFS numbers are (" << MemberDFSIn << "," + << MemberDFSOut << ")\n"); + // First, we see if we are out of scope or empty. If so, + // and there equivalences, we try to replace the top of + // stack with equivalences (if it's on the stack, it must + // not have been eliminated yet). + // Then we synchronize to our current scope, by + // popping until we are back within a DFS scope that + // dominates the current member. + // Then, what happens depends on a few factors + // If the stack is now empty, we need to push + // If we have a constant or a local equivalence we want to + // start using, we also push. + // Otherwise, we walk along, processing members who are + // dominated by this scope, and eliminate them. + bool ShouldPush = Def && EliminationStack.empty(); + bool OutOfScope = + !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut); + + if (OutOfScope || ShouldPush) { + // Sync to our current scope. + EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut); + bool ShouldPush = Def && EliminationStack.empty(); + if (ShouldPush) { + EliminationStack.push_back(Def, MemberDFSIn, MemberDFSOut); + } + } + + // Skip the Def's, we only want to eliminate on their uses. But mark + // dominated defs as dead. + if (Def) { + // For anything in this case, what and how we value number + // guarantees that any side-effets that would have occurred (ie + // throwing, etc) can be proven to either still occur (because it's + // dominated by something that has the same side-effects), or never + // occur. Otherwise, we would not have been able to prove it value + // equivalent to something else. For these things, we can just mark + // it all dead. Note that this is different from the "ProbablyDead" + // set, which may not be dominated by anything, and thus, are only + // easy to prove dead if they are also side-effect free. Note that + // because stores are put in terms of the stored value, we skip + // stored values here. If the stored value is really dead, it will + // still be marked for deletion when we process it in its own class. + if (!EliminationStack.empty() && Def != EliminationStack.back() && + isa<Instruction>(Def) && !FromStore) + markInstructionForDeletion(cast<Instruction>(Def)); + continue; + } + // At this point, we know it is a Use we are trying to possibly + // replace. + + assert(isa<Instruction>(U->get()) && + "Current def should have been an instruction"); + assert(isa<Instruction>(U->getUser()) && + "Current user should have been an instruction"); + + // If the thing we are replacing into is already marked to be dead, + // this use is dead. Note that this is true regardless of whether + // we have anything dominating the use or not. We do this here + // because we are already walking all the uses anyway. + Instruction *InstUse = cast<Instruction>(U->getUser()); + if (InstructionsToErase.count(InstUse)) { + auto &UseCount = UseCounts[U->get()]; + if (--UseCount == 0) { + ProbablyDead.insert(cast<Instruction>(U->get())); + } + } + + // If we get to this point, and the stack is empty we must have a use + // with nothing we can use to eliminate this use, so just skip it. + if (EliminationStack.empty()) + continue; + + Value *DominatingLeader = EliminationStack.back(); + + auto *II = dyn_cast<IntrinsicInst>(DominatingLeader); + if (II && II->getIntrinsicID() == Intrinsic::ssa_copy) + DominatingLeader = II->getOperand(0); + + // Don't replace our existing users with ourselves. + if (U->get() == DominatingLeader) + continue; + DEBUG(dbgs() << "Found replacement " << *DominatingLeader << " for " + << *U->get() << " in " << *(U->getUser()) << "\n"); + + // If we replaced something in an instruction, handle the patching of + // metadata. Skip this if we are replacing predicateinfo with its + // original operand, as we already know we can just drop it. + auto *ReplacedInst = cast<Instruction>(U->get()); + auto *PI = PredInfo->getPredicateInfoFor(ReplacedInst); + if (!PI || DominatingLeader != PI->OriginalOp) + patchReplacementInstruction(ReplacedInst, DominatingLeader); + U->set(DominatingLeader); + // This is now a use of the dominating leader, which means if the + // dominating leader was dead, it's now live! + auto &LeaderUseCount = UseCounts[DominatingLeader]; + // It's about to be alive again. + if (LeaderUseCount == 0 && isa<Instruction>(DominatingLeader)) + ProbablyDead.erase(cast<Instruction>(DominatingLeader)); + if (LeaderUseCount == 0 && II) + ProbablyDead.insert(II); + ++LeaderUseCount; + AnythingReplaced = true; + } + } + } + + // At this point, anything still in the ProbablyDead set is actually dead if + // would be trivially dead. + for (auto *I : ProbablyDead) + if (wouldInstructionBeTriviallyDead(I)) + markInstructionForDeletion(I); + + // Cleanup the congruence class. + CongruenceClass::MemberSet MembersLeft; + for (auto *Member : *CC) + if (!isa<Instruction>(Member) || + !InstructionsToErase.count(cast<Instruction>(Member))) + MembersLeft.insert(Member); + CC->swap(MembersLeft); + + // If we have possible dead stores to look at, try to eliminate them. + if (CC->getStoreCount() > 0) { + convertClassToLoadsAndStores(*CC, PossibleDeadStores); + std::sort(PossibleDeadStores.begin(), PossibleDeadStores.end()); + ValueDFSStack EliminationStack; + for (auto &VD : PossibleDeadStores) { + int MemberDFSIn = VD.DFSIn; + int MemberDFSOut = VD.DFSOut; + Instruction *Member = cast<Instruction>(VD.Def.getPointer()); + if (EliminationStack.empty() || + !EliminationStack.isInScope(MemberDFSIn, MemberDFSOut)) { + // Sync to our current scope. + EliminationStack.popUntilDFSScope(MemberDFSIn, MemberDFSOut); + if (EliminationStack.empty()) { + EliminationStack.push_back(Member, MemberDFSIn, MemberDFSOut); + continue; + } + } + // We already did load elimination, so nothing to do here. + if (isa<LoadInst>(Member)) + continue; + assert(!EliminationStack.empty()); + Instruction *Leader = cast<Instruction>(EliminationStack.back()); + (void)Leader; + assert(DT->dominates(Leader->getParent(), Member->getParent())); + // Member is dominater by Leader, and thus dead + DEBUG(dbgs() << "Marking dead store " << *Member + << " that is dominated by " << *Leader << "\n"); + markInstructionForDeletion(Member); + CC->erase(Member); + ++NumGVNDeadStores; + } + } + } + return AnythingReplaced; +} + +// This function provides global ranking of operations so that we can place them +// in a canonical order. Note that rank alone is not necessarily enough for a +// complete ordering, as constants all have the same rank. However, generally, +// we will simplify an operation with all constants so that it doesn't matter +// what order they appear in. +unsigned int NewGVN::getRank(const Value *V) const { + // Prefer constants to undef to anything else + // Undef is a constant, have to check it first. + // Prefer smaller constants to constantexprs + if (isa<ConstantExpr>(V)) + return 2; + if (isa<UndefValue>(V)) + return 1; + if (isa<Constant>(V)) + return 0; + else if (auto *A = dyn_cast<Argument>(V)) + return 3 + A->getArgNo(); + + // Need to shift the instruction DFS by number of arguments + 3 to account for + // the constant and argument ranking above. + unsigned Result = InstrToDFSNum(V); + if (Result > 0) + return 4 + NumFuncArgs + Result; + // Unreachable or something else, just return a really large number. + return ~0; +} + +// This is a function that says whether two commutative operations should +// have their order swapped when canonicalizing. +bool NewGVN::shouldSwapOperands(const Value *A, const Value *B) const { + // Because we only care about a total ordering, and don't rewrite expressions + // in this order, we order by rank, which will give a strict weak ordering to + // everything but constants, and then we order by pointer address. + return std::make_pair(getRank(A), A) > std::make_pair(getRank(B), B); +} + +namespace { + +class NewGVNLegacyPass : public FunctionPass { +public: + // Pass identification, replacement for typeid. + static char ID; + + NewGVNLegacyPass() : FunctionPass(ID) { + initializeNewGVNLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override; + +private: + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<MemorySSAWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; + +} // end anonymous namespace + +bool NewGVNLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + return NewGVN(F, &getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F), + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(), + &getAnalysis<AAResultsWrapperPass>().getAAResults(), + &getAnalysis<MemorySSAWrapperPass>().getMSSA(), + F.getParent()->getDataLayout()) + .runGVN(); +} + +char NewGVNLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(NewGVNLegacyPass, "newgvn", "Global Value Numbering", + false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(MemorySSAWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass) +INITIALIZE_PASS_END(NewGVNLegacyPass, "newgvn", "Global Value Numbering", false, + false) + +// createGVNPass - The public interface to this file. +FunctionPass *llvm::createNewGVNPass() { return new NewGVNLegacyPass(); } + +PreservedAnalyses NewGVNPass::run(Function &F, AnalysisManager<Function> &AM) { + // Apparently the order in which we get these results matter for + // the old GVN (see Chandler's comment in GVN.cpp). I'll keep + // the same order here, just in case. + auto &AC = AM.getResult<AssumptionAnalysis>(F); + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto &MSSA = AM.getResult<MemorySSAAnalysis>(F).getMSSA(); + bool Changed = + NewGVN(F, &DT, &AC, &TLI, &AA, &MSSA, F.getParent()->getDataLayout()) + .runGVN(); + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<DominatorTreeAnalysis>(); + PA.preserve<GlobalsAA>(); + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp b/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp new file mode 100644 index 000000000000..1748815c5941 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/PartiallyInlineLibCalls.cpp @@ -0,0 +1,180 @@ +//===--- PartiallyInlineLibCalls.cpp - Partially inline libcalls ----------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass tries to partially inline the fast path of well-known library +// functions, such as using square-root instructions for cases where sqrt() +// does not need to set errno. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/PartiallyInlineLibCalls.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "partially-inline-libcalls" + + +static bool optimizeSQRT(CallInst *Call, Function *CalledFunc, + BasicBlock &CurrBB, Function::iterator &BB, + const TargetTransformInfo *TTI) { + // There is no need to change the IR, since backend will emit sqrt + // instruction if the call has already been marked read-only. + if (Call->onlyReadsMemory()) + return false; + + // Do the following transformation: + // + // (before) + // dst = sqrt(src) + // + // (after) + // v0 = sqrt_noreadmem(src) # native sqrt instruction. + // [if (v0 is a NaN) || if (src < 0)] + // v1 = sqrt(src) # library call. + // dst = phi(v0, v1) + // + + // Move all instructions following Call to newly created block JoinBB. + // Create phi and replace all uses. + BasicBlock *JoinBB = llvm::SplitBlock(&CurrBB, Call->getNextNode()); + IRBuilder<> Builder(JoinBB, JoinBB->begin()); + Type *Ty = Call->getType(); + PHINode *Phi = Builder.CreatePHI(Ty, 2); + Call->replaceAllUsesWith(Phi); + + // Create basic block LibCallBB and insert a call to library function sqrt. + BasicBlock *LibCallBB = BasicBlock::Create(CurrBB.getContext(), "call.sqrt", + CurrBB.getParent(), JoinBB); + Builder.SetInsertPoint(LibCallBB); + Instruction *LibCall = Call->clone(); + Builder.Insert(LibCall); + Builder.CreateBr(JoinBB); + + // Add attribute "readnone" so that backend can use a native sqrt instruction + // for this call. Insert a FP compare instruction and a conditional branch + // at the end of CurrBB. + Call->addAttribute(AttributeList::FunctionIndex, Attribute::ReadNone); + CurrBB.getTerminator()->eraseFromParent(); + Builder.SetInsertPoint(&CurrBB); + Value *FCmp = TTI->isFCmpOrdCheaperThanFCmpZero(Ty) + ? Builder.CreateFCmpORD(Call, Call) + : Builder.CreateFCmpOGE(Call->getOperand(0), + ConstantFP::get(Ty, 0.0)); + Builder.CreateCondBr(FCmp, JoinBB, LibCallBB); + + // Add phi operands. + Phi->addIncoming(Call, &CurrBB); + Phi->addIncoming(LibCall, LibCallBB); + + BB = JoinBB->getIterator(); + return true; +} + +static bool runPartiallyInlineLibCalls(Function &F, TargetLibraryInfo *TLI, + const TargetTransformInfo *TTI) { + bool Changed = false; + + Function::iterator CurrBB; + for (Function::iterator BB = F.begin(), BE = F.end(); BB != BE;) { + CurrBB = BB++; + + for (BasicBlock::iterator II = CurrBB->begin(), IE = CurrBB->end(); + II != IE; ++II) { + CallInst *Call = dyn_cast<CallInst>(&*II); + Function *CalledFunc; + + if (!Call || !(CalledFunc = Call->getCalledFunction())) + continue; + + if (Call->isNoBuiltin()) + continue; + + // Skip if function either has local linkage or is not a known library + // function. + LibFunc LF; + if (CalledFunc->hasLocalLinkage() || + !TLI->getLibFunc(*CalledFunc, LF) || !TLI->has(LF)) + continue; + + switch (LF) { + case LibFunc_sqrtf: + case LibFunc_sqrt: + if (TTI->haveFastSqrt(Call->getType()) && + optimizeSQRT(Call, CalledFunc, *CurrBB, BB, TTI)) + break; + continue; + default: + continue; + } + + Changed = true; + break; + } + } + + return Changed; +} + +PreservedAnalyses +PartiallyInlineLibCallsPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + if (!runPartiallyInlineLibCalls(F, &TLI, &TTI)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { +class PartiallyInlineLibCallsLegacyPass : public FunctionPass { +public: + static char ID; + + PartiallyInlineLibCallsLegacyPass() : FunctionPass(ID) { + initializePartiallyInlineLibCallsLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + FunctionPass::getAnalysisUsage(AU); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + const TargetTransformInfo *TTI = + &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return runPartiallyInlineLibCalls(F, TLI, TTI); + } +}; +} + +char PartiallyInlineLibCallsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(PartiallyInlineLibCallsLegacyPass, + "partially-inline-libcalls", + "Partially inline calls to library functions", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(PartiallyInlineLibCallsLegacyPass, + "partially-inline-libcalls", + "Partially inline calls to library functions", false, false) + +FunctionPass *llvm::createPartiallyInlineLibCallsPass() { + return new PartiallyInlineLibCallsLegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp b/contrib/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp new file mode 100644 index 000000000000..2d0cb6fbf211 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/PlaceSafepoints.cpp @@ -0,0 +1,691 @@ +//===- PlaceSafepoints.cpp - Place GC Safepoints --------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Place garbage collection safepoints at appropriate locations in the IR. This +// does not make relocation semantics or variable liveness explicit. That's +// done by RewriteStatepointsForGC. +// +// Terminology: +// - A call is said to be "parseable" if there is a stack map generated for the +// return PC of the call. A runtime can determine where values listed in the +// deopt arguments and (after RewriteStatepointsForGC) gc arguments are located +// on the stack when the code is suspended inside such a call. Every parse +// point is represented by a call wrapped in an gc.statepoint intrinsic. +// - A "poll" is an explicit check in the generated code to determine if the +// runtime needs the generated code to cooperate by calling a helper routine +// and thus suspending its execution at a known state. The call to the helper +// routine will be parseable. The (gc & runtime specific) logic of a poll is +// assumed to be provided in a function of the name "gc.safepoint_poll". +// +// We aim to insert polls such that running code can quickly be brought to a +// well defined state for inspection by the collector. In the current +// implementation, this is done via the insertion of poll sites at method entry +// and the backedge of most loops. We try to avoid inserting more polls than +// are necessary to ensure a finite period between poll sites. This is not +// because the poll itself is expensive in the generated code; it's not. Polls +// do tend to impact the optimizer itself in negative ways; we'd like to avoid +// perturbing the optimization of the method as much as we can. +// +// We also need to make most call sites parseable. The callee might execute a +// poll (or otherwise be inspected by the GC). If so, the entire stack +// (including the suspended frame of the current method) must be parseable. +// +// This pass will insert: +// - Call parse points ("call safepoints") for any call which may need to +// reach a safepoint during the execution of the callee function. +// - Backedge safepoint polls and entry safepoint polls to ensure that +// executing code reaches a safepoint poll in a finite amount of time. +// +// We do not currently support return statepoints, but adding them would not +// be hard. They are not required for correctness - entry safepoints are an +// alternative - but some GCs may prefer them. Patches welcome. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Pass.h" + +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Statepoint.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/Local.h" + +#define DEBUG_TYPE "safepoint-placement" + +STATISTIC(NumEntrySafepoints, "Number of entry safepoints inserted"); +STATISTIC(NumBackedgeSafepoints, "Number of backedge safepoints inserted"); + +STATISTIC(CallInLoop, + "Number of loops without safepoints due to calls in loop"); +STATISTIC(FiniteExecution, + "Number of loops without safepoints finite execution"); + +using namespace llvm; + +// Ignore opportunities to avoid placing safepoints on backedges, useful for +// validation +static cl::opt<bool> AllBackedges("spp-all-backedges", cl::Hidden, + cl::init(false)); + +/// How narrow does the trip count of a loop have to be to have to be considered +/// "counted"? Counted loops do not get safepoints at backedges. +static cl::opt<int> CountedLoopTripWidth("spp-counted-loop-trip-width", + cl::Hidden, cl::init(32)); + +// If true, split the backedge of a loop when placing the safepoint, otherwise +// split the latch block itself. Both are useful to support for +// experimentation, but in practice, it looks like splitting the backedge +// optimizes better. +static cl::opt<bool> SplitBackedge("spp-split-backedge", cl::Hidden, + cl::init(false)); + +namespace { + +/// An analysis pass whose purpose is to identify each of the backedges in +/// the function which require a safepoint poll to be inserted. +struct PlaceBackedgeSafepointsImpl : public FunctionPass { + static char ID; + + /// The output of the pass - gives a list of each backedge (described by + /// pointing at the branch) which need a poll inserted. + std::vector<TerminatorInst *> PollLocations; + + /// True unless we're running spp-no-calls in which case we need to disable + /// the call-dependent placement opts. + bool CallSafepointsEnabled; + + ScalarEvolution *SE = nullptr; + DominatorTree *DT = nullptr; + LoopInfo *LI = nullptr; + TargetLibraryInfo *TLI = nullptr; + + PlaceBackedgeSafepointsImpl(bool CallSafepoints = false) + : FunctionPass(ID), CallSafepointsEnabled(CallSafepoints) { + initializePlaceBackedgeSafepointsImplPass(*PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *); + void runOnLoopAndSubLoops(Loop *L) { + // Visit all the subloops + for (Loop *I : *L) + runOnLoopAndSubLoops(I); + runOnLoop(L); + } + + bool runOnFunction(Function &F) override { + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + for (Loop *I : *LI) { + runOnLoopAndSubLoops(I); + } + return false; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + // We no longer modify the IR at all in this pass. Thus all + // analysis are preserved. + AU.setPreservesAll(); + } +}; +} + +static cl::opt<bool> NoEntry("spp-no-entry", cl::Hidden, cl::init(false)); +static cl::opt<bool> NoCall("spp-no-call", cl::Hidden, cl::init(false)); +static cl::opt<bool> NoBackedge("spp-no-backedge", cl::Hidden, cl::init(false)); + +namespace { +struct PlaceSafepoints : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + + PlaceSafepoints() : FunctionPass(ID) { + initializePlaceSafepointsPass(*PassRegistry::getPassRegistry()); + } + bool runOnFunction(Function &F) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + // We modify the graph wholesale (inlining, block insertion, etc). We + // preserve nothing at the moment. We could potentially preserve dom tree + // if that was worth doing + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } +}; +} + +// Insert a safepoint poll immediately before the given instruction. Does +// not handle the parsability of state at the runtime call, that's the +// callers job. +static void +InsertSafepointPoll(Instruction *InsertBefore, + std::vector<CallSite> &ParsePointsNeeded /*rval*/, + const TargetLibraryInfo &TLI); + +static bool needsStatepoint(const CallSite &CS, const TargetLibraryInfo &TLI) { + if (callsGCLeafFunction(CS, TLI)) + return false; + if (CS.isCall()) { + CallInst *call = cast<CallInst>(CS.getInstruction()); + if (call->isInlineAsm()) + return false; + } + + return !(isStatepoint(CS) || isGCRelocate(CS) || isGCResult(CS)); +} + +/// Returns true if this loop is known to contain a call safepoint which +/// must unconditionally execute on any iteration of the loop which returns +/// to the loop header via an edge from Pred. Returns a conservative correct +/// answer; i.e. false is always valid. +static bool containsUnconditionalCallSafepoint(Loop *L, BasicBlock *Header, + BasicBlock *Pred, + DominatorTree &DT, + const TargetLibraryInfo &TLI) { + // In general, we're looking for any cut of the graph which ensures + // there's a call safepoint along every edge between Header and Pred. + // For the moment, we look only for the 'cuts' that consist of a single call + // instruction in a block which is dominated by the Header and dominates the + // loop latch (Pred) block. Somewhat surprisingly, walking the entire chain + // of such dominating blocks gets substantially more occurrences than just + // checking the Pred and Header blocks themselves. This may be due to the + // density of loop exit conditions caused by range and null checks. + // TODO: structure this as an analysis pass, cache the result for subloops, + // avoid dom tree recalculations + assert(DT.dominates(Header, Pred) && "loop latch not dominated by header?"); + + BasicBlock *Current = Pred; + while (true) { + for (Instruction &I : *Current) { + if (auto CS = CallSite(&I)) + // Note: Technically, needing a safepoint isn't quite the right + // condition here. We should instead be checking if the target method + // has an + // unconditional poll. In practice, this is only a theoretical concern + // since we don't have any methods with conditional-only safepoint + // polls. + if (needsStatepoint(CS, TLI)) + return true; + } + + if (Current == Header) + break; + Current = DT.getNode(Current)->getIDom()->getBlock(); + } + + return false; +} + +/// Returns true if this loop is known to terminate in a finite number of +/// iterations. Note that this function may return false for a loop which +/// does actual terminate in a finite constant number of iterations due to +/// conservatism in the analysis. +static bool mustBeFiniteCountedLoop(Loop *L, ScalarEvolution *SE, + BasicBlock *Pred) { + // A conservative bound on the loop as a whole. + const SCEV *MaxTrips = SE->getMaxBackedgeTakenCount(L); + if (MaxTrips != SE->getCouldNotCompute() && + SE->getUnsignedRange(MaxTrips).getUnsignedMax().isIntN( + CountedLoopTripWidth)) + return true; + + // If this is a conditional branch to the header with the alternate path + // being outside the loop, we can ask questions about the execution frequency + // of the exit block. + if (L->isLoopExiting(Pred)) { + // This returns an exact expression only. TODO: We really only need an + // upper bound here, but SE doesn't expose that. + const SCEV *MaxExec = SE->getExitCount(L, Pred); + if (MaxExec != SE->getCouldNotCompute() && + SE->getUnsignedRange(MaxExec).getUnsignedMax().isIntN( + CountedLoopTripWidth)) + return true; + } + + return /* not finite */ false; +} + +static void scanOneBB(Instruction *Start, Instruction *End, + std::vector<CallInst *> &Calls, + DenseSet<BasicBlock *> &Seen, + std::vector<BasicBlock *> &Worklist) { + for (BasicBlock::iterator BBI(Start), BBE0 = Start->getParent()->end(), + BBE1 = BasicBlock::iterator(End); + BBI != BBE0 && BBI != BBE1; BBI++) { + if (CallInst *CI = dyn_cast<CallInst>(&*BBI)) + Calls.push_back(CI); + + // FIXME: This code does not handle invokes + assert(!isa<InvokeInst>(&*BBI) && + "support for invokes in poll code needed"); + + // Only add the successor blocks if we reach the terminator instruction + // without encountering end first + if (BBI->isTerminator()) { + BasicBlock *BB = BBI->getParent(); + for (BasicBlock *Succ : successors(BB)) { + if (Seen.insert(Succ).second) { + Worklist.push_back(Succ); + } + } + } + } +} + +static void scanInlinedCode(Instruction *Start, Instruction *End, + std::vector<CallInst *> &Calls, + DenseSet<BasicBlock *> &Seen) { + Calls.clear(); + std::vector<BasicBlock *> Worklist; + Seen.insert(Start->getParent()); + scanOneBB(Start, End, Calls, Seen, Worklist); + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.back(); + Worklist.pop_back(); + scanOneBB(&*BB->begin(), End, Calls, Seen, Worklist); + } +} + +bool PlaceBackedgeSafepointsImpl::runOnLoop(Loop *L) { + // Loop through all loop latches (branches controlling backedges). We need + // to place a safepoint on every backedge (potentially). + // Note: In common usage, there will be only one edge due to LoopSimplify + // having run sometime earlier in the pipeline, but this code must be correct + // w.r.t. loops with multiple backedges. + BasicBlock *Header = L->getHeader(); + SmallVector<BasicBlock*, 16> LoopLatches; + L->getLoopLatches(LoopLatches); + for (BasicBlock *Pred : LoopLatches) { + assert(L->contains(Pred)); + + // Make a policy decision about whether this loop needs a safepoint or + // not. Note that this is about unburdening the optimizer in loops, not + // avoiding the runtime cost of the actual safepoint. + if (!AllBackedges) { + if (mustBeFiniteCountedLoop(L, SE, Pred)) { + DEBUG(dbgs() << "skipping safepoint placement in finite loop\n"); + FiniteExecution++; + continue; + } + if (CallSafepointsEnabled && + containsUnconditionalCallSafepoint(L, Header, Pred, *DT, *TLI)) { + // Note: This is only semantically legal since we won't do any further + // IPO or inlining before the actual call insertion.. If we hadn't, we + // might latter loose this call safepoint. + DEBUG(dbgs() << "skipping safepoint placement due to unconditional call\n"); + CallInLoop++; + continue; + } + } + + // TODO: We can create an inner loop which runs a finite number of + // iterations with an outer loop which contains a safepoint. This would + // not help runtime performance that much, but it might help our ability to + // optimize the inner loop. + + // Safepoint insertion would involve creating a new basic block (as the + // target of the current backedge) which does the safepoint (of all live + // variables) and branches to the true header + TerminatorInst *Term = Pred->getTerminator(); + + DEBUG(dbgs() << "[LSP] terminator instruction: " << *Term); + + PollLocations.push_back(Term); + } + + return false; +} + +/// Returns true if an entry safepoint is not required before this callsite in +/// the caller function. +static bool doesNotRequireEntrySafepointBefore(const CallSite &CS) { + Instruction *Inst = CS.getInstruction(); + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst)) { + switch (II->getIntrinsicID()) { + case Intrinsic::experimental_gc_statepoint: + case Intrinsic::experimental_patchpoint_void: + case Intrinsic::experimental_patchpoint_i64: + // The can wrap an actual call which may grow the stack by an unbounded + // amount or run forever. + return false; + default: + // Most LLVM intrinsics are things which do not expand to actual calls, or + // at least if they do, are leaf functions that cause only finite stack + // growth. In particular, the optimizer likes to form things like memsets + // out of stores in the original IR. Another important example is + // llvm.localescape which must occur in the entry block. Inserting a + // safepoint before it is not legal since it could push the localescape + // out of the entry block. + return true; + } + } + return false; +} + +static Instruction *findLocationForEntrySafepoint(Function &F, + DominatorTree &DT) { + + // Conceptually, this poll needs to be on method entry, but in + // practice, we place it as late in the entry block as possible. We + // can place it as late as we want as long as it dominates all calls + // that can grow the stack. This, combined with backedge polls, + // give us all the progress guarantees we need. + + // hasNextInstruction and nextInstruction are used to iterate + // through a "straight line" execution sequence. + + auto HasNextInstruction = [](Instruction *I) { + if (!I->isTerminator()) + return true; + + BasicBlock *nextBB = I->getParent()->getUniqueSuccessor(); + return nextBB && (nextBB->getUniquePredecessor() != nullptr); + }; + + auto NextInstruction = [&](Instruction *I) { + assert(HasNextInstruction(I) && + "first check if there is a next instruction!"); + + if (I->isTerminator()) + return &I->getParent()->getUniqueSuccessor()->front(); + return &*++I->getIterator(); + }; + + Instruction *Cursor = nullptr; + for (Cursor = &F.getEntryBlock().front(); HasNextInstruction(Cursor); + Cursor = NextInstruction(Cursor)) { + + // We need to ensure a safepoint poll occurs before any 'real' call. The + // easiest way to ensure finite execution between safepoints in the face of + // recursive and mutually recursive functions is to enforce that each take + // a safepoint. Additionally, we need to ensure a poll before any call + // which can grow the stack by an unbounded amount. This isn't required + // for GC semantics per se, but is a common requirement for languages + // which detect stack overflow via guard pages and then throw exceptions. + if (auto CS = CallSite(Cursor)) { + if (doesNotRequireEntrySafepointBefore(CS)) + continue; + break; + } + } + + assert((HasNextInstruction(Cursor) || Cursor->isTerminator()) && + "either we stopped because of a call, or because of terminator"); + + return Cursor; +} + +static const char *const GCSafepointPollName = "gc.safepoint_poll"; + +static bool isGCSafepointPoll(Function &F) { + return F.getName().equals(GCSafepointPollName); +} + +/// Returns true if this function should be rewritten to include safepoint +/// polls and parseable call sites. The main point of this function is to be +/// an extension point for custom logic. +static bool shouldRewriteFunction(Function &F) { + // TODO: This should check the GCStrategy + if (F.hasGC()) { + const auto &FunctionGCName = F.getGC(); + const StringRef StatepointExampleName("statepoint-example"); + const StringRef CoreCLRName("coreclr"); + return (StatepointExampleName == FunctionGCName) || + (CoreCLRName == FunctionGCName); + } else + return false; +} + +// TODO: These should become properties of the GCStrategy, possibly with +// command line overrides. +static bool enableEntrySafepoints(Function &F) { return !NoEntry; } +static bool enableBackedgeSafepoints(Function &F) { return !NoBackedge; } +static bool enableCallSafepoints(Function &F) { return !NoCall; } + +bool PlaceSafepoints::runOnFunction(Function &F) { + if (F.isDeclaration() || F.empty()) { + // This is a declaration, nothing to do. Must exit early to avoid crash in + // dom tree calculation + return false; + } + + if (isGCSafepointPoll(F)) { + // Given we're inlining this inside of safepoint poll insertion, this + // doesn't make any sense. Note that we do make any contained calls + // parseable after we inline a poll. + return false; + } + + if (!shouldRewriteFunction(F)) + return false; + + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + + bool Modified = false; + + // In various bits below, we rely on the fact that uses are reachable from + // defs. When there are basic blocks unreachable from the entry, dominance + // and reachablity queries return non-sensical results. Thus, we preprocess + // the function to ensure these properties hold. + Modified |= removeUnreachableBlocks(F); + + // STEP 1 - Insert the safepoint polling locations. We do not need to + // actually insert parse points yet. That will be done for all polls and + // calls in a single pass. + + DominatorTree DT; + DT.recalculate(F); + + SmallVector<Instruction *, 16> PollsNeeded; + std::vector<CallSite> ParsePointNeeded; + + if (enableBackedgeSafepoints(F)) { + // Construct a pass manager to run the LoopPass backedge logic. We + // need the pass manager to handle scheduling all the loop passes + // appropriately. Doing this by hand is painful and just not worth messing + // with for the moment. + legacy::FunctionPassManager FPM(F.getParent()); + bool CanAssumeCallSafepoints = enableCallSafepoints(F); + auto *PBS = new PlaceBackedgeSafepointsImpl(CanAssumeCallSafepoints); + FPM.add(PBS); + FPM.run(F); + + // We preserve dominance information when inserting the poll, otherwise + // we'd have to recalculate this on every insert + DT.recalculate(F); + + auto &PollLocations = PBS->PollLocations; + + auto OrderByBBName = [](Instruction *a, Instruction *b) { + return a->getParent()->getName() < b->getParent()->getName(); + }; + // We need the order of list to be stable so that naming ends up stable + // when we split edges. This makes test cases much easier to write. + std::sort(PollLocations.begin(), PollLocations.end(), OrderByBBName); + + // We can sometimes end up with duplicate poll locations. This happens if + // a single loop is visited more than once. The fact this happens seems + // wrong, but it does happen for the split-backedge.ll test case. + PollLocations.erase(std::unique(PollLocations.begin(), + PollLocations.end()), + PollLocations.end()); + + // Insert a poll at each point the analysis pass identified + // The poll location must be the terminator of a loop latch block. + for (TerminatorInst *Term : PollLocations) { + // We are inserting a poll, the function is modified + Modified = true; + + if (SplitBackedge) { + // Split the backedge of the loop and insert the poll within that new + // basic block. This creates a loop with two latches per original + // latch (which is non-ideal), but this appears to be easier to + // optimize in practice than inserting the poll immediately before the + // latch test. + + // Since this is a latch, at least one of the successors must dominate + // it. Its possible that we have a) duplicate edges to the same header + // and b) edges to distinct loop headers. We need to insert pools on + // each. + SetVector<BasicBlock *> Headers; + for (unsigned i = 0; i < Term->getNumSuccessors(); i++) { + BasicBlock *Succ = Term->getSuccessor(i); + if (DT.dominates(Succ, Term->getParent())) { + Headers.insert(Succ); + } + } + assert(!Headers.empty() && "poll location is not a loop latch?"); + + // The split loop structure here is so that we only need to recalculate + // the dominator tree once. Alternatively, we could just keep it up to + // date and use a more natural merged loop. + SetVector<BasicBlock *> SplitBackedges; + for (BasicBlock *Header : Headers) { + BasicBlock *NewBB = SplitEdge(Term->getParent(), Header, &DT); + PollsNeeded.push_back(NewBB->getTerminator()); + NumBackedgeSafepoints++; + } + } else { + // Split the latch block itself, right before the terminator. + PollsNeeded.push_back(Term); + NumBackedgeSafepoints++; + } + } + } + + if (enableEntrySafepoints(F)) { + if (Instruction *Location = findLocationForEntrySafepoint(F, DT)) { + PollsNeeded.push_back(Location); + Modified = true; + NumEntrySafepoints++; + } + // TODO: else we should assert that there was, in fact, a policy choice to + // not insert a entry safepoint poll. + } + + // Now that we've identified all the needed safepoint poll locations, insert + // safepoint polls themselves. + for (Instruction *PollLocation : PollsNeeded) { + std::vector<CallSite> RuntimeCalls; + InsertSafepointPoll(PollLocation, RuntimeCalls, TLI); + ParsePointNeeded.insert(ParsePointNeeded.end(), RuntimeCalls.begin(), + RuntimeCalls.end()); + } + + return Modified; +} + +char PlaceBackedgeSafepointsImpl::ID = 0; +char PlaceSafepoints::ID = 0; + +FunctionPass *llvm::createPlaceSafepointsPass() { + return new PlaceSafepoints(); +} + +INITIALIZE_PASS_BEGIN(PlaceBackedgeSafepointsImpl, + "place-backedge-safepoints-impl", + "Place Backedge Safepoints", false, false) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_END(PlaceBackedgeSafepointsImpl, + "place-backedge-safepoints-impl", + "Place Backedge Safepoints", false, false) + +INITIALIZE_PASS_BEGIN(PlaceSafepoints, "place-safepoints", "Place Safepoints", + false, false) +INITIALIZE_PASS_END(PlaceSafepoints, "place-safepoints", "Place Safepoints", + false, false) + +static void +InsertSafepointPoll(Instruction *InsertBefore, + std::vector<CallSite> &ParsePointsNeeded /*rval*/, + const TargetLibraryInfo &TLI) { + BasicBlock *OrigBB = InsertBefore->getParent(); + Module *M = InsertBefore->getModule(); + assert(M && "must be part of a module"); + + // Inline the safepoint poll implementation - this will get all the branch, + // control flow, etc.. Most importantly, it will introduce the actual slow + // path call - where we need to insert a safepoint (parsepoint). + + auto *F = M->getFunction(GCSafepointPollName); + assert(F && "gc.safepoint_poll function is missing"); + assert(F->getValueType() == + FunctionType::get(Type::getVoidTy(M->getContext()), false) && + "gc.safepoint_poll declared with wrong type"); + assert(!F->empty() && "gc.safepoint_poll must be a non-empty function"); + CallInst *PollCall = CallInst::Create(F, "", InsertBefore); + + // Record some information about the call site we're replacing + BasicBlock::iterator Before(PollCall), After(PollCall); + bool IsBegin = false; + if (Before == OrigBB->begin()) + IsBegin = true; + else + Before--; + + After++; + assert(After != OrigBB->end() && "must have successor"); + + // Do the actual inlining + InlineFunctionInfo IFI; + bool InlineStatus = InlineFunction(PollCall, IFI); + assert(InlineStatus && "inline must succeed"); + (void)InlineStatus; // suppress warning in release-asserts + + // Check post-conditions + assert(IFI.StaticAllocas.empty() && "can't have allocs"); + + std::vector<CallInst *> Calls; // new calls + DenseSet<BasicBlock *> BBs; // new BBs + insertee + + // Include only the newly inserted instructions, Note: begin may not be valid + // if we inserted to the beginning of the basic block + BasicBlock::iterator Start = IsBegin ? OrigBB->begin() : std::next(Before); + + // If your poll function includes an unreachable at the end, that's not + // valid. Bugpoint likes to create this, so check for it. + assert(isPotentiallyReachable(&*Start, &*After) && + "malformed poll function"); + + scanInlinedCode(&*Start, &*After, Calls, BBs); + assert(!Calls.empty() && "slow path not found for safepoint poll"); + + // Record the fact we need a parsable state at the runtime call contained in + // the poll function. This is required so that the runtime knows how to + // parse the last frame when we actually take the safepoint (i.e. execute + // the slow path) + assert(ParsePointsNeeded.empty()); + for (auto *CI : Calls) { + // No safepoint needed or wanted + if (!needsStatepoint(CI, TLI)) + continue; + + // These are likely runtime calls. Should we assert that via calling + // convention or something? + ParsePointsNeeded.push_back(CallSite(CI)); + } + assert(ParsePointsNeeded.size() <= Calls.size()); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp b/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp new file mode 100644 index 000000000000..88dcaf0f8a36 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/Reassociate.cpp @@ -0,0 +1,2401 @@ +//===- Reassociate.cpp - Reassociate binary expressions -------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass reassociates commutative expressions in an order that is designed +// to promote better constant propagation, GCSE, LICM, PRE, etc. +// +// For example: 4 + (x + 5) -> x + (4 + 5) +// +// In the implementation of this algorithm, constants are assigned rank = 0, +// function arguments are rank = 1, and other values are assigned ranks +// corresponding to the reverse post order traversal of current function +// (starting at 2), which effectively gives values in deep loops higher rank +// than values not in loops. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/Reassociate.h" +#include "llvm/ADT/APFloat.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <algorithm> +#include <cassert> +#include <utility> + +using namespace llvm; +using namespace reassociate; + +#define DEBUG_TYPE "reassociate" + +STATISTIC(NumChanged, "Number of insts reassociated"); +STATISTIC(NumAnnihil, "Number of expr tree annihilated"); +STATISTIC(NumFactor , "Number of multiplies factored"); + +#ifndef NDEBUG +/// Print out the expression identified in the Ops list. +static void PrintOps(Instruction *I, const SmallVectorImpl<ValueEntry> &Ops) { + Module *M = I->getModule(); + dbgs() << Instruction::getOpcodeName(I->getOpcode()) << " " + << *Ops[0].Op->getType() << '\t'; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + dbgs() << "[ "; + Ops[i].Op->printAsOperand(dbgs(), false, M); + dbgs() << ", #" << Ops[i].Rank << "] "; + } +} +#endif + +/// Utility class representing a non-constant Xor-operand. We classify +/// non-constant Xor-Operands into two categories: +/// C1) The operand is in the form "X & C", where C is a constant and C != ~0 +/// C2) +/// C2.1) The operand is in the form of "X | C", where C is a non-zero +/// constant. +/// C2.2) Any operand E which doesn't fall into C1 and C2.1, we view this +/// operand as "E | 0" +class llvm::reassociate::XorOpnd { +public: + XorOpnd(Value *V); + + bool isInvalid() const { return SymbolicPart == nullptr; } + bool isOrExpr() const { return isOr; } + Value *getValue() const { return OrigVal; } + Value *getSymbolicPart() const { return SymbolicPart; } + unsigned getSymbolicRank() const { return SymbolicRank; } + const APInt &getConstPart() const { return ConstPart; } + + void Invalidate() { SymbolicPart = OrigVal = nullptr; } + void setSymbolicRank(unsigned R) { SymbolicRank = R; } + +private: + Value *OrigVal; + Value *SymbolicPart; + APInt ConstPart; + unsigned SymbolicRank; + bool isOr; +}; + +XorOpnd::XorOpnd(Value *V) { + assert(!isa<ConstantInt>(V) && "No ConstantInt"); + OrigVal = V; + Instruction *I = dyn_cast<Instruction>(V); + SymbolicRank = 0; + + if (I && (I->getOpcode() == Instruction::Or || + I->getOpcode() == Instruction::And)) { + Value *V0 = I->getOperand(0); + Value *V1 = I->getOperand(1); + const APInt *C; + if (match(V0, PatternMatch::m_APInt(C))) + std::swap(V0, V1); + + if (match(V1, PatternMatch::m_APInt(C))) { + ConstPart = *C; + SymbolicPart = V0; + isOr = (I->getOpcode() == Instruction::Or); + return; + } + } + + // view the operand as "V | 0" + SymbolicPart = V; + ConstPart = APInt::getNullValue(V->getType()->getScalarSizeInBits()); + isOr = true; +} + +/// Return true if V is an instruction of the specified opcode and if it +/// only has one use. +static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode) { + auto *I = dyn_cast<Instruction>(V); + if (I && I->hasOneUse() && I->getOpcode() == Opcode) + if (!isa<FPMathOperator>(I) || I->isFast()) + return cast<BinaryOperator>(I); + return nullptr; +} + +static BinaryOperator *isReassociableOp(Value *V, unsigned Opcode1, + unsigned Opcode2) { + auto *I = dyn_cast<Instruction>(V); + if (I && I->hasOneUse() && + (I->getOpcode() == Opcode1 || I->getOpcode() == Opcode2)) + if (!isa<FPMathOperator>(I) || I->isFast()) + return cast<BinaryOperator>(I); + return nullptr; +} + +void ReassociatePass::BuildRankMap(Function &F, + ReversePostOrderTraversal<Function*> &RPOT) { + unsigned Rank = 2; + + // Assign distinct ranks to function arguments. + for (auto &Arg : F.args()) { + ValueRankMap[&Arg] = ++Rank; + DEBUG(dbgs() << "Calculated Rank[" << Arg.getName() << "] = " << Rank + << "\n"); + } + + // Traverse basic blocks in ReversePostOrder + for (BasicBlock *BB : RPOT) { + unsigned BBRank = RankMap[BB] = ++Rank << 16; + + // Walk the basic block, adding precomputed ranks for any instructions that + // we cannot move. This ensures that the ranks for these instructions are + // all different in the block. + for (Instruction &I : *BB) + if (mayBeMemoryDependent(I)) + ValueRankMap[&I] = ++BBRank; + } +} + +unsigned ReassociatePass::getRank(Value *V) { + Instruction *I = dyn_cast<Instruction>(V); + if (!I) { + if (isa<Argument>(V)) return ValueRankMap[V]; // Function argument. + return 0; // Otherwise it's a global or constant, rank 0. + } + + if (unsigned Rank = ValueRankMap[I]) + return Rank; // Rank already known? + + // If this is an expression, return the 1+MAX(rank(LHS), rank(RHS)) so that + // we can reassociate expressions for code motion! Since we do not recurse + // for PHI nodes, we cannot have infinite recursion here, because there + // cannot be loops in the value graph that do not go through PHI nodes. + unsigned Rank = 0, MaxRank = RankMap[I->getParent()]; + for (unsigned i = 0, e = I->getNumOperands(); + i != e && Rank != MaxRank; ++i) + Rank = std::max(Rank, getRank(I->getOperand(i))); + + // If this is a not or neg instruction, do not count it for rank. This + // assures us that X and ~X will have the same rank. + if (!BinaryOperator::isNot(I) && !BinaryOperator::isNeg(I) && + !BinaryOperator::isFNeg(I)) + ++Rank; + + DEBUG(dbgs() << "Calculated Rank[" << V->getName() << "] = " << Rank << "\n"); + + return ValueRankMap[I] = Rank; +} + +// Canonicalize constants to RHS. Otherwise, sort the operands by rank. +void ReassociatePass::canonicalizeOperands(Instruction *I) { + assert(isa<BinaryOperator>(I) && "Expected binary operator."); + assert(I->isCommutative() && "Expected commutative operator."); + + Value *LHS = I->getOperand(0); + Value *RHS = I->getOperand(1); + if (LHS == RHS || isa<Constant>(RHS)) + return; + if (isa<Constant>(LHS) || getRank(RHS) < getRank(LHS)) + cast<BinaryOperator>(I)->swapOperands(); +} + +static BinaryOperator *CreateAdd(Value *S1, Value *S2, const Twine &Name, + Instruction *InsertBefore, Value *FlagsOp) { + if (S1->getType()->isIntOrIntVectorTy()) + return BinaryOperator::CreateAdd(S1, S2, Name, InsertBefore); + else { + BinaryOperator *Res = + BinaryOperator::CreateFAdd(S1, S2, Name, InsertBefore); + Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); + return Res; + } +} + +static BinaryOperator *CreateMul(Value *S1, Value *S2, const Twine &Name, + Instruction *InsertBefore, Value *FlagsOp) { + if (S1->getType()->isIntOrIntVectorTy()) + return BinaryOperator::CreateMul(S1, S2, Name, InsertBefore); + else { + BinaryOperator *Res = + BinaryOperator::CreateFMul(S1, S2, Name, InsertBefore); + Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); + return Res; + } +} + +static BinaryOperator *CreateNeg(Value *S1, const Twine &Name, + Instruction *InsertBefore, Value *FlagsOp) { + if (S1->getType()->isIntOrIntVectorTy()) + return BinaryOperator::CreateNeg(S1, Name, InsertBefore); + else { + BinaryOperator *Res = BinaryOperator::CreateFNeg(S1, Name, InsertBefore); + Res->setFastMathFlags(cast<FPMathOperator>(FlagsOp)->getFastMathFlags()); + return Res; + } +} + +/// Replace 0-X with X*-1. +static BinaryOperator *LowerNegateToMultiply(Instruction *Neg) { + Type *Ty = Neg->getType(); + Constant *NegOne = Ty->isIntOrIntVectorTy() ? + ConstantInt::getAllOnesValue(Ty) : ConstantFP::get(Ty, -1.0); + + BinaryOperator *Res = CreateMul(Neg->getOperand(1), NegOne, "", Neg, Neg); + Neg->setOperand(1, Constant::getNullValue(Ty)); // Drop use of op. + Res->takeName(Neg); + Neg->replaceAllUsesWith(Res); + Res->setDebugLoc(Neg->getDebugLoc()); + return Res; +} + +/// Returns k such that lambda(2^Bitwidth) = 2^k, where lambda is the Carmichael +/// function. This means that x^(2^k) === 1 mod 2^Bitwidth for +/// every odd x, i.e. x^(2^k) = 1 for every odd x in Bitwidth-bit arithmetic. +/// Note that 0 <= k < Bitwidth, and if Bitwidth > 3 then x^(2^k) = 0 for every +/// even x in Bitwidth-bit arithmetic. +static unsigned CarmichaelShift(unsigned Bitwidth) { + if (Bitwidth < 3) + return Bitwidth - 1; + return Bitwidth - 2; +} + +/// Add the extra weight 'RHS' to the existing weight 'LHS', +/// reducing the combined weight using any special properties of the operation. +/// The existing weight LHS represents the computation X op X op ... op X where +/// X occurs LHS times. The combined weight represents X op X op ... op X with +/// X occurring LHS + RHS times. If op is "Xor" for example then the combined +/// operation is equivalent to X if LHS + RHS is odd, or 0 if LHS + RHS is even; +/// the routine returns 1 in LHS in the first case, and 0 in LHS in the second. +static void IncorporateWeight(APInt &LHS, const APInt &RHS, unsigned Opcode) { + // If we were working with infinite precision arithmetic then the combined + // weight would be LHS + RHS. But we are using finite precision arithmetic, + // and the APInt sum LHS + RHS may not be correct if it wraps (it is correct + // for nilpotent operations and addition, but not for idempotent operations + // and multiplication), so it is important to correctly reduce the combined + // weight back into range if wrapping would be wrong. + + // If RHS is zero then the weight didn't change. + if (RHS.isMinValue()) + return; + // If LHS is zero then the combined weight is RHS. + if (LHS.isMinValue()) { + LHS = RHS; + return; + } + // From this point on we know that neither LHS nor RHS is zero. + + if (Instruction::isIdempotent(Opcode)) { + // Idempotent means X op X === X, so any non-zero weight is equivalent to a + // weight of 1. Keeping weights at zero or one also means that wrapping is + // not a problem. + assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); + return; // Return a weight of 1. + } + if (Instruction::isNilpotent(Opcode)) { + // Nilpotent means X op X === 0, so reduce weights modulo 2. + assert(LHS == 1 && RHS == 1 && "Weights not reduced!"); + LHS = 0; // 1 + 1 === 0 modulo 2. + return; + } + if (Opcode == Instruction::Add || Opcode == Instruction::FAdd) { + // TODO: Reduce the weight by exploiting nsw/nuw? + LHS += RHS; + return; + } + + assert((Opcode == Instruction::Mul || Opcode == Instruction::FMul) && + "Unknown associative operation!"); + unsigned Bitwidth = LHS.getBitWidth(); + // If CM is the Carmichael number then a weight W satisfying W >= CM+Bitwidth + // can be replaced with W-CM. That's because x^W=x^(W-CM) for every Bitwidth + // bit number x, since either x is odd in which case x^CM = 1, or x is even in + // which case both x^W and x^(W - CM) are zero. By subtracting off multiples + // of CM like this weights can always be reduced to the range [0, CM+Bitwidth) + // which by a happy accident means that they can always be represented using + // Bitwidth bits. + // TODO: Reduce the weight by exploiting nsw/nuw? (Could do much better than + // the Carmichael number). + if (Bitwidth > 3) { + /// CM - The value of Carmichael's lambda function. + APInt CM = APInt::getOneBitSet(Bitwidth, CarmichaelShift(Bitwidth)); + // Any weight W >= Threshold can be replaced with W - CM. + APInt Threshold = CM + Bitwidth; + assert(LHS.ult(Threshold) && RHS.ult(Threshold) && "Weights not reduced!"); + // For Bitwidth 4 or more the following sum does not overflow. + LHS += RHS; + while (LHS.uge(Threshold)) + LHS -= CM; + } else { + // To avoid problems with overflow do everything the same as above but using + // a larger type. + unsigned CM = 1U << CarmichaelShift(Bitwidth); + unsigned Threshold = CM + Bitwidth; + assert(LHS.getZExtValue() < Threshold && RHS.getZExtValue() < Threshold && + "Weights not reduced!"); + unsigned Total = LHS.getZExtValue() + RHS.getZExtValue(); + while (Total >= Threshold) + Total -= CM; + LHS = Total; + } +} + +using RepeatedValue = std::pair<Value*, APInt>; + +/// Given an associative binary expression, return the leaf +/// nodes in Ops along with their weights (how many times the leaf occurs). The +/// original expression is the same as +/// (Ops[0].first op Ops[0].first op ... Ops[0].first) <- Ops[0].second times +/// op +/// (Ops[1].first op Ops[1].first op ... Ops[1].first) <- Ops[1].second times +/// op +/// ... +/// op +/// (Ops[N].first op Ops[N].first op ... Ops[N].first) <- Ops[N].second times +/// +/// Note that the values Ops[0].first, ..., Ops[N].first are all distinct. +/// +/// This routine may modify the function, in which case it returns 'true'. The +/// changes it makes may well be destructive, changing the value computed by 'I' +/// to something completely different. Thus if the routine returns 'true' then +/// you MUST either replace I with a new expression computed from the Ops array, +/// or use RewriteExprTree to put the values back in. +/// +/// A leaf node is either not a binary operation of the same kind as the root +/// node 'I' (i.e. is not a binary operator at all, or is, but with a different +/// opcode), or is the same kind of binary operator but has a use which either +/// does not belong to the expression, or does belong to the expression but is +/// a leaf node. Every leaf node has at least one use that is a non-leaf node +/// of the expression, while for non-leaf nodes (except for the root 'I') every +/// use is a non-leaf node of the expression. +/// +/// For example: +/// expression graph node names +/// +/// + | I +/// / \ | +/// + + | A, B +/// / \ / \ | +/// * + * | C, D, E +/// / \ / \ / \ | +/// + * | F, G +/// +/// The leaf nodes are C, E, F and G. The Ops array will contain (maybe not in +/// that order) (C, 1), (E, 1), (F, 2), (G, 2). +/// +/// The expression is maximal: if some instruction is a binary operator of the +/// same kind as 'I', and all of its uses are non-leaf nodes of the expression, +/// then the instruction also belongs to the expression, is not a leaf node of +/// it, and its operands also belong to the expression (but may be leaf nodes). +/// +/// NOTE: This routine will set operands of non-leaf non-root nodes to undef in +/// order to ensure that every non-root node in the expression has *exactly one* +/// use by a non-leaf node of the expression. This destruction means that the +/// caller MUST either replace 'I' with a new expression or use something like +/// RewriteExprTree to put the values back in if the routine indicates that it +/// made a change by returning 'true'. +/// +/// In the above example either the right operand of A or the left operand of B +/// will be replaced by undef. If it is B's operand then this gives: +/// +/// + | I +/// / \ | +/// + + | A, B - operand of B replaced with undef +/// / \ \ | +/// * + * | C, D, E +/// / \ / \ / \ | +/// + * | F, G +/// +/// Note that such undef operands can only be reached by passing through 'I'. +/// For example, if you visit operands recursively starting from a leaf node +/// then you will never see such an undef operand unless you get back to 'I', +/// which requires passing through a phi node. +/// +/// Note that this routine may also mutate binary operators of the wrong type +/// that have all uses inside the expression (i.e. only used by non-leaf nodes +/// of the expression) if it can turn them into binary operators of the right +/// type and thus make the expression bigger. +static bool LinearizeExprTree(BinaryOperator *I, + SmallVectorImpl<RepeatedValue> &Ops) { + DEBUG(dbgs() << "LINEARIZE: " << *I << '\n'); + unsigned Bitwidth = I->getType()->getScalarType()->getPrimitiveSizeInBits(); + unsigned Opcode = I->getOpcode(); + assert(I->isAssociative() && I->isCommutative() && + "Expected an associative and commutative operation!"); + + // Visit all operands of the expression, keeping track of their weight (the + // number of paths from the expression root to the operand, or if you like + // the number of times that operand occurs in the linearized expression). + // For example, if I = X + A, where X = A + B, then I, X and B have weight 1 + // while A has weight two. + + // Worklist of non-leaf nodes (their operands are in the expression too) along + // with their weights, representing a certain number of paths to the operator. + // If an operator occurs in the worklist multiple times then we found multiple + // ways to get to it. + SmallVector<std::pair<BinaryOperator*, APInt>, 8> Worklist; // (Op, Weight) + Worklist.push_back(std::make_pair(I, APInt(Bitwidth, 1))); + bool Changed = false; + + // Leaves of the expression are values that either aren't the right kind of + // operation (eg: a constant, or a multiply in an add tree), or are, but have + // some uses that are not inside the expression. For example, in I = X + X, + // X = A + B, the value X has two uses (by I) that are in the expression. If + // X has any other uses, for example in a return instruction, then we consider + // X to be a leaf, and won't analyze it further. When we first visit a value, + // if it has more than one use then at first we conservatively consider it to + // be a leaf. Later, as the expression is explored, we may discover some more + // uses of the value from inside the expression. If all uses turn out to be + // from within the expression (and the value is a binary operator of the right + // kind) then the value is no longer considered to be a leaf, and its operands + // are explored. + + // Leaves - Keeps track of the set of putative leaves as well as the number of + // paths to each leaf seen so far. + using LeafMap = DenseMap<Value *, APInt>; + LeafMap Leaves; // Leaf -> Total weight so far. + SmallVector<Value *, 8> LeafOrder; // Ensure deterministic leaf output order. + +#ifndef NDEBUG + SmallPtrSet<Value *, 8> Visited; // For sanity checking the iteration scheme. +#endif + while (!Worklist.empty()) { + std::pair<BinaryOperator*, APInt> P = Worklist.pop_back_val(); + I = P.first; // We examine the operands of this binary operator. + + for (unsigned OpIdx = 0; OpIdx < 2; ++OpIdx) { // Visit operands. + Value *Op = I->getOperand(OpIdx); + APInt Weight = P.second; // Number of paths to this operand. + DEBUG(dbgs() << "OPERAND: " << *Op << " (" << Weight << ")\n"); + assert(!Op->use_empty() && "No uses, so how did we get to it?!"); + + // If this is a binary operation of the right kind with only one use then + // add its operands to the expression. + if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) { + assert(Visited.insert(Op).second && "Not first visit!"); + DEBUG(dbgs() << "DIRECT ADD: " << *Op << " (" << Weight << ")\n"); + Worklist.push_back(std::make_pair(BO, Weight)); + continue; + } + + // Appears to be a leaf. Is the operand already in the set of leaves? + LeafMap::iterator It = Leaves.find(Op); + if (It == Leaves.end()) { + // Not in the leaf map. Must be the first time we saw this operand. + assert(Visited.insert(Op).second && "Not first visit!"); + if (!Op->hasOneUse()) { + // This value has uses not accounted for by the expression, so it is + // not safe to modify. Mark it as being a leaf. + DEBUG(dbgs() << "ADD USES LEAF: " << *Op << " (" << Weight << ")\n"); + LeafOrder.push_back(Op); + Leaves[Op] = Weight; + continue; + } + // No uses outside the expression, try morphing it. + } else { + // Already in the leaf map. + assert(It != Leaves.end() && Visited.count(Op) && + "In leaf map but not visited!"); + + // Update the number of paths to the leaf. + IncorporateWeight(It->second, Weight, Opcode); + +#if 0 // TODO: Re-enable once PR13021 is fixed. + // The leaf already has one use from inside the expression. As we want + // exactly one such use, drop this new use of the leaf. + assert(!Op->hasOneUse() && "Only one use, but we got here twice!"); + I->setOperand(OpIdx, UndefValue::get(I->getType())); + Changed = true; + + // If the leaf is a binary operation of the right kind and we now see + // that its multiple original uses were in fact all by nodes belonging + // to the expression, then no longer consider it to be a leaf and add + // its operands to the expression. + if (BinaryOperator *BO = isReassociableOp(Op, Opcode)) { + DEBUG(dbgs() << "UNLEAF: " << *Op << " (" << It->second << ")\n"); + Worklist.push_back(std::make_pair(BO, It->second)); + Leaves.erase(It); + continue; + } +#endif + + // If we still have uses that are not accounted for by the expression + // then it is not safe to modify the value. + if (!Op->hasOneUse()) + continue; + + // No uses outside the expression, try morphing it. + Weight = It->second; + Leaves.erase(It); // Since the value may be morphed below. + } + + // At this point we have a value which, first of all, is not a binary + // expression of the right kind, and secondly, is only used inside the + // expression. This means that it can safely be modified. See if we + // can usefully morph it into an expression of the right kind. + assert((!isa<Instruction>(Op) || + cast<Instruction>(Op)->getOpcode() != Opcode + || (isa<FPMathOperator>(Op) && + !cast<Instruction>(Op)->isFast())) && + "Should have been handled above!"); + assert(Op->hasOneUse() && "Has uses outside the expression tree!"); + + // If this is a multiply expression, turn any internal negations into + // multiplies by -1 so they can be reassociated. + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(Op)) + if ((Opcode == Instruction::Mul && BinaryOperator::isNeg(BO)) || + (Opcode == Instruction::FMul && BinaryOperator::isFNeg(BO))) { + DEBUG(dbgs() << "MORPH LEAF: " << *Op << " (" << Weight << ") TO "); + BO = LowerNegateToMultiply(BO); + DEBUG(dbgs() << *BO << '\n'); + Worklist.push_back(std::make_pair(BO, Weight)); + Changed = true; + continue; + } + + // Failed to morph into an expression of the right type. This really is + // a leaf. + DEBUG(dbgs() << "ADD LEAF: " << *Op << " (" << Weight << ")\n"); + assert(!isReassociableOp(Op, Opcode) && "Value was morphed?"); + LeafOrder.push_back(Op); + Leaves[Op] = Weight; + } + } + + // The leaves, repeated according to their weights, represent the linearized + // form of the expression. + for (unsigned i = 0, e = LeafOrder.size(); i != e; ++i) { + Value *V = LeafOrder[i]; + LeafMap::iterator It = Leaves.find(V); + if (It == Leaves.end()) + // Node initially thought to be a leaf wasn't. + continue; + assert(!isReassociableOp(V, Opcode) && "Shouldn't be a leaf!"); + APInt Weight = It->second; + if (Weight.isMinValue()) + // Leaf already output or weight reduction eliminated it. + continue; + // Ensure the leaf is only output once. + It->second = 0; + Ops.push_back(std::make_pair(V, Weight)); + } + + // For nilpotent operations or addition there may be no operands, for example + // because the expression was "X xor X" or consisted of 2^Bitwidth additions: + // in both cases the weight reduces to 0 causing the value to be skipped. + if (Ops.empty()) { + Constant *Identity = ConstantExpr::getBinOpIdentity(Opcode, I->getType()); + assert(Identity && "Associative operation without identity!"); + Ops.emplace_back(Identity, APInt(Bitwidth, 1)); + } + + return Changed; +} + +/// Now that the operands for this expression tree are +/// linearized and optimized, emit them in-order. +void ReassociatePass::RewriteExprTree(BinaryOperator *I, + SmallVectorImpl<ValueEntry> &Ops) { + assert(Ops.size() > 1 && "Single values should be used directly!"); + + // Since our optimizations should never increase the number of operations, the + // new expression can usually be written reusing the existing binary operators + // from the original expression tree, without creating any new instructions, + // though the rewritten expression may have a completely different topology. + // We take care to not change anything if the new expression will be the same + // as the original. If more than trivial changes (like commuting operands) + // were made then we are obliged to clear out any optional subclass data like + // nsw flags. + + /// NodesToRewrite - Nodes from the original expression available for writing + /// the new expression into. + SmallVector<BinaryOperator*, 8> NodesToRewrite; + unsigned Opcode = I->getOpcode(); + BinaryOperator *Op = I; + + /// NotRewritable - The operands being written will be the leaves of the new + /// expression and must not be used as inner nodes (via NodesToRewrite) by + /// mistake. Inner nodes are always reassociable, and usually leaves are not + /// (if they were they would have been incorporated into the expression and so + /// would not be leaves), so most of the time there is no danger of this. But + /// in rare cases a leaf may become reassociable if an optimization kills uses + /// of it, or it may momentarily become reassociable during rewriting (below) + /// due it being removed as an operand of one of its uses. Ensure that misuse + /// of leaf nodes as inner nodes cannot occur by remembering all of the future + /// leaves and refusing to reuse any of them as inner nodes. + SmallPtrSet<Value*, 8> NotRewritable; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + NotRewritable.insert(Ops[i].Op); + + // ExpressionChanged - Non-null if the rewritten expression differs from the + // original in some non-trivial way, requiring the clearing of optional flags. + // Flags are cleared from the operator in ExpressionChanged up to I inclusive. + BinaryOperator *ExpressionChanged = nullptr; + for (unsigned i = 0; ; ++i) { + // The last operation (which comes earliest in the IR) is special as both + // operands will come from Ops, rather than just one with the other being + // a subexpression. + if (i+2 == Ops.size()) { + Value *NewLHS = Ops[i].Op; + Value *NewRHS = Ops[i+1].Op; + Value *OldLHS = Op->getOperand(0); + Value *OldRHS = Op->getOperand(1); + + if (NewLHS == OldLHS && NewRHS == OldRHS) + // Nothing changed, leave it alone. + break; + + if (NewLHS == OldRHS && NewRHS == OldLHS) { + // The order of the operands was reversed. Swap them. + DEBUG(dbgs() << "RA: " << *Op << '\n'); + Op->swapOperands(); + DEBUG(dbgs() << "TO: " << *Op << '\n'); + MadeChange = true; + ++NumChanged; + break; + } + + // The new operation differs non-trivially from the original. Overwrite + // the old operands with the new ones. + DEBUG(dbgs() << "RA: " << *Op << '\n'); + if (NewLHS != OldLHS) { + BinaryOperator *BO = isReassociableOp(OldLHS, Opcode); + if (BO && !NotRewritable.count(BO)) + NodesToRewrite.push_back(BO); + Op->setOperand(0, NewLHS); + } + if (NewRHS != OldRHS) { + BinaryOperator *BO = isReassociableOp(OldRHS, Opcode); + if (BO && !NotRewritable.count(BO)) + NodesToRewrite.push_back(BO); + Op->setOperand(1, NewRHS); + } + DEBUG(dbgs() << "TO: " << *Op << '\n'); + + ExpressionChanged = Op; + MadeChange = true; + ++NumChanged; + + break; + } + + // Not the last operation. The left-hand side will be a sub-expression + // while the right-hand side will be the current element of Ops. + Value *NewRHS = Ops[i].Op; + if (NewRHS != Op->getOperand(1)) { + DEBUG(dbgs() << "RA: " << *Op << '\n'); + if (NewRHS == Op->getOperand(0)) { + // The new right-hand side was already present as the left operand. If + // we are lucky then swapping the operands will sort out both of them. + Op->swapOperands(); + } else { + // Overwrite with the new right-hand side. + BinaryOperator *BO = isReassociableOp(Op->getOperand(1), Opcode); + if (BO && !NotRewritable.count(BO)) + NodesToRewrite.push_back(BO); + Op->setOperand(1, NewRHS); + ExpressionChanged = Op; + } + DEBUG(dbgs() << "TO: " << *Op << '\n'); + MadeChange = true; + ++NumChanged; + } + + // Now deal with the left-hand side. If this is already an operation node + // from the original expression then just rewrite the rest of the expression + // into it. + BinaryOperator *BO = isReassociableOp(Op->getOperand(0), Opcode); + if (BO && !NotRewritable.count(BO)) { + Op = BO; + continue; + } + + // Otherwise, grab a spare node from the original expression and use that as + // the left-hand side. If there are no nodes left then the optimizers made + // an expression with more nodes than the original! This usually means that + // they did something stupid but it might mean that the problem was just too + // hard (finding the mimimal number of multiplications needed to realize a + // multiplication expression is NP-complete). Whatever the reason, smart or + // stupid, create a new node if there are none left. + BinaryOperator *NewOp; + if (NodesToRewrite.empty()) { + Constant *Undef = UndefValue::get(I->getType()); + NewOp = BinaryOperator::Create(Instruction::BinaryOps(Opcode), + Undef, Undef, "", I); + if (NewOp->getType()->isFPOrFPVectorTy()) + NewOp->setFastMathFlags(I->getFastMathFlags()); + } else { + NewOp = NodesToRewrite.pop_back_val(); + } + + DEBUG(dbgs() << "RA: " << *Op << '\n'); + Op->setOperand(0, NewOp); + DEBUG(dbgs() << "TO: " << *Op << '\n'); + ExpressionChanged = Op; + MadeChange = true; + ++NumChanged; + Op = NewOp; + } + + // If the expression changed non-trivially then clear out all subclass data + // starting from the operator specified in ExpressionChanged, and compactify + // the operators to just before the expression root to guarantee that the + // expression tree is dominated by all of Ops. + if (ExpressionChanged) + do { + // Preserve FastMathFlags. + if (isa<FPMathOperator>(I)) { + FastMathFlags Flags = I->getFastMathFlags(); + ExpressionChanged->clearSubclassOptionalData(); + ExpressionChanged->setFastMathFlags(Flags); + } else + ExpressionChanged->clearSubclassOptionalData(); + + if (ExpressionChanged == I) + break; + ExpressionChanged->moveBefore(I); + ExpressionChanged = cast<BinaryOperator>(*ExpressionChanged->user_begin()); + } while (true); + + // Throw away any left over nodes from the original expression. + for (unsigned i = 0, e = NodesToRewrite.size(); i != e; ++i) + RedoInsts.insert(NodesToRewrite[i]); +} + +/// Insert instructions before the instruction pointed to by BI, +/// that computes the negative version of the value specified. The negative +/// version of the value is returned, and BI is left pointing at the instruction +/// that should be processed next by the reassociation pass. +/// Also add intermediate instructions to the redo list that are modified while +/// pushing the negates through adds. These will be revisited to see if +/// additional opportunities have been exposed. +static Value *NegateValue(Value *V, Instruction *BI, + SetVector<AssertingVH<Instruction>> &ToRedo) { + if (auto *C = dyn_cast<Constant>(V)) + return C->getType()->isFPOrFPVectorTy() ? ConstantExpr::getFNeg(C) : + ConstantExpr::getNeg(C); + + // We are trying to expose opportunity for reassociation. One of the things + // that we want to do to achieve this is to push a negation as deep into an + // expression chain as possible, to expose the add instructions. In practice, + // this means that we turn this: + // X = -(A+12+C+D) into X = -A + -12 + -C + -D = -12 + -A + -C + -D + // so that later, a: Y = 12+X could get reassociated with the -12 to eliminate + // the constants. We assume that instcombine will clean up the mess later if + // we introduce tons of unnecessary negation instructions. + // + if (BinaryOperator *I = + isReassociableOp(V, Instruction::Add, Instruction::FAdd)) { + // Push the negates through the add. + I->setOperand(0, NegateValue(I->getOperand(0), BI, ToRedo)); + I->setOperand(1, NegateValue(I->getOperand(1), BI, ToRedo)); + if (I->getOpcode() == Instruction::Add) { + I->setHasNoUnsignedWrap(false); + I->setHasNoSignedWrap(false); + } + + // We must move the add instruction here, because the neg instructions do + // not dominate the old add instruction in general. By moving it, we are + // assured that the neg instructions we just inserted dominate the + // instruction we are about to insert after them. + // + I->moveBefore(BI); + I->setName(I->getName()+".neg"); + + // Add the intermediate negates to the redo list as processing them later + // could expose more reassociating opportunities. + ToRedo.insert(I); + return I; + } + + // Okay, we need to materialize a negated version of V with an instruction. + // Scan the use lists of V to see if we have one already. + for (User *U : V->users()) { + if (!BinaryOperator::isNeg(U) && !BinaryOperator::isFNeg(U)) + continue; + + // We found one! Now we have to make sure that the definition dominates + // this use. We do this by moving it to the entry block (if it is a + // non-instruction value) or right after the definition. These negates will + // be zapped by reassociate later, so we don't need much finesse here. + BinaryOperator *TheNeg = cast<BinaryOperator>(U); + + // Verify that the negate is in this function, V might be a constant expr. + if (TheNeg->getParent()->getParent() != BI->getParent()->getParent()) + continue; + + BasicBlock::iterator InsertPt; + if (Instruction *InstInput = dyn_cast<Instruction>(V)) { + if (InvokeInst *II = dyn_cast<InvokeInst>(InstInput)) { + InsertPt = II->getNormalDest()->begin(); + } else { + InsertPt = ++InstInput->getIterator(); + } + while (isa<PHINode>(InsertPt)) ++InsertPt; + } else { + InsertPt = TheNeg->getParent()->getParent()->getEntryBlock().begin(); + } + TheNeg->moveBefore(&*InsertPt); + if (TheNeg->getOpcode() == Instruction::Sub) { + TheNeg->setHasNoUnsignedWrap(false); + TheNeg->setHasNoSignedWrap(false); + } else { + TheNeg->andIRFlags(BI); + } + ToRedo.insert(TheNeg); + return TheNeg; + } + + // Insert a 'neg' instruction that subtracts the value from zero to get the + // negation. + BinaryOperator *NewNeg = CreateNeg(V, V->getName() + ".neg", BI, BI); + ToRedo.insert(NewNeg); + return NewNeg; +} + +/// Return true if we should break up this subtract of X-Y into (X + -Y). +static bool ShouldBreakUpSubtract(Instruction *Sub) { + // If this is a negation, we can't split it up! + if (BinaryOperator::isNeg(Sub) || BinaryOperator::isFNeg(Sub)) + return false; + + // Don't breakup X - undef. + if (isa<UndefValue>(Sub->getOperand(1))) + return false; + + // Don't bother to break this up unless either the LHS is an associable add or + // subtract or if this is only used by one. + Value *V0 = Sub->getOperand(0); + if (isReassociableOp(V0, Instruction::Add, Instruction::FAdd) || + isReassociableOp(V0, Instruction::Sub, Instruction::FSub)) + return true; + Value *V1 = Sub->getOperand(1); + if (isReassociableOp(V1, Instruction::Add, Instruction::FAdd) || + isReassociableOp(V1, Instruction::Sub, Instruction::FSub)) + return true; + Value *VB = Sub->user_back(); + if (Sub->hasOneUse() && + (isReassociableOp(VB, Instruction::Add, Instruction::FAdd) || + isReassociableOp(VB, Instruction::Sub, Instruction::FSub))) + return true; + + return false; +} + +/// If we have (X-Y), and if either X is an add, or if this is only used by an +/// add, transform this into (X+(0-Y)) to promote better reassociation. +static BinaryOperator * +BreakUpSubtract(Instruction *Sub, SetVector<AssertingVH<Instruction>> &ToRedo) { + // Convert a subtract into an add and a neg instruction. This allows sub + // instructions to be commuted with other add instructions. + // + // Calculate the negative value of Operand 1 of the sub instruction, + // and set it as the RHS of the add instruction we just made. + Value *NegVal = NegateValue(Sub->getOperand(1), Sub, ToRedo); + BinaryOperator *New = CreateAdd(Sub->getOperand(0), NegVal, "", Sub, Sub); + Sub->setOperand(0, Constant::getNullValue(Sub->getType())); // Drop use of op. + Sub->setOperand(1, Constant::getNullValue(Sub->getType())); // Drop use of op. + New->takeName(Sub); + + // Everyone now refers to the add instruction. + Sub->replaceAllUsesWith(New); + New->setDebugLoc(Sub->getDebugLoc()); + + DEBUG(dbgs() << "Negated: " << *New << '\n'); + return New; +} + +/// If this is a shift of a reassociable multiply or is used by one, change +/// this into a multiply by a constant to assist with further reassociation. +static BinaryOperator *ConvertShiftToMul(Instruction *Shl) { + Constant *MulCst = ConstantInt::get(Shl->getType(), 1); + MulCst = ConstantExpr::getShl(MulCst, cast<Constant>(Shl->getOperand(1))); + + BinaryOperator *Mul = + BinaryOperator::CreateMul(Shl->getOperand(0), MulCst, "", Shl); + Shl->setOperand(0, UndefValue::get(Shl->getType())); // Drop use of op. + Mul->takeName(Shl); + + // Everyone now refers to the mul instruction. + Shl->replaceAllUsesWith(Mul); + Mul->setDebugLoc(Shl->getDebugLoc()); + + // We can safely preserve the nuw flag in all cases. It's also safe to turn a + // nuw nsw shl into a nuw nsw mul. However, nsw in isolation requires special + // handling. + bool NSW = cast<BinaryOperator>(Shl)->hasNoSignedWrap(); + bool NUW = cast<BinaryOperator>(Shl)->hasNoUnsignedWrap(); + if (NSW && NUW) + Mul->setHasNoSignedWrap(true); + Mul->setHasNoUnsignedWrap(NUW); + return Mul; +} + +/// Scan backwards and forwards among values with the same rank as element i +/// to see if X exists. If X does not exist, return i. This is useful when +/// scanning for 'x' when we see '-x' because they both get the same rank. +static unsigned FindInOperandList(const SmallVectorImpl<ValueEntry> &Ops, + unsigned i, Value *X) { + unsigned XRank = Ops[i].Rank; + unsigned e = Ops.size(); + for (unsigned j = i+1; j != e && Ops[j].Rank == XRank; ++j) { + if (Ops[j].Op == X) + return j; + if (Instruction *I1 = dyn_cast<Instruction>(Ops[j].Op)) + if (Instruction *I2 = dyn_cast<Instruction>(X)) + if (I1->isIdenticalTo(I2)) + return j; + } + // Scan backwards. + for (unsigned j = i-1; j != ~0U && Ops[j].Rank == XRank; --j) { + if (Ops[j].Op == X) + return j; + if (Instruction *I1 = dyn_cast<Instruction>(Ops[j].Op)) + if (Instruction *I2 = dyn_cast<Instruction>(X)) + if (I1->isIdenticalTo(I2)) + return j; + } + return i; +} + +/// Emit a tree of add instructions, summing Ops together +/// and returning the result. Insert the tree before I. +static Value *EmitAddTreeOfValues(Instruction *I, + SmallVectorImpl<WeakTrackingVH> &Ops) { + if (Ops.size() == 1) return Ops.back(); + + Value *V1 = Ops.back(); + Ops.pop_back(); + Value *V2 = EmitAddTreeOfValues(I, Ops); + return CreateAdd(V2, V1, "reass.add", I, I); +} + +/// If V is an expression tree that is a multiplication sequence, +/// and if this sequence contains a multiply by Factor, +/// remove Factor from the tree and return the new tree. +Value *ReassociatePass::RemoveFactorFromExpression(Value *V, Value *Factor) { + BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); + if (!BO) + return nullptr; + + SmallVector<RepeatedValue, 8> Tree; + MadeChange |= LinearizeExprTree(BO, Tree); + SmallVector<ValueEntry, 8> Factors; + Factors.reserve(Tree.size()); + for (unsigned i = 0, e = Tree.size(); i != e; ++i) { + RepeatedValue E = Tree[i]; + Factors.append(E.second.getZExtValue(), + ValueEntry(getRank(E.first), E.first)); + } + + bool FoundFactor = false; + bool NeedsNegate = false; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) { + if (Factors[i].Op == Factor) { + FoundFactor = true; + Factors.erase(Factors.begin()+i); + break; + } + + // If this is a negative version of this factor, remove it. + if (ConstantInt *FC1 = dyn_cast<ConstantInt>(Factor)) { + if (ConstantInt *FC2 = dyn_cast<ConstantInt>(Factors[i].Op)) + if (FC1->getValue() == -FC2->getValue()) { + FoundFactor = NeedsNegate = true; + Factors.erase(Factors.begin()+i); + break; + } + } else if (ConstantFP *FC1 = dyn_cast<ConstantFP>(Factor)) { + if (ConstantFP *FC2 = dyn_cast<ConstantFP>(Factors[i].Op)) { + const APFloat &F1 = FC1->getValueAPF(); + APFloat F2(FC2->getValueAPF()); + F2.changeSign(); + if (F1.compare(F2) == APFloat::cmpEqual) { + FoundFactor = NeedsNegate = true; + Factors.erase(Factors.begin() + i); + break; + } + } + } + } + + if (!FoundFactor) { + // Make sure to restore the operands to the expression tree. + RewriteExprTree(BO, Factors); + return nullptr; + } + + BasicBlock::iterator InsertPt = ++BO->getIterator(); + + // If this was just a single multiply, remove the multiply and return the only + // remaining operand. + if (Factors.size() == 1) { + RedoInsts.insert(BO); + V = Factors[0].Op; + } else { + RewriteExprTree(BO, Factors); + V = BO; + } + + if (NeedsNegate) + V = CreateNeg(V, "neg", &*InsertPt, BO); + + return V; +} + +/// If V is a single-use multiply, recursively add its operands as factors, +/// otherwise add V to the list of factors. +/// +/// Ops is the top-level list of add operands we're trying to factor. +static void FindSingleUseMultiplyFactors(Value *V, + SmallVectorImpl<Value*> &Factors) { + BinaryOperator *BO = isReassociableOp(V, Instruction::Mul, Instruction::FMul); + if (!BO) { + Factors.push_back(V); + return; + } + + // Otherwise, add the LHS and RHS to the list of factors. + FindSingleUseMultiplyFactors(BO->getOperand(1), Factors); + FindSingleUseMultiplyFactors(BO->getOperand(0), Factors); +} + +/// Optimize a series of operands to an 'and', 'or', or 'xor' instruction. +/// This optimizes based on identities. If it can be reduced to a single Value, +/// it is returned, otherwise the Ops list is mutated as necessary. +static Value *OptimizeAndOrXor(unsigned Opcode, + SmallVectorImpl<ValueEntry> &Ops) { + // Scan the operand lists looking for X and ~X pairs, along with X,X pairs. + // If we find any, we can simplify the expression. X&~X == 0, X|~X == -1. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + // First, check for X and ~X in the operand list. + assert(i < Ops.size()); + if (BinaryOperator::isNot(Ops[i].Op)) { // Cannot occur for ^. + Value *X = BinaryOperator::getNotArgument(Ops[i].Op); + unsigned FoundX = FindInOperandList(Ops, i, X); + if (FoundX != i) { + if (Opcode == Instruction::And) // ...&X&~X = 0 + return Constant::getNullValue(X->getType()); + + if (Opcode == Instruction::Or) // ...|X|~X = -1 + return Constant::getAllOnesValue(X->getType()); + } + } + + // Next, check for duplicate pairs of values, which we assume are next to + // each other, due to our sorting criteria. + assert(i < Ops.size()); + if (i+1 != Ops.size() && Ops[i+1].Op == Ops[i].Op) { + if (Opcode == Instruction::And || Opcode == Instruction::Or) { + // Drop duplicate values for And and Or. + Ops.erase(Ops.begin()+i); + --i; --e; + ++NumAnnihil; + continue; + } + + // Drop pairs of values for Xor. + assert(Opcode == Instruction::Xor); + if (e == 2) + return Constant::getNullValue(Ops[0].Op->getType()); + + // Y ^ X^X -> Y + Ops.erase(Ops.begin()+i, Ops.begin()+i+2); + i -= 1; e -= 2; + ++NumAnnihil; + } + } + return nullptr; +} + +/// Helper function of CombineXorOpnd(). It creates a bitwise-and +/// instruction with the given two operands, and return the resulting +/// instruction. There are two special cases: 1) if the constant operand is 0, +/// it will return NULL. 2) if the constant is ~0, the symbolic operand will +/// be returned. +static Value *createAndInstr(Instruction *InsertBefore, Value *Opnd, + const APInt &ConstOpnd) { + if (ConstOpnd.isNullValue()) + return nullptr; + + if (ConstOpnd.isAllOnesValue()) + return Opnd; + + Instruction *I = BinaryOperator::CreateAnd( + Opnd, ConstantInt::get(Opnd->getType(), ConstOpnd), "and.ra", + InsertBefore); + I->setDebugLoc(InsertBefore->getDebugLoc()); + return I; +} + +// Helper function of OptimizeXor(). It tries to simplify "Opnd1 ^ ConstOpnd" +// into "R ^ C", where C would be 0, and R is a symbolic value. +// +// If it was successful, true is returned, and the "R" and "C" is returned +// via "Res" and "ConstOpnd", respectively; otherwise, false is returned, +// and both "Res" and "ConstOpnd" remain unchanged. +bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, + APInt &ConstOpnd, Value *&Res) { + // Xor-Rule 1: (x | c1) ^ c2 = (x | c1) ^ (c1 ^ c1) ^ c2 + // = ((x | c1) ^ c1) ^ (c1 ^ c2) + // = (x & ~c1) ^ (c1 ^ c2) + // It is useful only when c1 == c2. + if (!Opnd1->isOrExpr() || Opnd1->getConstPart().isNullValue()) + return false; + + if (!Opnd1->getValue()->hasOneUse()) + return false; + + const APInt &C1 = Opnd1->getConstPart(); + if (C1 != ConstOpnd) + return false; + + Value *X = Opnd1->getSymbolicPart(); + Res = createAndInstr(I, X, ~C1); + // ConstOpnd was C2, now C1 ^ C2. + ConstOpnd ^= C1; + + if (Instruction *T = dyn_cast<Instruction>(Opnd1->getValue())) + RedoInsts.insert(T); + return true; +} + +// Helper function of OptimizeXor(). It tries to simplify +// "Opnd1 ^ Opnd2 ^ ConstOpnd" into "R ^ C", where C would be 0, and R is a +// symbolic value. +// +// If it was successful, true is returned, and the "R" and "C" is returned +// via "Res" and "ConstOpnd", respectively (If the entire expression is +// evaluated to a constant, the Res is set to NULL); otherwise, false is +// returned, and both "Res" and "ConstOpnd" remain unchanged. +bool ReassociatePass::CombineXorOpnd(Instruction *I, XorOpnd *Opnd1, + XorOpnd *Opnd2, APInt &ConstOpnd, + Value *&Res) { + Value *X = Opnd1->getSymbolicPart(); + if (X != Opnd2->getSymbolicPart()) + return false; + + // This many instruction become dead.(At least "Opnd1 ^ Opnd2" will die.) + int DeadInstNum = 1; + if (Opnd1->getValue()->hasOneUse()) + DeadInstNum++; + if (Opnd2->getValue()->hasOneUse()) + DeadInstNum++; + + // Xor-Rule 2: + // (x | c1) ^ (x & c2) + // = (x|c1) ^ (x&c2) ^ (c1 ^ c1) = ((x|c1) ^ c1) ^ (x & c2) ^ c1 + // = (x & ~c1) ^ (x & c2) ^ c1 // Xor-Rule 1 + // = (x & c3) ^ c1, where c3 = ~c1 ^ c2 // Xor-rule 3 + // + if (Opnd1->isOrExpr() != Opnd2->isOrExpr()) { + if (Opnd2->isOrExpr()) + std::swap(Opnd1, Opnd2); + + const APInt &C1 = Opnd1->getConstPart(); + const APInt &C2 = Opnd2->getConstPart(); + APInt C3((~C1) ^ C2); + + // Do not increase code size! + if (!C3.isNullValue() && !C3.isAllOnesValue()) { + int NewInstNum = ConstOpnd.getBoolValue() ? 1 : 2; + if (NewInstNum > DeadInstNum) + return false; + } + + Res = createAndInstr(I, X, C3); + ConstOpnd ^= C1; + } else if (Opnd1->isOrExpr()) { + // Xor-Rule 3: (x | c1) ^ (x | c2) = (x & c3) ^ c3 where c3 = c1 ^ c2 + // + const APInt &C1 = Opnd1->getConstPart(); + const APInt &C2 = Opnd2->getConstPart(); + APInt C3 = C1 ^ C2; + + // Do not increase code size + if (!C3.isNullValue() && !C3.isAllOnesValue()) { + int NewInstNum = ConstOpnd.getBoolValue() ? 1 : 2; + if (NewInstNum > DeadInstNum) + return false; + } + + Res = createAndInstr(I, X, C3); + ConstOpnd ^= C3; + } else { + // Xor-Rule 4: (x & c1) ^ (x & c2) = (x & (c1^c2)) + // + const APInt &C1 = Opnd1->getConstPart(); + const APInt &C2 = Opnd2->getConstPart(); + APInt C3 = C1 ^ C2; + Res = createAndInstr(I, X, C3); + } + + // Put the original operands in the Redo list; hope they will be deleted + // as dead code. + if (Instruction *T = dyn_cast<Instruction>(Opnd1->getValue())) + RedoInsts.insert(T); + if (Instruction *T = dyn_cast<Instruction>(Opnd2->getValue())) + RedoInsts.insert(T); + + return true; +} + +/// Optimize a series of operands to an 'xor' instruction. If it can be reduced +/// to a single Value, it is returned, otherwise the Ops list is mutated as +/// necessary. +Value *ReassociatePass::OptimizeXor(Instruction *I, + SmallVectorImpl<ValueEntry> &Ops) { + if (Value *V = OptimizeAndOrXor(Instruction::Xor, Ops)) + return V; + + if (Ops.size() == 1) + return nullptr; + + SmallVector<XorOpnd, 8> Opnds; + SmallVector<XorOpnd*, 8> OpndPtrs; + Type *Ty = Ops[0].Op->getType(); + APInt ConstOpnd(Ty->getScalarSizeInBits(), 0); + + // Step 1: Convert ValueEntry to XorOpnd + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + Value *V = Ops[i].Op; + const APInt *C; + // TODO: Support non-splat vectors. + if (match(V, PatternMatch::m_APInt(C))) { + ConstOpnd ^= *C; + } else { + XorOpnd O(V); + O.setSymbolicRank(getRank(O.getSymbolicPart())); + Opnds.push_back(O); + } + } + + // NOTE: From this point on, do *NOT* add/delete element to/from "Opnds". + // It would otherwise invalidate the "Opnds"'s iterator, and hence invalidate + // the "OpndPtrs" as well. For the similar reason, do not fuse this loop + // with the previous loop --- the iterator of the "Opnds" may be invalidated + // when new elements are added to the vector. + for (unsigned i = 0, e = Opnds.size(); i != e; ++i) + OpndPtrs.push_back(&Opnds[i]); + + // Step 2: Sort the Xor-Operands in a way such that the operands containing + // the same symbolic value cluster together. For instance, the input operand + // sequence ("x | 123", "y & 456", "x & 789") will be sorted into: + // ("x | 123", "x & 789", "y & 456"). + // + // The purpose is twofold: + // 1) Cluster together the operands sharing the same symbolic-value. + // 2) Operand having smaller symbolic-value-rank is permuted earlier, which + // could potentially shorten crital path, and expose more loop-invariants. + // Note that values' rank are basically defined in RPO order (FIXME). + // So, if Rank(X) < Rank(Y) < Rank(Z), it means X is defined earlier + // than Y which is defined earlier than Z. Permute "x | 1", "Y & 2", + // "z" in the order of X-Y-Z is better than any other orders. + std::stable_sort(OpndPtrs.begin(), OpndPtrs.end(), + [](XorOpnd *LHS, XorOpnd *RHS) { + return LHS->getSymbolicRank() < RHS->getSymbolicRank(); + }); + + // Step 3: Combine adjacent operands + XorOpnd *PrevOpnd = nullptr; + bool Changed = false; + for (unsigned i = 0, e = Opnds.size(); i < e; i++) { + XorOpnd *CurrOpnd = OpndPtrs[i]; + // The combined value + Value *CV; + + // Step 3.1: Try simplifying "CurrOpnd ^ ConstOpnd" + if (!ConstOpnd.isNullValue() && + CombineXorOpnd(I, CurrOpnd, ConstOpnd, CV)) { + Changed = true; + if (CV) + *CurrOpnd = XorOpnd(CV); + else { + CurrOpnd->Invalidate(); + continue; + } + } + + if (!PrevOpnd || CurrOpnd->getSymbolicPart() != PrevOpnd->getSymbolicPart()) { + PrevOpnd = CurrOpnd; + continue; + } + + // step 3.2: When previous and current operands share the same symbolic + // value, try to simplify "PrevOpnd ^ CurrOpnd ^ ConstOpnd" + if (CombineXorOpnd(I, CurrOpnd, PrevOpnd, ConstOpnd, CV)) { + // Remove previous operand + PrevOpnd->Invalidate(); + if (CV) { + *CurrOpnd = XorOpnd(CV); + PrevOpnd = CurrOpnd; + } else { + CurrOpnd->Invalidate(); + PrevOpnd = nullptr; + } + Changed = true; + } + } + + // Step 4: Reassemble the Ops + if (Changed) { + Ops.clear(); + for (unsigned int i = 0, e = Opnds.size(); i < e; i++) { + XorOpnd &O = Opnds[i]; + if (O.isInvalid()) + continue; + ValueEntry VE(getRank(O.getValue()), O.getValue()); + Ops.push_back(VE); + } + if (!ConstOpnd.isNullValue()) { + Value *C = ConstantInt::get(Ty, ConstOpnd); + ValueEntry VE(getRank(C), C); + Ops.push_back(VE); + } + unsigned Sz = Ops.size(); + if (Sz == 1) + return Ops.back().Op; + if (Sz == 0) { + assert(ConstOpnd.isNullValue()); + return ConstantInt::get(Ty, ConstOpnd); + } + } + + return nullptr; +} + +/// Optimize a series of operands to an 'add' instruction. This +/// optimizes based on identities. If it can be reduced to a single Value, it +/// is returned, otherwise the Ops list is mutated as necessary. +Value *ReassociatePass::OptimizeAdd(Instruction *I, + SmallVectorImpl<ValueEntry> &Ops) { + // Scan the operand lists looking for X and -X pairs. If we find any, we + // can simplify expressions like X+-X == 0 and X+~X ==-1. While we're at it, + // scan for any + // duplicates. We want to canonicalize Y+Y+Y+Z -> 3*Y+Z. + + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + Value *TheOp = Ops[i].Op; + // Check to see if we've seen this operand before. If so, we factor all + // instances of the operand together. Due to our sorting criteria, we know + // that these need to be next to each other in the vector. + if (i+1 != Ops.size() && Ops[i+1].Op == TheOp) { + // Rescan the list, remove all instances of this operand from the expr. + unsigned NumFound = 0; + do { + Ops.erase(Ops.begin()+i); + ++NumFound; + } while (i != Ops.size() && Ops[i].Op == TheOp); + + DEBUG(dbgs() << "\nFACTORING [" << NumFound << "]: " << *TheOp << '\n'); + ++NumFactor; + + // Insert a new multiply. + Type *Ty = TheOp->getType(); + Constant *C = Ty->isIntOrIntVectorTy() ? + ConstantInt::get(Ty, NumFound) : ConstantFP::get(Ty, NumFound); + Instruction *Mul = CreateMul(TheOp, C, "factor", I, I); + + // Now that we have inserted a multiply, optimize it. This allows us to + // handle cases that require multiple factoring steps, such as this: + // (X*2) + (X*2) + (X*2) -> (X*2)*3 -> X*6 + RedoInsts.insert(Mul); + + // If every add operand was a duplicate, return the multiply. + if (Ops.empty()) + return Mul; + + // Otherwise, we had some input that didn't have the dupe, such as + // "A + A + B" -> "A*2 + B". Add the new multiply to the list of + // things being added by this operation. + Ops.insert(Ops.begin(), ValueEntry(getRank(Mul), Mul)); + + --i; + e = Ops.size(); + continue; + } + + // Check for X and -X or X and ~X in the operand list. + if (!BinaryOperator::isNeg(TheOp) && !BinaryOperator::isFNeg(TheOp) && + !BinaryOperator::isNot(TheOp)) + continue; + + Value *X = nullptr; + if (BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp)) + X = BinaryOperator::getNegArgument(TheOp); + else if (BinaryOperator::isNot(TheOp)) + X = BinaryOperator::getNotArgument(TheOp); + + unsigned FoundX = FindInOperandList(Ops, i, X); + if (FoundX == i) + continue; + + // Remove X and -X from the operand list. + if (Ops.size() == 2 && + (BinaryOperator::isNeg(TheOp) || BinaryOperator::isFNeg(TheOp))) + return Constant::getNullValue(X->getType()); + + // Remove X and ~X from the operand list. + if (Ops.size() == 2 && BinaryOperator::isNot(TheOp)) + return Constant::getAllOnesValue(X->getType()); + + Ops.erase(Ops.begin()+i); + if (i < FoundX) + --FoundX; + else + --i; // Need to back up an extra one. + Ops.erase(Ops.begin()+FoundX); + ++NumAnnihil; + --i; // Revisit element. + e -= 2; // Removed two elements. + + // if X and ~X we append -1 to the operand list. + if (BinaryOperator::isNot(TheOp)) { + Value *V = Constant::getAllOnesValue(X->getType()); + Ops.insert(Ops.end(), ValueEntry(getRank(V), V)); + e += 1; + } + } + + // Scan the operand list, checking to see if there are any common factors + // between operands. Consider something like A*A+A*B*C+D. We would like to + // reassociate this to A*(A+B*C)+D, which reduces the number of multiplies. + // To efficiently find this, we count the number of times a factor occurs + // for any ADD operands that are MULs. + DenseMap<Value*, unsigned> FactorOccurrences; + + // Keep track of each multiply we see, to avoid triggering on (X*4)+(X*4) + // where they are actually the same multiply. + unsigned MaxOcc = 0; + Value *MaxOccVal = nullptr; + for (unsigned i = 0, e = Ops.size(); i != e; ++i) { + BinaryOperator *BOp = + isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul); + if (!BOp) + continue; + + // Compute all of the factors of this added value. + SmallVector<Value*, 8> Factors; + FindSingleUseMultiplyFactors(BOp, Factors); + assert(Factors.size() > 1 && "Bad linearize!"); + + // Add one to FactorOccurrences for each unique factor in this op. + SmallPtrSet<Value*, 8> Duplicates; + for (unsigned i = 0, e = Factors.size(); i != e; ++i) { + Value *Factor = Factors[i]; + if (!Duplicates.insert(Factor).second) + continue; + + unsigned Occ = ++FactorOccurrences[Factor]; + if (Occ > MaxOcc) { + MaxOcc = Occ; + MaxOccVal = Factor; + } + + // If Factor is a negative constant, add the negated value as a factor + // because we can percolate the negate out. Watch for minint, which + // cannot be positivified. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Factor)) { + if (CI->isNegative() && !CI->isMinValue(true)) { + Factor = ConstantInt::get(CI->getContext(), -CI->getValue()); + if (!Duplicates.insert(Factor).second) + continue; + unsigned Occ = ++FactorOccurrences[Factor]; + if (Occ > MaxOcc) { + MaxOcc = Occ; + MaxOccVal = Factor; + } + } + } else if (ConstantFP *CF = dyn_cast<ConstantFP>(Factor)) { + if (CF->isNegative()) { + APFloat F(CF->getValueAPF()); + F.changeSign(); + Factor = ConstantFP::get(CF->getContext(), F); + if (!Duplicates.insert(Factor).second) + continue; + unsigned Occ = ++FactorOccurrences[Factor]; + if (Occ > MaxOcc) { + MaxOcc = Occ; + MaxOccVal = Factor; + } + } + } + } + } + + // If any factor occurred more than one time, we can pull it out. + if (MaxOcc > 1) { + DEBUG(dbgs() << "\nFACTORING [" << MaxOcc << "]: " << *MaxOccVal << '\n'); + ++NumFactor; + + // Create a new instruction that uses the MaxOccVal twice. If we don't do + // this, we could otherwise run into situations where removing a factor + // from an expression will drop a use of maxocc, and this can cause + // RemoveFactorFromExpression on successive values to behave differently. + Instruction *DummyInst = + I->getType()->isIntOrIntVectorTy() + ? BinaryOperator::CreateAdd(MaxOccVal, MaxOccVal) + : BinaryOperator::CreateFAdd(MaxOccVal, MaxOccVal); + + SmallVector<WeakTrackingVH, 4> NewMulOps; + for (unsigned i = 0; i != Ops.size(); ++i) { + // Only try to remove factors from expressions we're allowed to. + BinaryOperator *BOp = + isReassociableOp(Ops[i].Op, Instruction::Mul, Instruction::FMul); + if (!BOp) + continue; + + if (Value *V = RemoveFactorFromExpression(Ops[i].Op, MaxOccVal)) { + // The factorized operand may occur several times. Convert them all in + // one fell swoop. + for (unsigned j = Ops.size(); j != i;) { + --j; + if (Ops[j].Op == Ops[i].Op) { + NewMulOps.push_back(V); + Ops.erase(Ops.begin()+j); + } + } + --i; + } + } + + // No need for extra uses anymore. + DummyInst->deleteValue(); + + unsigned NumAddedValues = NewMulOps.size(); + Value *V = EmitAddTreeOfValues(I, NewMulOps); + + // Now that we have inserted the add tree, optimize it. This allows us to + // handle cases that require multiple factoring steps, such as this: + // A*A*B + A*A*C --> A*(A*B+A*C) --> A*(A*(B+C)) + assert(NumAddedValues > 1 && "Each occurrence should contribute a value"); + (void)NumAddedValues; + if (Instruction *VI = dyn_cast<Instruction>(V)) + RedoInsts.insert(VI); + + // Create the multiply. + Instruction *V2 = CreateMul(V, MaxOccVal, "reass.mul", I, I); + + // Rerun associate on the multiply in case the inner expression turned into + // a multiply. We want to make sure that we keep things in canonical form. + RedoInsts.insert(V2); + + // If every add operand included the factor (e.g. "A*B + A*C"), then the + // entire result expression is just the multiply "A*(B+C)". + if (Ops.empty()) + return V2; + + // Otherwise, we had some input that didn't have the factor, such as + // "A*B + A*C + D" -> "A*(B+C) + D". Add the new multiply to the list of + // things being added by this operation. + Ops.insert(Ops.begin(), ValueEntry(getRank(V2), V2)); + } + + return nullptr; +} + +/// \brief Build up a vector of value/power pairs factoring a product. +/// +/// Given a series of multiplication operands, build a vector of factors and +/// the powers each is raised to when forming the final product. Sort them in +/// the order of descending power. +/// +/// (x*x) -> [(x, 2)] +/// ((x*x)*x) -> [(x, 3)] +/// ((((x*y)*x)*y)*x) -> [(x, 3), (y, 2)] +/// +/// \returns Whether any factors have a power greater than one. +static bool collectMultiplyFactors(SmallVectorImpl<ValueEntry> &Ops, + SmallVectorImpl<Factor> &Factors) { + // FIXME: Have Ops be (ValueEntry, Multiplicity) pairs, simplifying this. + // Compute the sum of powers of simplifiable factors. + unsigned FactorPowerSum = 0; + for (unsigned Idx = 1, Size = Ops.size(); Idx < Size; ++Idx) { + Value *Op = Ops[Idx-1].Op; + + // Count the number of occurrences of this value. + unsigned Count = 1; + for (; Idx < Size && Ops[Idx].Op == Op; ++Idx) + ++Count; + // Track for simplification all factors which occur 2 or more times. + if (Count > 1) + FactorPowerSum += Count; + } + + // We can only simplify factors if the sum of the powers of our simplifiable + // factors is 4 or higher. When that is the case, we will *always* have + // a simplification. This is an important invariant to prevent cyclicly + // trying to simplify already minimal formations. + if (FactorPowerSum < 4) + return false; + + // Now gather the simplifiable factors, removing them from Ops. + FactorPowerSum = 0; + for (unsigned Idx = 1; Idx < Ops.size(); ++Idx) { + Value *Op = Ops[Idx-1].Op; + + // Count the number of occurrences of this value. + unsigned Count = 1; + for (; Idx < Ops.size() && Ops[Idx].Op == Op; ++Idx) + ++Count; + if (Count == 1) + continue; + // Move an even number of occurrences to Factors. + Count &= ~1U; + Idx -= Count; + FactorPowerSum += Count; + Factors.push_back(Factor(Op, Count)); + Ops.erase(Ops.begin()+Idx, Ops.begin()+Idx+Count); + } + + // None of the adjustments above should have reduced the sum of factor powers + // below our mininum of '4'. + assert(FactorPowerSum >= 4); + + std::stable_sort(Factors.begin(), Factors.end(), + [](const Factor &LHS, const Factor &RHS) { + return LHS.Power > RHS.Power; + }); + return true; +} + +/// \brief Build a tree of multiplies, computing the product of Ops. +static Value *buildMultiplyTree(IRBuilder<> &Builder, + SmallVectorImpl<Value*> &Ops) { + if (Ops.size() == 1) + return Ops.back(); + + Value *LHS = Ops.pop_back_val(); + do { + if (LHS->getType()->isIntOrIntVectorTy()) + LHS = Builder.CreateMul(LHS, Ops.pop_back_val()); + else + LHS = Builder.CreateFMul(LHS, Ops.pop_back_val()); + } while (!Ops.empty()); + + return LHS; +} + +/// \brief Build a minimal multiplication DAG for (a^x)*(b^y)*(c^z)*... +/// +/// Given a vector of values raised to various powers, where no two values are +/// equal and the powers are sorted in decreasing order, compute the minimal +/// DAG of multiplies to compute the final product, and return that product +/// value. +Value * +ReassociatePass::buildMinimalMultiplyDAG(IRBuilder<> &Builder, + SmallVectorImpl<Factor> &Factors) { + assert(Factors[0].Power); + SmallVector<Value *, 4> OuterProduct; + for (unsigned LastIdx = 0, Idx = 1, Size = Factors.size(); + Idx < Size && Factors[Idx].Power > 0; ++Idx) { + if (Factors[Idx].Power != Factors[LastIdx].Power) { + LastIdx = Idx; + continue; + } + + // We want to multiply across all the factors with the same power so that + // we can raise them to that power as a single entity. Build a mini tree + // for that. + SmallVector<Value *, 4> InnerProduct; + InnerProduct.push_back(Factors[LastIdx].Base); + do { + InnerProduct.push_back(Factors[Idx].Base); + ++Idx; + } while (Idx < Size && Factors[Idx].Power == Factors[LastIdx].Power); + + // Reset the base value of the first factor to the new expression tree. + // We'll remove all the factors with the same power in a second pass. + Value *M = Factors[LastIdx].Base = buildMultiplyTree(Builder, InnerProduct); + if (Instruction *MI = dyn_cast<Instruction>(M)) + RedoInsts.insert(MI); + + LastIdx = Idx; + } + // Unique factors with equal powers -- we've folded them into the first one's + // base. + Factors.erase(std::unique(Factors.begin(), Factors.end(), + [](const Factor &LHS, const Factor &RHS) { + return LHS.Power == RHS.Power; + }), + Factors.end()); + + // Iteratively collect the base of each factor with an add power into the + // outer product, and halve each power in preparation for squaring the + // expression. + for (unsigned Idx = 0, Size = Factors.size(); Idx != Size; ++Idx) { + if (Factors[Idx].Power & 1) + OuterProduct.push_back(Factors[Idx].Base); + Factors[Idx].Power >>= 1; + } + if (Factors[0].Power) { + Value *SquareRoot = buildMinimalMultiplyDAG(Builder, Factors); + OuterProduct.push_back(SquareRoot); + OuterProduct.push_back(SquareRoot); + } + if (OuterProduct.size() == 1) + return OuterProduct.front(); + + Value *V = buildMultiplyTree(Builder, OuterProduct); + return V; +} + +Value *ReassociatePass::OptimizeMul(BinaryOperator *I, + SmallVectorImpl<ValueEntry> &Ops) { + // We can only optimize the multiplies when there is a chain of more than + // three, such that a balanced tree might require fewer total multiplies. + if (Ops.size() < 4) + return nullptr; + + // Try to turn linear trees of multiplies without other uses of the + // intermediate stages into minimal multiply DAGs with perfect sub-expression + // re-use. + SmallVector<Factor, 4> Factors; + if (!collectMultiplyFactors(Ops, Factors)) + return nullptr; // All distinct factors, so nothing left for us to do. + + IRBuilder<> Builder(I); + // The reassociate transformation for FP operations is performed only + // if unsafe algebra is permitted by FastMathFlags. Propagate those flags + // to the newly generated operations. + if (auto FPI = dyn_cast<FPMathOperator>(I)) + Builder.setFastMathFlags(FPI->getFastMathFlags()); + + Value *V = buildMinimalMultiplyDAG(Builder, Factors); + if (Ops.empty()) + return V; + + ValueEntry NewEntry = ValueEntry(getRank(V), V); + Ops.insert(std::lower_bound(Ops.begin(), Ops.end(), NewEntry), NewEntry); + return nullptr; +} + +Value *ReassociatePass::OptimizeExpression(BinaryOperator *I, + SmallVectorImpl<ValueEntry> &Ops) { + // Now that we have the linearized expression tree, try to optimize it. + // Start by folding any constants that we found. + Constant *Cst = nullptr; + unsigned Opcode = I->getOpcode(); + while (!Ops.empty() && isa<Constant>(Ops.back().Op)) { + Constant *C = cast<Constant>(Ops.pop_back_val().Op); + Cst = Cst ? ConstantExpr::get(Opcode, C, Cst) : C; + } + // If there was nothing but constants then we are done. + if (Ops.empty()) + return Cst; + + // Put the combined constant back at the end of the operand list, except if + // there is no point. For example, an add of 0 gets dropped here, while a + // multiplication by zero turns the whole expression into zero. + if (Cst && Cst != ConstantExpr::getBinOpIdentity(Opcode, I->getType())) { + if (Cst == ConstantExpr::getBinOpAbsorber(Opcode, I->getType())) + return Cst; + Ops.push_back(ValueEntry(0, Cst)); + } + + if (Ops.size() == 1) return Ops[0].Op; + + // Handle destructive annihilation due to identities between elements in the + // argument list here. + unsigned NumOps = Ops.size(); + switch (Opcode) { + default: break; + case Instruction::And: + case Instruction::Or: + if (Value *Result = OptimizeAndOrXor(Opcode, Ops)) + return Result; + break; + + case Instruction::Xor: + if (Value *Result = OptimizeXor(I, Ops)) + return Result; + break; + + case Instruction::Add: + case Instruction::FAdd: + if (Value *Result = OptimizeAdd(I, Ops)) + return Result; + break; + + case Instruction::Mul: + case Instruction::FMul: + if (Value *Result = OptimizeMul(I, Ops)) + return Result; + break; + } + + if (Ops.size() != NumOps) + return OptimizeExpression(I, Ops); + return nullptr; +} + +// Remove dead instructions and if any operands are trivially dead add them to +// Insts so they will be removed as well. +void ReassociatePass::RecursivelyEraseDeadInsts( + Instruction *I, SetVector<AssertingVH<Instruction>> &Insts) { + assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); + SmallVector<Value *, 4> Ops(I->op_begin(), I->op_end()); + ValueRankMap.erase(I); + Insts.remove(I); + RedoInsts.remove(I); + I->eraseFromParent(); + for (auto Op : Ops) + if (Instruction *OpInst = dyn_cast<Instruction>(Op)) + if (OpInst->use_empty()) + Insts.insert(OpInst); +} + +/// Zap the given instruction, adding interesting operands to the work list. +void ReassociatePass::EraseInst(Instruction *I) { + assert(isInstructionTriviallyDead(I) && "Trivially dead instructions only!"); + DEBUG(dbgs() << "Erasing dead inst: "; I->dump()); + + SmallVector<Value*, 8> Ops(I->op_begin(), I->op_end()); + // Erase the dead instruction. + ValueRankMap.erase(I); + RedoInsts.remove(I); + I->eraseFromParent(); + // Optimize its operands. + SmallPtrSet<Instruction *, 8> Visited; // Detect self-referential nodes. + for (unsigned i = 0, e = Ops.size(); i != e; ++i) + if (Instruction *Op = dyn_cast<Instruction>(Ops[i])) { + // If this is a node in an expression tree, climb to the expression root + // and add that since that's where optimization actually happens. + unsigned Opcode = Op->getOpcode(); + while (Op->hasOneUse() && Op->user_back()->getOpcode() == Opcode && + Visited.insert(Op).second) + Op = Op->user_back(); + RedoInsts.insert(Op); + } + + MadeChange = true; +} + +// Canonicalize expressions of the following form: +// x + (-Constant * y) -> x - (Constant * y) +// x - (-Constant * y) -> x + (Constant * y) +Instruction *ReassociatePass::canonicalizeNegConstExpr(Instruction *I) { + if (!I->hasOneUse() || I->getType()->isVectorTy()) + return nullptr; + + // Must be a fmul or fdiv instruction. + unsigned Opcode = I->getOpcode(); + if (Opcode != Instruction::FMul && Opcode != Instruction::FDiv) + return nullptr; + + auto *C0 = dyn_cast<ConstantFP>(I->getOperand(0)); + auto *C1 = dyn_cast<ConstantFP>(I->getOperand(1)); + + // Both operands are constant, let it get constant folded away. + if (C0 && C1) + return nullptr; + + ConstantFP *CF = C0 ? C0 : C1; + + // Must have one constant operand. + if (!CF) + return nullptr; + + // Must be a negative ConstantFP. + if (!CF->isNegative()) + return nullptr; + + // User must be a binary operator with one or more uses. + Instruction *User = I->user_back(); + if (!isa<BinaryOperator>(User) || User->use_empty()) + return nullptr; + + unsigned UserOpcode = User->getOpcode(); + if (UserOpcode != Instruction::FAdd && UserOpcode != Instruction::FSub) + return nullptr; + + // Subtraction is not commutative. Explicitly, the following transform is + // not valid: (-Constant * y) - x -> x + (Constant * y) + if (!User->isCommutative() && User->getOperand(1) != I) + return nullptr; + + // Don't canonicalize x + (-Constant * y) -> x - (Constant * y), if the + // resulting subtract will be broken up later. This can get us into an + // infinite loop during reassociation. + if (UserOpcode == Instruction::FAdd && ShouldBreakUpSubtract(User)) + return nullptr; + + // Change the sign of the constant. + APFloat Val = CF->getValueAPF(); + Val.changeSign(); + I->setOperand(C0 ? 0 : 1, ConstantFP::get(CF->getContext(), Val)); + + // Canonicalize I to RHS to simplify the next bit of logic. E.g., + // ((-Const*y) + x) -> (x + (-Const*y)). + if (User->getOperand(0) == I && User->isCommutative()) + cast<BinaryOperator>(User)->swapOperands(); + + Value *Op0 = User->getOperand(0); + Value *Op1 = User->getOperand(1); + BinaryOperator *NI; + switch (UserOpcode) { + default: + llvm_unreachable("Unexpected Opcode!"); + case Instruction::FAdd: + NI = BinaryOperator::CreateFSub(Op0, Op1); + NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags()); + break; + case Instruction::FSub: + NI = BinaryOperator::CreateFAdd(Op0, Op1); + NI->setFastMathFlags(cast<FPMathOperator>(User)->getFastMathFlags()); + break; + } + + NI->insertBefore(User); + NI->setName(User->getName()); + User->replaceAllUsesWith(NI); + NI->setDebugLoc(I->getDebugLoc()); + RedoInsts.insert(I); + MadeChange = true; + return NI; +} + +/// Inspect and optimize the given instruction. Note that erasing +/// instructions is not allowed. +void ReassociatePass::OptimizeInst(Instruction *I) { + // Only consider operations that we understand. + if (!isa<BinaryOperator>(I)) + return; + + if (I->getOpcode() == Instruction::Shl && isa<ConstantInt>(I->getOperand(1))) + // If an operand of this shift is a reassociable multiply, or if the shift + // is used by a reassociable multiply or add, turn into a multiply. + if (isReassociableOp(I->getOperand(0), Instruction::Mul) || + (I->hasOneUse() && + (isReassociableOp(I->user_back(), Instruction::Mul) || + isReassociableOp(I->user_back(), Instruction::Add)))) { + Instruction *NI = ConvertShiftToMul(I); + RedoInsts.insert(I); + MadeChange = true; + I = NI; + } + + // Canonicalize negative constants out of expressions. + if (Instruction *Res = canonicalizeNegConstExpr(I)) + I = Res; + + // Commute binary operators, to canonicalize the order of their operands. + // This can potentially expose more CSE opportunities, and makes writing other + // transformations simpler. + if (I->isCommutative()) + canonicalizeOperands(I); + + // Don't optimize floating-point instructions unless they are 'fast'. + if (I->getType()->isFPOrFPVectorTy() && !I->isFast()) + return; + + // Do not reassociate boolean (i1) expressions. We want to preserve the + // original order of evaluation for short-circuited comparisons that + // SimplifyCFG has folded to AND/OR expressions. If the expression + // is not further optimized, it is likely to be transformed back to a + // short-circuited form for code gen, and the source order may have been + // optimized for the most likely conditions. + if (I->getType()->isIntegerTy(1)) + return; + + // If this is a subtract instruction which is not already in negate form, + // see if we can convert it to X+-Y. + if (I->getOpcode() == Instruction::Sub) { + if (ShouldBreakUpSubtract(I)) { + Instruction *NI = BreakUpSubtract(I, RedoInsts); + RedoInsts.insert(I); + MadeChange = true; + I = NI; + } else if (BinaryOperator::isNeg(I)) { + // Otherwise, this is a negation. See if the operand is a multiply tree + // and if this is not an inner node of a multiply tree. + if (isReassociableOp(I->getOperand(1), Instruction::Mul) && + (!I->hasOneUse() || + !isReassociableOp(I->user_back(), Instruction::Mul))) { + Instruction *NI = LowerNegateToMultiply(I); + // If the negate was simplified, revisit the users to see if we can + // reassociate further. + for (User *U : NI->users()) { + if (BinaryOperator *Tmp = dyn_cast<BinaryOperator>(U)) + RedoInsts.insert(Tmp); + } + RedoInsts.insert(I); + MadeChange = true; + I = NI; + } + } + } else if (I->getOpcode() == Instruction::FSub) { + if (ShouldBreakUpSubtract(I)) { + Instruction *NI = BreakUpSubtract(I, RedoInsts); + RedoInsts.insert(I); + MadeChange = true; + I = NI; + } else if (BinaryOperator::isFNeg(I)) { + // Otherwise, this is a negation. See if the operand is a multiply tree + // and if this is not an inner node of a multiply tree. + if (isReassociableOp(I->getOperand(1), Instruction::FMul) && + (!I->hasOneUse() || + !isReassociableOp(I->user_back(), Instruction::FMul))) { + // If the negate was simplified, revisit the users to see if we can + // reassociate further. + Instruction *NI = LowerNegateToMultiply(I); + for (User *U : NI->users()) { + if (BinaryOperator *Tmp = dyn_cast<BinaryOperator>(U)) + RedoInsts.insert(Tmp); + } + RedoInsts.insert(I); + MadeChange = true; + I = NI; + } + } + } + + // If this instruction is an associative binary operator, process it. + if (!I->isAssociative()) return; + BinaryOperator *BO = cast<BinaryOperator>(I); + + // If this is an interior node of a reassociable tree, ignore it until we + // get to the root of the tree, to avoid N^2 analysis. + unsigned Opcode = BO->getOpcode(); + if (BO->hasOneUse() && BO->user_back()->getOpcode() == Opcode) { + // During the initial run we will get to the root of the tree. + // But if we get here while we are redoing instructions, there is no + // guarantee that the root will be visited. So Redo later + if (BO->user_back() != BO && + BO->getParent() == BO->user_back()->getParent()) + RedoInsts.insert(BO->user_back()); + return; + } + + // If this is an add tree that is used by a sub instruction, ignore it + // until we process the subtract. + if (BO->hasOneUse() && BO->getOpcode() == Instruction::Add && + cast<Instruction>(BO->user_back())->getOpcode() == Instruction::Sub) + return; + if (BO->hasOneUse() && BO->getOpcode() == Instruction::FAdd && + cast<Instruction>(BO->user_back())->getOpcode() == Instruction::FSub) + return; + + ReassociateExpression(BO); +} + +void ReassociatePass::ReassociateExpression(BinaryOperator *I) { + // First, walk the expression tree, linearizing the tree, collecting the + // operand information. + SmallVector<RepeatedValue, 8> Tree; + MadeChange |= LinearizeExprTree(I, Tree); + SmallVector<ValueEntry, 8> Ops; + Ops.reserve(Tree.size()); + for (unsigned i = 0, e = Tree.size(); i != e; ++i) { + RepeatedValue E = Tree[i]; + Ops.append(E.second.getZExtValue(), + ValueEntry(getRank(E.first), E.first)); + } + + DEBUG(dbgs() << "RAIn:\t"; PrintOps(I, Ops); dbgs() << '\n'); + + // Now that we have linearized the tree to a list and have gathered all of + // the operands and their ranks, sort the operands by their rank. Use a + // stable_sort so that values with equal ranks will have their relative + // positions maintained (and so the compiler is deterministic). Note that + // this sorts so that the highest ranking values end up at the beginning of + // the vector. + std::stable_sort(Ops.begin(), Ops.end()); + + // Now that we have the expression tree in a convenient + // sorted form, optimize it globally if possible. + if (Value *V = OptimizeExpression(I, Ops)) { + if (V == I) + // Self-referential expression in unreachable code. + return; + // This expression tree simplified to something that isn't a tree, + // eliminate it. + DEBUG(dbgs() << "Reassoc to scalar: " << *V << '\n'); + I->replaceAllUsesWith(V); + if (Instruction *VI = dyn_cast<Instruction>(V)) + if (I->getDebugLoc()) + VI->setDebugLoc(I->getDebugLoc()); + RedoInsts.insert(I); + ++NumAnnihil; + return; + } + + // We want to sink immediates as deeply as possible except in the case where + // this is a multiply tree used only by an add, and the immediate is a -1. + // In this case we reassociate to put the negation on the outside so that we + // can fold the negation into the add: (-X)*Y + Z -> Z-X*Y + if (I->hasOneUse()) { + if (I->getOpcode() == Instruction::Mul && + cast<Instruction>(I->user_back())->getOpcode() == Instruction::Add && + isa<ConstantInt>(Ops.back().Op) && + cast<ConstantInt>(Ops.back().Op)->isMinusOne()) { + ValueEntry Tmp = Ops.pop_back_val(); + Ops.insert(Ops.begin(), Tmp); + } else if (I->getOpcode() == Instruction::FMul && + cast<Instruction>(I->user_back())->getOpcode() == + Instruction::FAdd && + isa<ConstantFP>(Ops.back().Op) && + cast<ConstantFP>(Ops.back().Op)->isExactlyValue(-1.0)) { + ValueEntry Tmp = Ops.pop_back_val(); + Ops.insert(Ops.begin(), Tmp); + } + } + + DEBUG(dbgs() << "RAOut:\t"; PrintOps(I, Ops); dbgs() << '\n'); + + if (Ops.size() == 1) { + if (Ops[0].Op == I) + // Self-referential expression in unreachable code. + return; + + // This expression tree simplified to something that isn't a tree, + // eliminate it. + I->replaceAllUsesWith(Ops[0].Op); + if (Instruction *OI = dyn_cast<Instruction>(Ops[0].Op)) + OI->setDebugLoc(I->getDebugLoc()); + RedoInsts.insert(I); + return; + } + + if (Ops.size() > 2 && Ops.size() <= GlobalReassociateLimit) { + // Find the pair with the highest count in the pairmap and move it to the + // back of the list so that it can later be CSE'd. + // example: + // a*b*c*d*e + // if c*e is the most "popular" pair, we can express this as + // (((c*e)*d)*b)*a + unsigned Max = 1; + unsigned BestRank = 0; + std::pair<unsigned, unsigned> BestPair; + unsigned Idx = I->getOpcode() - Instruction::BinaryOpsBegin; + for (unsigned i = 0; i < Ops.size() - 1; ++i) + for (unsigned j = i + 1; j < Ops.size(); ++j) { + unsigned Score = 0; + Value *Op0 = Ops[i].Op; + Value *Op1 = Ops[j].Op; + if (std::less<Value *>()(Op1, Op0)) + std::swap(Op0, Op1); + auto it = PairMap[Idx].find({Op0, Op1}); + if (it != PairMap[Idx].end()) + Score += it->second; + + unsigned MaxRank = std::max(Ops[i].Rank, Ops[j].Rank); + if (Score > Max || (Score == Max && MaxRank < BestRank)) { + BestPair = {i, j}; + Max = Score; + BestRank = MaxRank; + } + } + if (Max > 1) { + auto Op0 = Ops[BestPair.first]; + auto Op1 = Ops[BestPair.second]; + Ops.erase(&Ops[BestPair.second]); + Ops.erase(&Ops[BestPair.first]); + Ops.push_back(Op0); + Ops.push_back(Op1); + } + } + // Now that we ordered and optimized the expressions, splat them back into + // the expression tree, removing any unneeded nodes. + RewriteExprTree(I, Ops); +} + +void +ReassociatePass::BuildPairMap(ReversePostOrderTraversal<Function *> &RPOT) { + // Make a "pairmap" of how often each operand pair occurs. + for (BasicBlock *BI : RPOT) { + for (Instruction &I : *BI) { + if (!I.isAssociative()) + continue; + + // Ignore nodes that aren't at the root of trees. + if (I.hasOneUse() && I.user_back()->getOpcode() == I.getOpcode()) + continue; + + // Collect all operands in a single reassociable expression. + // Since Reassociate has already been run once, we can assume things + // are already canonical according to Reassociation's regime. + SmallVector<Value *, 8> Worklist = { I.getOperand(0), I.getOperand(1) }; + SmallVector<Value *, 8> Ops; + while (!Worklist.empty() && Ops.size() <= GlobalReassociateLimit) { + Value *Op = Worklist.pop_back_val(); + Instruction *OpI = dyn_cast<Instruction>(Op); + if (!OpI || OpI->getOpcode() != I.getOpcode() || !OpI->hasOneUse()) { + Ops.push_back(Op); + continue; + } + // Be paranoid about self-referencing expressions in unreachable code. + if (OpI->getOperand(0) != OpI) + Worklist.push_back(OpI->getOperand(0)); + if (OpI->getOperand(1) != OpI) + Worklist.push_back(OpI->getOperand(1)); + } + // Skip extremely long expressions. + if (Ops.size() > GlobalReassociateLimit) + continue; + + // Add all pairwise combinations of operands to the pair map. + unsigned BinaryIdx = I.getOpcode() - Instruction::BinaryOpsBegin; + SmallSet<std::pair<Value *, Value*>, 32> Visited; + for (unsigned i = 0; i < Ops.size() - 1; ++i) { + for (unsigned j = i + 1; j < Ops.size(); ++j) { + // Canonicalize operand orderings. + Value *Op0 = Ops[i]; + Value *Op1 = Ops[j]; + if (std::less<Value *>()(Op1, Op0)) + std::swap(Op0, Op1); + if (!Visited.insert({Op0, Op1}).second) + continue; + auto res = PairMap[BinaryIdx].insert({{Op0, Op1}, 1}); + if (!res.second) + ++res.first->second; + } + } + } + } +} + +PreservedAnalyses ReassociatePass::run(Function &F, FunctionAnalysisManager &) { + // Get the functions basic blocks in Reverse Post Order. This order is used by + // BuildRankMap to pre calculate ranks correctly. It also excludes dead basic + // blocks (it has been seen that the analysis in this pass could hang when + // analysing dead basic blocks). + ReversePostOrderTraversal<Function *> RPOT(&F); + + // Calculate the rank map for F. + BuildRankMap(F, RPOT); + + // Build the pair map before running reassociate. + // Technically this would be more accurate if we did it after one round + // of reassociation, but in practice it doesn't seem to help much on + // real-world code, so don't waste the compile time running reassociate + // twice. + // If a user wants, they could expicitly run reassociate twice in their + // pass pipeline for further potential gains. + // It might also be possible to update the pair map during runtime, but the + // overhead of that may be large if there's many reassociable chains. + BuildPairMap(RPOT); + + MadeChange = false; + + // Traverse the same blocks that were analysed by BuildRankMap. + for (BasicBlock *BI : RPOT) { + assert(RankMap.count(&*BI) && "BB should be ranked."); + // Optimize every instruction in the basic block. + for (BasicBlock::iterator II = BI->begin(), IE = BI->end(); II != IE;) + if (isInstructionTriviallyDead(&*II)) { + EraseInst(&*II++); + } else { + OptimizeInst(&*II); + assert(II->getParent() == &*BI && "Moved to a different block!"); + ++II; + } + + // Make a copy of all the instructions to be redone so we can remove dead + // instructions. + SetVector<AssertingVH<Instruction>> ToRedo(RedoInsts); + // Iterate over all instructions to be reevaluated and remove trivially dead + // instructions. If any operand of the trivially dead instruction becomes + // dead mark it for deletion as well. Continue this process until all + // trivially dead instructions have been removed. + while (!ToRedo.empty()) { + Instruction *I = ToRedo.pop_back_val(); + if (isInstructionTriviallyDead(I)) { + RecursivelyEraseDeadInsts(I, ToRedo); + MadeChange = true; + } + } + + // Now that we have removed dead instructions, we can reoptimize the + // remaining instructions. + while (!RedoInsts.empty()) { + Instruction *I = RedoInsts.pop_back_val(); + if (isInstructionTriviallyDead(I)) + EraseInst(I); + else + OptimizeInst(I); + } + } + + // We are done with the rank map and pair map. + RankMap.clear(); + ValueRankMap.clear(); + for (auto &Entry : PairMap) + Entry.clear(); + + if (MadeChange) { + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + return PA; + } + + return PreservedAnalyses::all(); +} + +namespace { + + class ReassociateLegacyPass : public FunctionPass { + ReassociatePass Impl; + + public: + static char ID; // Pass identification, replacement for typeid + + ReassociateLegacyPass() : FunctionPass(ID) { + initializeReassociateLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + FunctionAnalysisManager DummyFAM; + auto PA = Impl.run(F, DummyFAM); + return !PA.areAllPreserved(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + }; + +} // end anonymous namespace + +char ReassociateLegacyPass::ID = 0; + +INITIALIZE_PASS(ReassociateLegacyPass, "reassociate", + "Reassociate expressions", false, false) + +// Public interface to the Reassociate pass +FunctionPass *llvm::createReassociatePass() { + return new ReassociateLegacyPass(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/Reg2Mem.cpp b/contrib/llvm/lib/Transforms/Scalar/Reg2Mem.cpp new file mode 100644 index 000000000000..96295683314c --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/Reg2Mem.cpp @@ -0,0 +1,128 @@ +//===- Reg2Mem.cpp - Convert registers to allocas -------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file demotes all registers to memory references. It is intended to be +// the inverse of PromoteMemoryToRegister. By converting to loads, the only +// values live across basic blocks are allocas and loads before phi nodes. +// It is intended that this should make CFG hacking much easier. +// To make later hacking easier, the entry block is split into two, such that +// all introduced allocas and nothing else are in the entry block. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/Statistic.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <list> +using namespace llvm; + +#define DEBUG_TYPE "reg2mem" + +STATISTIC(NumRegsDemoted, "Number of registers demoted"); +STATISTIC(NumPhisDemoted, "Number of phi-nodes demoted"); + +namespace { + struct RegToMem : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + RegToMem() : FunctionPass(ID) { + initializeRegToMemPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequiredID(BreakCriticalEdgesID); + AU.addPreservedID(BreakCriticalEdgesID); + } + + bool valueEscapes(const Instruction *Inst) const { + const BasicBlock *BB = Inst->getParent(); + for (const User *U : Inst->users()) { + const Instruction *UI = cast<Instruction>(U); + if (UI->getParent() != BB || isa<PHINode>(UI)) + return true; + } + return false; + } + + bool runOnFunction(Function &F) override; + }; +} + +char RegToMem::ID = 0; +INITIALIZE_PASS_BEGIN(RegToMem, "reg2mem", "Demote all values to stack slots", + false, false) +INITIALIZE_PASS_DEPENDENCY(BreakCriticalEdges) +INITIALIZE_PASS_END(RegToMem, "reg2mem", "Demote all values to stack slots", + false, false) + +bool RegToMem::runOnFunction(Function &F) { + if (F.isDeclaration() || skipFunction(F)) + return false; + + // Insert all new allocas into entry block. + BasicBlock *BBEntry = &F.getEntryBlock(); + assert(pred_empty(BBEntry) && + "Entry block to function must not have predecessors!"); + + // Find first non-alloca instruction and create insertion point. This is + // safe if block is well-formed: it always have terminator, otherwise + // we'll get and assertion. + BasicBlock::iterator I = BBEntry->begin(); + while (isa<AllocaInst>(I)) ++I; + + CastInst *AllocaInsertionPoint = new BitCastInst( + Constant::getNullValue(Type::getInt32Ty(F.getContext())), + Type::getInt32Ty(F.getContext()), "reg2mem alloca point", &*I); + + // Find the escaped instructions. But don't create stack slots for + // allocas in entry block. + std::list<Instruction*> WorkList; + for (BasicBlock &ibb : F) + for (BasicBlock::iterator iib = ibb.begin(), iie = ibb.end(); iib != iie; + ++iib) { + if (!(isa<AllocaInst>(iib) && iib->getParent() == BBEntry) && + valueEscapes(&*iib)) { + WorkList.push_front(&*iib); + } + } + + // Demote escaped instructions + NumRegsDemoted += WorkList.size(); + for (Instruction *ilb : WorkList) + DemoteRegToStack(*ilb, false, AllocaInsertionPoint); + + WorkList.clear(); + + // Find all phi's + for (BasicBlock &ibb : F) + for (BasicBlock::iterator iib = ibb.begin(), iie = ibb.end(); iib != iie; + ++iib) + if (isa<PHINode>(iib)) + WorkList.push_front(&*iib); + + // Demote phi nodes + NumPhisDemoted += WorkList.size(); + for (Instruction *ilb : WorkList) + DemotePHIToStack(cast<PHINode>(ilb), AllocaInsertionPoint); + + return true; +} + + +// createDemoteRegisterToMemory - Provide an entry point to create this pass. +char &llvm::DemoteRegisterToMemoryID = RegToMem::ID; +FunctionPass *llvm::createDemoteRegisterToMemoryPass() { + return new RegToMem(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp b/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp new file mode 100644 index 000000000000..c44edbed8ed9 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/RewriteStatepointsForGC.cpp @@ -0,0 +1,2830 @@ +//===- RewriteStatepointsForGC.cpp - Make GC relocations explicit ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Rewrite call/invoke instructions so as to make potential relocations +// performed by the garbage collector explicit in the IR. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/RewriteStatepointsForGC.h" + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/None.h" +#include "llvm/ADT/Optional.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/CallingConv.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/MDBuilder.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Statepoint.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <iterator> +#include <set> +#include <string> +#include <utility> +#include <vector> + +#define DEBUG_TYPE "rewrite-statepoints-for-gc" + +using namespace llvm; + +// Print the liveset found at the insert location +static cl::opt<bool> PrintLiveSet("spp-print-liveset", cl::Hidden, + cl::init(false)); +static cl::opt<bool> PrintLiveSetSize("spp-print-liveset-size", cl::Hidden, + cl::init(false)); + +// Print out the base pointers for debugging +static cl::opt<bool> PrintBasePointers("spp-print-base-pointers", cl::Hidden, + cl::init(false)); + +// Cost threshold measuring when it is profitable to rematerialize value instead +// of relocating it +static cl::opt<unsigned> +RematerializationThreshold("spp-rematerialization-threshold", cl::Hidden, + cl::init(6)); + +#ifdef EXPENSIVE_CHECKS +static bool ClobberNonLive = true; +#else +static bool ClobberNonLive = false; +#endif + +static cl::opt<bool, true> ClobberNonLiveOverride("rs4gc-clobber-non-live", + cl::location(ClobberNonLive), + cl::Hidden); + +static cl::opt<bool> + AllowStatepointWithNoDeoptInfo("rs4gc-allow-statepoint-with-no-deopt-info", + cl::Hidden, cl::init(true)); + +/// The IR fed into RewriteStatepointsForGC may have had attributes and +/// metadata implying dereferenceability that are no longer valid/correct after +/// RewriteStatepointsForGC has run. This is because semantically, after +/// RewriteStatepointsForGC runs, all calls to gc.statepoint "free" the entire +/// heap. stripNonValidData (conservatively) restores +/// correctness by erasing all attributes in the module that externally imply +/// dereferenceability. Similar reasoning also applies to the noalias +/// attributes and metadata. gc.statepoint can touch the entire heap including +/// noalias objects. +/// Apart from attributes and metadata, we also remove instructions that imply +/// constant physical memory: llvm.invariant.start. +static void stripNonValidData(Module &M); + +static bool shouldRewriteStatepointsIn(Function &F); + +PreservedAnalyses RewriteStatepointsForGC::run(Module &M, + ModuleAnalysisManager &AM) { + bool Changed = false; + auto &FAM = AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager(); + for (Function &F : M) { + // Nothing to do for declarations. + if (F.isDeclaration() || F.empty()) + continue; + + // Policy choice says not to rewrite - the most common reason is that we're + // compiling code without a GCStrategy. + if (!shouldRewriteStatepointsIn(F)) + continue; + + auto &DT = FAM.getResult<DominatorTreeAnalysis>(F); + auto &TTI = FAM.getResult<TargetIRAnalysis>(F); + auto &TLI = FAM.getResult<TargetLibraryAnalysis>(F); + Changed |= runOnFunction(F, DT, TTI, TLI); + } + if (!Changed) + return PreservedAnalyses::all(); + + // stripNonValidData asserts that shouldRewriteStatepointsIn + // returns true for at least one function in the module. Since at least + // one function changed, we know that the precondition is satisfied. + stripNonValidData(M); + + PreservedAnalyses PA; + PA.preserve<TargetIRAnalysis>(); + PA.preserve<TargetLibraryAnalysis>(); + return PA; +} + +namespace { + +class RewriteStatepointsForGCLegacyPass : public ModulePass { + RewriteStatepointsForGC Impl; + +public: + static char ID; // Pass identification, replacement for typeid + + RewriteStatepointsForGCLegacyPass() : ModulePass(ID), Impl() { + initializeRewriteStatepointsForGCLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + bool Changed = false; + const TargetLibraryInfo &TLI = + getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + for (Function &F : M) { + // Nothing to do for declarations. + if (F.isDeclaration() || F.empty()) + continue; + + // Policy choice says not to rewrite - the most common reason is that + // we're compiling code without a GCStrategy. + if (!shouldRewriteStatepointsIn(F)) + continue; + + TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + auto &DT = getAnalysis<DominatorTreeWrapperPass>(F).getDomTree(); + + Changed |= Impl.runOnFunction(F, DT, TTI, TLI); + } + + if (!Changed) + return false; + + // stripNonValidData asserts that shouldRewriteStatepointsIn + // returns true for at least one function in the module. Since at least + // one function changed, we know that the precondition is satisfied. + stripNonValidData(M); + return true; + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + // We add and rewrite a bunch of instructions, but don't really do much + // else. We could in theory preserve a lot more analyses here. + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } +}; + +} // end anonymous namespace + +char RewriteStatepointsForGCLegacyPass::ID = 0; + +ModulePass *llvm::createRewriteStatepointsForGCLegacyPass() { + return new RewriteStatepointsForGCLegacyPass(); +} + +INITIALIZE_PASS_BEGIN(RewriteStatepointsForGCLegacyPass, + "rewrite-statepoints-for-gc", + "Make relocations explicit at statepoints", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(RewriteStatepointsForGCLegacyPass, + "rewrite-statepoints-for-gc", + "Make relocations explicit at statepoints", false, false) + +namespace { + +struct GCPtrLivenessData { + /// Values defined in this block. + MapVector<BasicBlock *, SetVector<Value *>> KillSet; + + /// Values used in this block (and thus live); does not included values + /// killed within this block. + MapVector<BasicBlock *, SetVector<Value *>> LiveSet; + + /// Values live into this basic block (i.e. used by any + /// instruction in this basic block or ones reachable from here) + MapVector<BasicBlock *, SetVector<Value *>> LiveIn; + + /// Values live out of this basic block (i.e. live into + /// any successor block) + MapVector<BasicBlock *, SetVector<Value *>> LiveOut; +}; + +// The type of the internal cache used inside the findBasePointers family +// of functions. From the callers perspective, this is an opaque type and +// should not be inspected. +// +// In the actual implementation this caches two relations: +// - The base relation itself (i.e. this pointer is based on that one) +// - The base defining value relation (i.e. before base_phi insertion) +// Generally, after the execution of a full findBasePointer call, only the +// base relation will remain. Internally, we add a mixture of the two +// types, then update all the second type to the first type +using DefiningValueMapTy = MapVector<Value *, Value *>; +using StatepointLiveSetTy = SetVector<Value *>; +using RematerializedValueMapTy = + MapVector<AssertingVH<Instruction>, AssertingVH<Value>>; + +struct PartiallyConstructedSafepointRecord { + /// The set of values known to be live across this safepoint + StatepointLiveSetTy LiveSet; + + /// Mapping from live pointers to a base-defining-value + MapVector<Value *, Value *> PointerToBase; + + /// The *new* gc.statepoint instruction itself. This produces the token + /// that normal path gc.relocates and the gc.result are tied to. + Instruction *StatepointToken; + + /// Instruction to which exceptional gc relocates are attached + /// Makes it easier to iterate through them during relocationViaAlloca. + Instruction *UnwindToken; + + /// Record live values we are rematerialized instead of relocating. + /// They are not included into 'LiveSet' field. + /// Maps rematerialized copy to it's original value. + RematerializedValueMapTy RematerializedValues; +}; + +} // end anonymous namespace + +static ArrayRef<Use> GetDeoptBundleOperands(ImmutableCallSite CS) { + Optional<OperandBundleUse> DeoptBundle = + CS.getOperandBundle(LLVMContext::OB_deopt); + + if (!DeoptBundle.hasValue()) { + assert(AllowStatepointWithNoDeoptInfo && + "Found non-leaf call without deopt info!"); + return None; + } + + return DeoptBundle.getValue().Inputs; +} + +/// Compute the live-in set for every basic block in the function +static void computeLiveInValues(DominatorTree &DT, Function &F, + GCPtrLivenessData &Data); + +/// Given results from the dataflow liveness computation, find the set of live +/// Values at a particular instruction. +static void findLiveSetAtInst(Instruction *inst, GCPtrLivenessData &Data, + StatepointLiveSetTy &out); + +// TODO: Once we can get to the GCStrategy, this becomes +// Optional<bool> isGCManagedPointer(const Type *Ty) const override { + +static bool isGCPointerType(Type *T) { + if (auto *PT = dyn_cast<PointerType>(T)) + // For the sake of this example GC, we arbitrarily pick addrspace(1) as our + // GC managed heap. We know that a pointer into this heap needs to be + // updated and that no other pointer does. + return PT->getAddressSpace() == 1; + return false; +} + +// Return true if this type is one which a) is a gc pointer or contains a GC +// pointer and b) is of a type this code expects to encounter as a live value. +// (The insertion code will assert that a type which matches (a) and not (b) +// is not encountered.) +static bool isHandledGCPointerType(Type *T) { + // We fully support gc pointers + if (isGCPointerType(T)) + return true; + // We partially support vectors of gc pointers. The code will assert if it + // can't handle something. + if (auto VT = dyn_cast<VectorType>(T)) + if (isGCPointerType(VT->getElementType())) + return true; + return false; +} + +#ifndef NDEBUG +/// Returns true if this type contains a gc pointer whether we know how to +/// handle that type or not. +static bool containsGCPtrType(Type *Ty) { + if (isGCPointerType(Ty)) + return true; + if (VectorType *VT = dyn_cast<VectorType>(Ty)) + return isGCPointerType(VT->getScalarType()); + if (ArrayType *AT = dyn_cast<ArrayType>(Ty)) + return containsGCPtrType(AT->getElementType()); + if (StructType *ST = dyn_cast<StructType>(Ty)) + return llvm::any_of(ST->subtypes(), containsGCPtrType); + return false; +} + +// Returns true if this is a type which a) is a gc pointer or contains a GC +// pointer and b) is of a type which the code doesn't expect (i.e. first class +// aggregates). Used to trip assertions. +static bool isUnhandledGCPointerType(Type *Ty) { + return containsGCPtrType(Ty) && !isHandledGCPointerType(Ty); +} +#endif + +// Return the name of the value suffixed with the provided value, or if the +// value didn't have a name, the default value specified. +static std::string suffixed_name_or(Value *V, StringRef Suffix, + StringRef DefaultName) { + return V->hasName() ? (V->getName() + Suffix).str() : DefaultName.str(); +} + +// Conservatively identifies any definitions which might be live at the +// given instruction. The analysis is performed immediately before the +// given instruction. Values defined by that instruction are not considered +// live. Values used by that instruction are considered live. +static void +analyzeParsePointLiveness(DominatorTree &DT, + GCPtrLivenessData &OriginalLivenessData, CallSite CS, + PartiallyConstructedSafepointRecord &Result) { + Instruction *Inst = CS.getInstruction(); + + StatepointLiveSetTy LiveSet; + findLiveSetAtInst(Inst, OriginalLivenessData, LiveSet); + + if (PrintLiveSet) { + dbgs() << "Live Variables:\n"; + for (Value *V : LiveSet) + dbgs() << " " << V->getName() << " " << *V << "\n"; + } + if (PrintLiveSetSize) { + dbgs() << "Safepoint For: " << CS.getCalledValue()->getName() << "\n"; + dbgs() << "Number live values: " << LiveSet.size() << "\n"; + } + Result.LiveSet = LiveSet; +} + +static bool isKnownBaseResult(Value *V); + +namespace { + +/// A single base defining value - An immediate base defining value for an +/// instruction 'Def' is an input to 'Def' whose base is also a base of 'Def'. +/// For instructions which have multiple pointer [vector] inputs or that +/// transition between vector and scalar types, there is no immediate base +/// defining value. The 'base defining value' for 'Def' is the transitive +/// closure of this relation stopping at the first instruction which has no +/// immediate base defining value. The b.d.v. might itself be a base pointer, +/// but it can also be an arbitrary derived pointer. +struct BaseDefiningValueResult { + /// Contains the value which is the base defining value. + Value * const BDV; + + /// True if the base defining value is also known to be an actual base + /// pointer. + const bool IsKnownBase; + + BaseDefiningValueResult(Value *BDV, bool IsKnownBase) + : BDV(BDV), IsKnownBase(IsKnownBase) { +#ifndef NDEBUG + // Check consistency between new and old means of checking whether a BDV is + // a base. + bool MustBeBase = isKnownBaseResult(BDV); + assert(!MustBeBase || MustBeBase == IsKnownBase); +#endif + } +}; + +} // end anonymous namespace + +static BaseDefiningValueResult findBaseDefiningValue(Value *I); + +/// Return a base defining value for the 'Index' element of the given vector +/// instruction 'I'. If Index is null, returns a BDV for the entire vector +/// 'I'. As an optimization, this method will try to determine when the +/// element is known to already be a base pointer. If this can be established, +/// the second value in the returned pair will be true. Note that either a +/// vector or a pointer typed value can be returned. For the former, the +/// vector returned is a BDV (and possibly a base) of the entire vector 'I'. +/// If the later, the return pointer is a BDV (or possibly a base) for the +/// particular element in 'I'. +static BaseDefiningValueResult +findBaseDefiningValueOfVector(Value *I) { + // Each case parallels findBaseDefiningValue below, see that code for + // detailed motivation. + + if (isa<Argument>(I)) + // An incoming argument to the function is a base pointer + return BaseDefiningValueResult(I, true); + + if (isa<Constant>(I)) + // Base of constant vector consists only of constant null pointers. + // For reasoning see similar case inside 'findBaseDefiningValue' function. + return BaseDefiningValueResult(ConstantAggregateZero::get(I->getType()), + true); + + if (isa<LoadInst>(I)) + return BaseDefiningValueResult(I, true); + + if (isa<InsertElementInst>(I)) + // We don't know whether this vector contains entirely base pointers or + // not. To be conservatively correct, we treat it as a BDV and will + // duplicate code as needed to construct a parallel vector of bases. + return BaseDefiningValueResult(I, false); + + if (isa<ShuffleVectorInst>(I)) + // We don't know whether this vector contains entirely base pointers or + // not. To be conservatively correct, we treat it as a BDV and will + // duplicate code as needed to construct a parallel vector of bases. + // TODO: There a number of local optimizations which could be applied here + // for particular sufflevector patterns. + return BaseDefiningValueResult(I, false); + + // The behavior of getelementptr instructions is the same for vector and + // non-vector data types. + if (auto *GEP = dyn_cast<GetElementPtrInst>(I)) + return findBaseDefiningValue(GEP->getPointerOperand()); + + // If the pointer comes through a bitcast of a vector of pointers to + // a vector of another type of pointer, then look through the bitcast + if (auto *BC = dyn_cast<BitCastInst>(I)) + return findBaseDefiningValue(BC->getOperand(0)); + + // A PHI or Select is a base defining value. The outer findBasePointer + // algorithm is responsible for constructing a base value for this BDV. + assert((isa<SelectInst>(I) || isa<PHINode>(I)) && + "unknown vector instruction - no base found for vector element"); + return BaseDefiningValueResult(I, false); +} + +/// Helper function for findBasePointer - Will return a value which either a) +/// defines the base pointer for the input, b) blocks the simple search +/// (i.e. a PHI or Select of two derived pointers), or c) involves a change +/// from pointer to vector type or back. +static BaseDefiningValueResult findBaseDefiningValue(Value *I) { + assert(I->getType()->isPtrOrPtrVectorTy() && + "Illegal to ask for the base pointer of a non-pointer type"); + + if (I->getType()->isVectorTy()) + return findBaseDefiningValueOfVector(I); + + if (isa<Argument>(I)) + // An incoming argument to the function is a base pointer + // We should have never reached here if this argument isn't an gc value + return BaseDefiningValueResult(I, true); + + if (isa<Constant>(I)) { + // We assume that objects with a constant base (e.g. a global) can't move + // and don't need to be reported to the collector because they are always + // live. Besides global references, all kinds of constants (e.g. undef, + // constant expressions, null pointers) can be introduced by the inliner or + // the optimizer, especially on dynamically dead paths. + // Here we treat all of them as having single null base. By doing this we + // trying to avoid problems reporting various conflicts in a form of + // "phi (const1, const2)" or "phi (const, regular gc ptr)". + // See constant.ll file for relevant test cases. + + return BaseDefiningValueResult( + ConstantPointerNull::get(cast<PointerType>(I->getType())), true); + } + + if (CastInst *CI = dyn_cast<CastInst>(I)) { + Value *Def = CI->stripPointerCasts(); + // If stripping pointer casts changes the address space there is an + // addrspacecast in between. + assert(cast<PointerType>(Def->getType())->getAddressSpace() == + cast<PointerType>(CI->getType())->getAddressSpace() && + "unsupported addrspacecast"); + // If we find a cast instruction here, it means we've found a cast which is + // not simply a pointer cast (i.e. an inttoptr). We don't know how to + // handle int->ptr conversion. + assert(!isa<CastInst>(Def) && "shouldn't find another cast here"); + return findBaseDefiningValue(Def); + } + + if (isa<LoadInst>(I)) + // The value loaded is an gc base itself + return BaseDefiningValueResult(I, true); + + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) + // The base of this GEP is the base + return findBaseDefiningValue(GEP->getPointerOperand()); + + if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { + switch (II->getIntrinsicID()) { + default: + // fall through to general call handling + break; + case Intrinsic::experimental_gc_statepoint: + llvm_unreachable("statepoints don't produce pointers"); + case Intrinsic::experimental_gc_relocate: + // Rerunning safepoint insertion after safepoints are already + // inserted is not supported. It could probably be made to work, + // but why are you doing this? There's no good reason. + llvm_unreachable("repeat safepoint insertion is not supported"); + case Intrinsic::gcroot: + // Currently, this mechanism hasn't been extended to work with gcroot. + // There's no reason it couldn't be, but I haven't thought about the + // implications much. + llvm_unreachable( + "interaction with the gcroot mechanism is not supported"); + } + } + // We assume that functions in the source language only return base + // pointers. This should probably be generalized via attributes to support + // both source language and internal functions. + if (isa<CallInst>(I) || isa<InvokeInst>(I)) + return BaseDefiningValueResult(I, true); + + // TODO: I have absolutely no idea how to implement this part yet. It's not + // necessarily hard, I just haven't really looked at it yet. + assert(!isa<LandingPadInst>(I) && "Landing Pad is unimplemented"); + + if (isa<AtomicCmpXchgInst>(I)) + // A CAS is effectively a atomic store and load combined under a + // predicate. From the perspective of base pointers, we just treat it + // like a load. + return BaseDefiningValueResult(I, true); + + assert(!isa<AtomicRMWInst>(I) && "Xchg handled above, all others are " + "binary ops which don't apply to pointers"); + + // The aggregate ops. Aggregates can either be in the heap or on the + // stack, but in either case, this is simply a field load. As a result, + // this is a defining definition of the base just like a load is. + if (isa<ExtractValueInst>(I)) + return BaseDefiningValueResult(I, true); + + // We should never see an insert vector since that would require we be + // tracing back a struct value not a pointer value. + assert(!isa<InsertValueInst>(I) && + "Base pointer for a struct is meaningless"); + + // An extractelement produces a base result exactly when it's input does. + // We may need to insert a parallel instruction to extract the appropriate + // element out of the base vector corresponding to the input. Given this, + // it's analogous to the phi and select case even though it's not a merge. + if (isa<ExtractElementInst>(I)) + // Note: There a lot of obvious peephole cases here. This are deliberately + // handled after the main base pointer inference algorithm to make writing + // test cases to exercise that code easier. + return BaseDefiningValueResult(I, false); + + // The last two cases here don't return a base pointer. Instead, they + // return a value which dynamically selects from among several base + // derived pointers (each with it's own base potentially). It's the job of + // the caller to resolve these. + assert((isa<SelectInst>(I) || isa<PHINode>(I)) && + "missing instruction case in findBaseDefiningValing"); + return BaseDefiningValueResult(I, false); +} + +/// Returns the base defining value for this value. +static Value *findBaseDefiningValueCached(Value *I, DefiningValueMapTy &Cache) { + Value *&Cached = Cache[I]; + if (!Cached) { + Cached = findBaseDefiningValue(I).BDV; + DEBUG(dbgs() << "fBDV-cached: " << I->getName() << " -> " + << Cached->getName() << "\n"); + } + assert(Cache[I] != nullptr); + return Cached; +} + +/// Return a base pointer for this value if known. Otherwise, return it's +/// base defining value. +static Value *findBaseOrBDV(Value *I, DefiningValueMapTy &Cache) { + Value *Def = findBaseDefiningValueCached(I, Cache); + auto Found = Cache.find(Def); + if (Found != Cache.end()) { + // Either a base-of relation, or a self reference. Caller must check. + return Found->second; + } + // Only a BDV available + return Def; +} + +/// Given the result of a call to findBaseDefiningValue, or findBaseOrBDV, +/// is it known to be a base pointer? Or do we need to continue searching. +static bool isKnownBaseResult(Value *V) { + if (!isa<PHINode>(V) && !isa<SelectInst>(V) && + !isa<ExtractElementInst>(V) && !isa<InsertElementInst>(V) && + !isa<ShuffleVectorInst>(V)) { + // no recursion possible + return true; + } + if (isa<Instruction>(V) && + cast<Instruction>(V)->getMetadata("is_base_value")) { + // This is a previously inserted base phi or select. We know + // that this is a base value. + return true; + } + + // We need to keep searching + return false; +} + +namespace { + +/// Models the state of a single base defining value in the findBasePointer +/// algorithm for determining where a new instruction is needed to propagate +/// the base of this BDV. +class BDVState { +public: + enum Status { Unknown, Base, Conflict }; + + BDVState() : BaseValue(nullptr) {} + + explicit BDVState(Status Status, Value *BaseValue = nullptr) + : Status(Status), BaseValue(BaseValue) { + assert(Status != Base || BaseValue); + } + + explicit BDVState(Value *BaseValue) : Status(Base), BaseValue(BaseValue) {} + + Status getStatus() const { return Status; } + Value *getBaseValue() const { return BaseValue; } + + bool isBase() const { return getStatus() == Base; } + bool isUnknown() const { return getStatus() == Unknown; } + bool isConflict() const { return getStatus() == Conflict; } + + bool operator==(const BDVState &Other) const { + return BaseValue == Other.BaseValue && Status == Other.Status; + } + + bool operator!=(const BDVState &other) const { return !(*this == other); } + + LLVM_DUMP_METHOD + void dump() const { + print(dbgs()); + dbgs() << '\n'; + } + + void print(raw_ostream &OS) const { + switch (getStatus()) { + case Unknown: + OS << "U"; + break; + case Base: + OS << "B"; + break; + case Conflict: + OS << "C"; + break; + } + OS << " (" << getBaseValue() << " - " + << (getBaseValue() ? getBaseValue()->getName() : "nullptr") << "): "; + } + +private: + Status Status = Unknown; + AssertingVH<Value> BaseValue; // Non-null only if Status == Base. +}; + +} // end anonymous namespace + +#ifndef NDEBUG +static raw_ostream &operator<<(raw_ostream &OS, const BDVState &State) { + State.print(OS); + return OS; +} +#endif + +static BDVState meetBDVStateImpl(const BDVState &LHS, const BDVState &RHS) { + switch (LHS.getStatus()) { + case BDVState::Unknown: + return RHS; + + case BDVState::Base: + assert(LHS.getBaseValue() && "can't be null"); + if (RHS.isUnknown()) + return LHS; + + if (RHS.isBase()) { + if (LHS.getBaseValue() == RHS.getBaseValue()) { + assert(LHS == RHS && "equality broken!"); + return LHS; + } + return BDVState(BDVState::Conflict); + } + assert(RHS.isConflict() && "only three states!"); + return BDVState(BDVState::Conflict); + + case BDVState::Conflict: + return LHS; + } + llvm_unreachable("only three states!"); +} + +// Values of type BDVState form a lattice, and this function implements the meet +// operation. +static BDVState meetBDVState(const BDVState &LHS, const BDVState &RHS) { + BDVState Result = meetBDVStateImpl(LHS, RHS); + assert(Result == meetBDVStateImpl(RHS, LHS) && + "Math is wrong: meet does not commute!"); + return Result; +} + +/// For a given value or instruction, figure out what base ptr its derived from. +/// For gc objects, this is simply itself. On success, returns a value which is +/// the base pointer. (This is reliable and can be used for relocation.) On +/// failure, returns nullptr. +static Value *findBasePointer(Value *I, DefiningValueMapTy &Cache) { + Value *Def = findBaseOrBDV(I, Cache); + + if (isKnownBaseResult(Def)) + return Def; + + // Here's the rough algorithm: + // - For every SSA value, construct a mapping to either an actual base + // pointer or a PHI which obscures the base pointer. + // - Construct a mapping from PHI to unknown TOP state. Use an + // optimistic algorithm to propagate base pointer information. Lattice + // looks like: + // UNKNOWN + // b1 b2 b3 b4 + // CONFLICT + // When algorithm terminates, all PHIs will either have a single concrete + // base or be in a conflict state. + // - For every conflict, insert a dummy PHI node without arguments. Add + // these to the base[Instruction] = BasePtr mapping. For every + // non-conflict, add the actual base. + // - For every conflict, add arguments for the base[a] of each input + // arguments. + // + // Note: A simpler form of this would be to add the conflict form of all + // PHIs without running the optimistic algorithm. This would be + // analogous to pessimistic data flow and would likely lead to an + // overall worse solution. + +#ifndef NDEBUG + auto isExpectedBDVType = [](Value *BDV) { + return isa<PHINode>(BDV) || isa<SelectInst>(BDV) || + isa<ExtractElementInst>(BDV) || isa<InsertElementInst>(BDV) || + isa<ShuffleVectorInst>(BDV); + }; +#endif + + // Once populated, will contain a mapping from each potentially non-base BDV + // to a lattice value (described above) which corresponds to that BDV. + // We use the order of insertion (DFS over the def/use graph) to provide a + // stable deterministic ordering for visiting DenseMaps (which are unordered) + // below. This is important for deterministic compilation. + MapVector<Value *, BDVState> States; + + // Recursively fill in all base defining values reachable from the initial + // one for which we don't already know a definite base value for + /* scope */ { + SmallVector<Value*, 16> Worklist; + Worklist.push_back(Def); + States.insert({Def, BDVState()}); + while (!Worklist.empty()) { + Value *Current = Worklist.pop_back_val(); + assert(!isKnownBaseResult(Current) && "why did it get added?"); + + auto visitIncomingValue = [&](Value *InVal) { + Value *Base = findBaseOrBDV(InVal, Cache); + if (isKnownBaseResult(Base)) + // Known bases won't need new instructions introduced and can be + // ignored safely + return; + assert(isExpectedBDVType(Base) && "the only non-base values " + "we see should be base defining values"); + if (States.insert(std::make_pair(Base, BDVState())).second) + Worklist.push_back(Base); + }; + if (PHINode *PN = dyn_cast<PHINode>(Current)) { + for (Value *InVal : PN->incoming_values()) + visitIncomingValue(InVal); + } else if (SelectInst *SI = dyn_cast<SelectInst>(Current)) { + visitIncomingValue(SI->getTrueValue()); + visitIncomingValue(SI->getFalseValue()); + } else if (auto *EE = dyn_cast<ExtractElementInst>(Current)) { + visitIncomingValue(EE->getVectorOperand()); + } else if (auto *IE = dyn_cast<InsertElementInst>(Current)) { + visitIncomingValue(IE->getOperand(0)); // vector operand + visitIncomingValue(IE->getOperand(1)); // scalar operand + } else if (auto *SV = dyn_cast<ShuffleVectorInst>(Current)) { + visitIncomingValue(SV->getOperand(0)); + visitIncomingValue(SV->getOperand(1)); + } + else { + llvm_unreachable("Unimplemented instruction case"); + } + } + } + +#ifndef NDEBUG + DEBUG(dbgs() << "States after initialization:\n"); + for (auto Pair : States) { + DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); + } +#endif + + // Return a phi state for a base defining value. We'll generate a new + // base state for known bases and expect to find a cached state otherwise. + auto getStateForBDV = [&](Value *baseValue) { + if (isKnownBaseResult(baseValue)) + return BDVState(baseValue); + auto I = States.find(baseValue); + assert(I != States.end() && "lookup failed!"); + return I->second; + }; + + bool Progress = true; + while (Progress) { +#ifndef NDEBUG + const size_t OldSize = States.size(); +#endif + Progress = false; + // We're only changing values in this loop, thus safe to keep iterators. + // Since this is computing a fixed point, the order of visit does not + // effect the result. TODO: We could use a worklist here and make this run + // much faster. + for (auto Pair : States) { + Value *BDV = Pair.first; + assert(!isKnownBaseResult(BDV) && "why did it get added?"); + + // Given an input value for the current instruction, return a BDVState + // instance which represents the BDV of that value. + auto getStateForInput = [&](Value *V) mutable { + Value *BDV = findBaseOrBDV(V, Cache); + return getStateForBDV(BDV); + }; + + BDVState NewState; + if (SelectInst *SI = dyn_cast<SelectInst>(BDV)) { + NewState = meetBDVState(NewState, getStateForInput(SI->getTrueValue())); + NewState = + meetBDVState(NewState, getStateForInput(SI->getFalseValue())); + } else if (PHINode *PN = dyn_cast<PHINode>(BDV)) { + for (Value *Val : PN->incoming_values()) + NewState = meetBDVState(NewState, getStateForInput(Val)); + } else if (auto *EE = dyn_cast<ExtractElementInst>(BDV)) { + // The 'meet' for an extractelement is slightly trivial, but it's still + // useful in that it drives us to conflict if our input is. + NewState = + meetBDVState(NewState, getStateForInput(EE->getVectorOperand())); + } else if (auto *IE = dyn_cast<InsertElementInst>(BDV)){ + // Given there's a inherent type mismatch between the operands, will + // *always* produce Conflict. + NewState = meetBDVState(NewState, getStateForInput(IE->getOperand(0))); + NewState = meetBDVState(NewState, getStateForInput(IE->getOperand(1))); + } else { + // The only instance this does not return a Conflict is when both the + // vector operands are the same vector. + auto *SV = cast<ShuffleVectorInst>(BDV); + NewState = meetBDVState(NewState, getStateForInput(SV->getOperand(0))); + NewState = meetBDVState(NewState, getStateForInput(SV->getOperand(1))); + } + + BDVState OldState = States[BDV]; + if (OldState != NewState) { + Progress = true; + States[BDV] = NewState; + } + } + + assert(OldSize == States.size() && + "fixed point shouldn't be adding any new nodes to state"); + } + +#ifndef NDEBUG + DEBUG(dbgs() << "States after meet iteration:\n"); + for (auto Pair : States) { + DEBUG(dbgs() << " " << Pair.second << " for " << *Pair.first << "\n"); + } +#endif + + // Insert Phis for all conflicts + // TODO: adjust naming patterns to avoid this order of iteration dependency + for (auto Pair : States) { + Instruction *I = cast<Instruction>(Pair.first); + BDVState State = Pair.second; + assert(!isKnownBaseResult(I) && "why did it get added?"); + assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); + + // extractelement instructions are a bit special in that we may need to + // insert an extract even when we know an exact base for the instruction. + // The problem is that we need to convert from a vector base to a scalar + // base for the particular indice we're interested in. + if (State.isBase() && isa<ExtractElementInst>(I) && + isa<VectorType>(State.getBaseValue()->getType())) { + auto *EE = cast<ExtractElementInst>(I); + // TODO: In many cases, the new instruction is just EE itself. We should + // exploit this, but can't do it here since it would break the invariant + // about the BDV not being known to be a base. + auto *BaseInst = ExtractElementInst::Create( + State.getBaseValue(), EE->getIndexOperand(), "base_ee", EE); + BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); + States[I] = BDVState(BDVState::Base, BaseInst); + } + + // Since we're joining a vector and scalar base, they can never be the + // same. As a result, we should always see insert element having reached + // the conflict state. + assert(!isa<InsertElementInst>(I) || State.isConflict()); + + if (!State.isConflict()) + continue; + + /// Create and insert a new instruction which will represent the base of + /// the given instruction 'I'. + auto MakeBaseInstPlaceholder = [](Instruction *I) -> Instruction* { + if (isa<PHINode>(I)) { + BasicBlock *BB = I->getParent(); + int NumPreds = std::distance(pred_begin(BB), pred_end(BB)); + assert(NumPreds > 0 && "how did we reach here"); + std::string Name = suffixed_name_or(I, ".base", "base_phi"); + return PHINode::Create(I->getType(), NumPreds, Name, I); + } else if (SelectInst *SI = dyn_cast<SelectInst>(I)) { + // The undef will be replaced later + UndefValue *Undef = UndefValue::get(SI->getType()); + std::string Name = suffixed_name_or(I, ".base", "base_select"); + return SelectInst::Create(SI->getCondition(), Undef, Undef, Name, SI); + } else if (auto *EE = dyn_cast<ExtractElementInst>(I)) { + UndefValue *Undef = UndefValue::get(EE->getVectorOperand()->getType()); + std::string Name = suffixed_name_or(I, ".base", "base_ee"); + return ExtractElementInst::Create(Undef, EE->getIndexOperand(), Name, + EE); + } else if (auto *IE = dyn_cast<InsertElementInst>(I)) { + UndefValue *VecUndef = UndefValue::get(IE->getOperand(0)->getType()); + UndefValue *ScalarUndef = UndefValue::get(IE->getOperand(1)->getType()); + std::string Name = suffixed_name_or(I, ".base", "base_ie"); + return InsertElementInst::Create(VecUndef, ScalarUndef, + IE->getOperand(2), Name, IE); + } else { + auto *SV = cast<ShuffleVectorInst>(I); + UndefValue *VecUndef = UndefValue::get(SV->getOperand(0)->getType()); + std::string Name = suffixed_name_or(I, ".base", "base_sv"); + return new ShuffleVectorInst(VecUndef, VecUndef, SV->getOperand(2), + Name, SV); + } + }; + Instruction *BaseInst = MakeBaseInstPlaceholder(I); + // Add metadata marking this as a base value + BaseInst->setMetadata("is_base_value", MDNode::get(I->getContext(), {})); + States[I] = BDVState(BDVState::Conflict, BaseInst); + } + + // Returns a instruction which produces the base pointer for a given + // instruction. The instruction is assumed to be an input to one of the BDVs + // seen in the inference algorithm above. As such, we must either already + // know it's base defining value is a base, or have inserted a new + // instruction to propagate the base of it's BDV and have entered that newly + // introduced instruction into the state table. In either case, we are + // assured to be able to determine an instruction which produces it's base + // pointer. + auto getBaseForInput = [&](Value *Input, Instruction *InsertPt) { + Value *BDV = findBaseOrBDV(Input, Cache); + Value *Base = nullptr; + if (isKnownBaseResult(BDV)) { + Base = BDV; + } else { + // Either conflict or base. + assert(States.count(BDV)); + Base = States[BDV].getBaseValue(); + } + assert(Base && "Can't be null"); + // The cast is needed since base traversal may strip away bitcasts + if (Base->getType() != Input->getType() && InsertPt) + Base = new BitCastInst(Base, Input->getType(), "cast", InsertPt); + return Base; + }; + + // Fixup all the inputs of the new PHIs. Visit order needs to be + // deterministic and predictable because we're naming newly created + // instructions. + for (auto Pair : States) { + Instruction *BDV = cast<Instruction>(Pair.first); + BDVState State = Pair.second; + + assert(!isKnownBaseResult(BDV) && "why did it get added?"); + assert(!State.isUnknown() && "Optimistic algorithm didn't complete!"); + if (!State.isConflict()) + continue; + + if (PHINode *BasePHI = dyn_cast<PHINode>(State.getBaseValue())) { + PHINode *PN = cast<PHINode>(BDV); + unsigned NumPHIValues = PN->getNumIncomingValues(); + for (unsigned i = 0; i < NumPHIValues; i++) { + Value *InVal = PN->getIncomingValue(i); + BasicBlock *InBB = PN->getIncomingBlock(i); + + // If we've already seen InBB, add the same incoming value + // we added for it earlier. The IR verifier requires phi + // nodes with multiple entries from the same basic block + // to have the same incoming value for each of those + // entries. If we don't do this check here and basephi + // has a different type than base, we'll end up adding two + // bitcasts (and hence two distinct values) as incoming + // values for the same basic block. + + int BlockIndex = BasePHI->getBasicBlockIndex(InBB); + if (BlockIndex != -1) { + Value *OldBase = BasePHI->getIncomingValue(BlockIndex); + BasePHI->addIncoming(OldBase, InBB); + +#ifndef NDEBUG + Value *Base = getBaseForInput(InVal, nullptr); + // In essence this assert states: the only way two values + // incoming from the same basic block may be different is by + // being different bitcasts of the same value. A cleanup + // that remains TODO is changing findBaseOrBDV to return an + // llvm::Value of the correct type (and still remain pure). + // This will remove the need to add bitcasts. + assert(Base->stripPointerCasts() == OldBase->stripPointerCasts() && + "Sanity -- findBaseOrBDV should be pure!"); +#endif + continue; + } + + // Find the instruction which produces the base for each input. We may + // need to insert a bitcast in the incoming block. + // TODO: Need to split critical edges if insertion is needed + Value *Base = getBaseForInput(InVal, InBB->getTerminator()); + BasePHI->addIncoming(Base, InBB); + } + assert(BasePHI->getNumIncomingValues() == NumPHIValues); + } else if (SelectInst *BaseSI = + dyn_cast<SelectInst>(State.getBaseValue())) { + SelectInst *SI = cast<SelectInst>(BDV); + + // Find the instruction which produces the base for each input. + // We may need to insert a bitcast. + BaseSI->setTrueValue(getBaseForInput(SI->getTrueValue(), BaseSI)); + BaseSI->setFalseValue(getBaseForInput(SI->getFalseValue(), BaseSI)); + } else if (auto *BaseEE = + dyn_cast<ExtractElementInst>(State.getBaseValue())) { + Value *InVal = cast<ExtractElementInst>(BDV)->getVectorOperand(); + // Find the instruction which produces the base for each input. We may + // need to insert a bitcast. + BaseEE->setOperand(0, getBaseForInput(InVal, BaseEE)); + } else if (auto *BaseIE = dyn_cast<InsertElementInst>(State.getBaseValue())){ + auto *BdvIE = cast<InsertElementInst>(BDV); + auto UpdateOperand = [&](int OperandIdx) { + Value *InVal = BdvIE->getOperand(OperandIdx); + Value *Base = getBaseForInput(InVal, BaseIE); + BaseIE->setOperand(OperandIdx, Base); + }; + UpdateOperand(0); // vector operand + UpdateOperand(1); // scalar operand + } else { + auto *BaseSV = cast<ShuffleVectorInst>(State.getBaseValue()); + auto *BdvSV = cast<ShuffleVectorInst>(BDV); + auto UpdateOperand = [&](int OperandIdx) { + Value *InVal = BdvSV->getOperand(OperandIdx); + Value *Base = getBaseForInput(InVal, BaseSV); + BaseSV->setOperand(OperandIdx, Base); + }; + UpdateOperand(0); // vector operand + UpdateOperand(1); // vector operand + } + } + + // Cache all of our results so we can cheaply reuse them + // NOTE: This is actually two caches: one of the base defining value + // relation and one of the base pointer relation! FIXME + for (auto Pair : States) { + auto *BDV = Pair.first; + Value *Base = Pair.second.getBaseValue(); + assert(BDV && Base); + assert(!isKnownBaseResult(BDV) && "why did it get added?"); + + DEBUG(dbgs() << "Updating base value cache" + << " for: " << BDV->getName() << " from: " + << (Cache.count(BDV) ? Cache[BDV]->getName().str() : "none") + << " to: " << Base->getName() << "\n"); + + if (Cache.count(BDV)) { + assert(isKnownBaseResult(Base) && + "must be something we 'know' is a base pointer"); + // Once we transition from the BDV relation being store in the Cache to + // the base relation being stored, it must be stable + assert((!isKnownBaseResult(Cache[BDV]) || Cache[BDV] == Base) && + "base relation should be stable"); + } + Cache[BDV] = Base; + } + assert(Cache.count(Def)); + return Cache[Def]; +} + +// For a set of live pointers (base and/or derived), identify the base +// pointer of the object which they are derived from. This routine will +// mutate the IR graph as needed to make the 'base' pointer live at the +// definition site of 'derived'. This ensures that any use of 'derived' can +// also use 'base'. This may involve the insertion of a number of +// additional PHI nodes. +// +// preconditions: live is a set of pointer type Values +// +// side effects: may insert PHI nodes into the existing CFG, will preserve +// CFG, will not remove or mutate any existing nodes +// +// post condition: PointerToBase contains one (derived, base) pair for every +// pointer in live. Note that derived can be equal to base if the original +// pointer was a base pointer. +static void +findBasePointers(const StatepointLiveSetTy &live, + MapVector<Value *, Value *> &PointerToBase, + DominatorTree *DT, DefiningValueMapTy &DVCache) { + for (Value *ptr : live) { + Value *base = findBasePointer(ptr, DVCache); + assert(base && "failed to find base pointer"); + PointerToBase[ptr] = base; + assert((!isa<Instruction>(base) || !isa<Instruction>(ptr) || + DT->dominates(cast<Instruction>(base)->getParent(), + cast<Instruction>(ptr)->getParent())) && + "The base we found better dominate the derived pointer"); + } +} + +/// Find the required based pointers (and adjust the live set) for the given +/// parse point. +static void findBasePointers(DominatorTree &DT, DefiningValueMapTy &DVCache, + CallSite CS, + PartiallyConstructedSafepointRecord &result) { + MapVector<Value *, Value *> PointerToBase; + findBasePointers(result.LiveSet, PointerToBase, &DT, DVCache); + + if (PrintBasePointers) { + errs() << "Base Pairs (w/o Relocation):\n"; + for (auto &Pair : PointerToBase) { + errs() << " derived "; + Pair.first->printAsOperand(errs(), false); + errs() << " base "; + Pair.second->printAsOperand(errs(), false); + errs() << "\n";; + } + } + + result.PointerToBase = PointerToBase; +} + +/// Given an updated version of the dataflow liveness results, update the +/// liveset and base pointer maps for the call site CS. +static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, + CallSite CS, + PartiallyConstructedSafepointRecord &result); + +static void recomputeLiveInValues( + Function &F, DominatorTree &DT, ArrayRef<CallSite> toUpdate, + MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) { + // TODO-PERF: reuse the original liveness, then simply run the dataflow + // again. The old values are still live and will help it stabilize quickly. + GCPtrLivenessData RevisedLivenessData; + computeLiveInValues(DT, F, RevisedLivenessData); + for (size_t i = 0; i < records.size(); i++) { + struct PartiallyConstructedSafepointRecord &info = records[i]; + recomputeLiveInValues(RevisedLivenessData, toUpdate[i], info); + } +} + +// When inserting gc.relocate and gc.result calls, we need to ensure there are +// no uses of the original value / return value between the gc.statepoint and +// the gc.relocate / gc.result call. One case which can arise is a phi node +// starting one of the successor blocks. We also need to be able to insert the +// gc.relocates only on the path which goes through the statepoint. We might +// need to split an edge to make this possible. +static BasicBlock * +normalizeForInvokeSafepoint(BasicBlock *BB, BasicBlock *InvokeParent, + DominatorTree &DT) { + BasicBlock *Ret = BB; + if (!BB->getUniquePredecessor()) + Ret = SplitBlockPredecessors(BB, InvokeParent, "", &DT); + + // Now that 'Ret' has unique predecessor we can safely remove all phi nodes + // from it + FoldSingleEntryPHINodes(Ret); + assert(!isa<PHINode>(Ret->begin()) && + "All PHI nodes should have been removed!"); + + // At this point, we can safely insert a gc.relocate or gc.result as the first + // instruction in Ret if needed. + return Ret; +} + +// Create new attribute set containing only attributes which can be transferred +// from original call to the safepoint. +static AttributeList legalizeCallAttributes(AttributeList AL) { + if (AL.isEmpty()) + return AL; + + // Remove the readonly, readnone, and statepoint function attributes. + AttrBuilder FnAttrs = AL.getFnAttributes(); + FnAttrs.removeAttribute(Attribute::ReadNone); + FnAttrs.removeAttribute(Attribute::ReadOnly); + for (Attribute A : AL.getFnAttributes()) { + if (isStatepointDirectiveAttr(A)) + FnAttrs.remove(A); + } + + // Just skip parameter and return attributes for now + LLVMContext &Ctx = AL.getContext(); + return AttributeList::get(Ctx, AttributeList::FunctionIndex, + AttributeSet::get(Ctx, FnAttrs)); +} + +/// Helper function to place all gc relocates necessary for the given +/// statepoint. +/// Inputs: +/// liveVariables - list of variables to be relocated. +/// liveStart - index of the first live variable. +/// basePtrs - base pointers. +/// statepointToken - statepoint instruction to which relocates should be +/// bound. +/// Builder - Llvm IR builder to be used to construct new calls. +static void CreateGCRelocates(ArrayRef<Value *> LiveVariables, + const int LiveStart, + ArrayRef<Value *> BasePtrs, + Instruction *StatepointToken, + IRBuilder<> Builder) { + if (LiveVariables.empty()) + return; + + auto FindIndex = [](ArrayRef<Value *> LiveVec, Value *Val) { + auto ValIt = llvm::find(LiveVec, Val); + assert(ValIt != LiveVec.end() && "Val not found in LiveVec!"); + size_t Index = std::distance(LiveVec.begin(), ValIt); + assert(Index < LiveVec.size() && "Bug in std::find?"); + return Index; + }; + Module *M = StatepointToken->getModule(); + + // All gc_relocate are generated as i8 addrspace(1)* (or a vector type whose + // element type is i8 addrspace(1)*). We originally generated unique + // declarations for each pointer type, but this proved problematic because + // the intrinsic mangling code is incomplete and fragile. Since we're moving + // towards a single unified pointer type anyways, we can just cast everything + // to an i8* of the right address space. A bitcast is added later to convert + // gc_relocate to the actual value's type. + auto getGCRelocateDecl = [&] (Type *Ty) { + assert(isHandledGCPointerType(Ty)); + auto AS = Ty->getScalarType()->getPointerAddressSpace(); + Type *NewTy = Type::getInt8PtrTy(M->getContext(), AS); + if (auto *VT = dyn_cast<VectorType>(Ty)) + NewTy = VectorType::get(NewTy, VT->getNumElements()); + return Intrinsic::getDeclaration(M, Intrinsic::experimental_gc_relocate, + {NewTy}); + }; + + // Lazily populated map from input types to the canonicalized form mentioned + // in the comment above. This should probably be cached somewhere more + // broadly. + DenseMap<Type*, Value*> TypeToDeclMap; + + for (unsigned i = 0; i < LiveVariables.size(); i++) { + // Generate the gc.relocate call and save the result + Value *BaseIdx = + Builder.getInt32(LiveStart + FindIndex(LiveVariables, BasePtrs[i])); + Value *LiveIdx = Builder.getInt32(LiveStart + i); + + Type *Ty = LiveVariables[i]->getType(); + if (!TypeToDeclMap.count(Ty)) + TypeToDeclMap[Ty] = getGCRelocateDecl(Ty); + Value *GCRelocateDecl = TypeToDeclMap[Ty]; + + // only specify a debug name if we can give a useful one + CallInst *Reloc = Builder.CreateCall( + GCRelocateDecl, {StatepointToken, BaseIdx, LiveIdx}, + suffixed_name_or(LiveVariables[i], ".relocated", "")); + // Trick CodeGen into thinking there are lots of free registers at this + // fake call. + Reloc->setCallingConv(CallingConv::Cold); + } +} + +namespace { + +/// This struct is used to defer RAUWs and `eraseFromParent` s. Using this +/// avoids having to worry about keeping around dangling pointers to Values. +class DeferredReplacement { + AssertingVH<Instruction> Old; + AssertingVH<Instruction> New; + bool IsDeoptimize = false; + + DeferredReplacement() = default; + +public: + static DeferredReplacement createRAUW(Instruction *Old, Instruction *New) { + assert(Old != New && Old && New && + "Cannot RAUW equal values or to / from null!"); + + DeferredReplacement D; + D.Old = Old; + D.New = New; + return D; + } + + static DeferredReplacement createDelete(Instruction *ToErase) { + DeferredReplacement D; + D.Old = ToErase; + return D; + } + + static DeferredReplacement createDeoptimizeReplacement(Instruction *Old) { +#ifndef NDEBUG + auto *F = cast<CallInst>(Old)->getCalledFunction(); + assert(F && F->getIntrinsicID() == Intrinsic::experimental_deoptimize && + "Only way to construct a deoptimize deferred replacement"); +#endif + DeferredReplacement D; + D.Old = Old; + D.IsDeoptimize = true; + return D; + } + + /// Does the task represented by this instance. + void doReplacement() { + Instruction *OldI = Old; + Instruction *NewI = New; + + assert(OldI != NewI && "Disallowed at construction?!"); + assert((!IsDeoptimize || !New) && + "Deoptimize instrinsics are not replaced!"); + + Old = nullptr; + New = nullptr; + + if (NewI) + OldI->replaceAllUsesWith(NewI); + + if (IsDeoptimize) { + // Note: we've inserted instructions, so the call to llvm.deoptimize may + // not necessarilly be followed by the matching return. + auto *RI = cast<ReturnInst>(OldI->getParent()->getTerminator()); + new UnreachableInst(RI->getContext(), RI); + RI->eraseFromParent(); + } + + OldI->eraseFromParent(); + } +}; + +} // end anonymous namespace + +static StringRef getDeoptLowering(CallSite CS) { + const char *DeoptLowering = "deopt-lowering"; + if (CS.hasFnAttr(DeoptLowering)) { + // FIXME: CallSite has a *really* confusing interface around attributes + // with values. + const AttributeList &CSAS = CS.getAttributes(); + if (CSAS.hasAttribute(AttributeList::FunctionIndex, DeoptLowering)) + return CSAS.getAttribute(AttributeList::FunctionIndex, DeoptLowering) + .getValueAsString(); + Function *F = CS.getCalledFunction(); + assert(F && F->hasFnAttribute(DeoptLowering)); + return F->getFnAttribute(DeoptLowering).getValueAsString(); + } + return "live-through"; +} + +static void +makeStatepointExplicitImpl(const CallSite CS, /* to replace */ + const SmallVectorImpl<Value *> &BasePtrs, + const SmallVectorImpl<Value *> &LiveVariables, + PartiallyConstructedSafepointRecord &Result, + std::vector<DeferredReplacement> &Replacements) { + assert(BasePtrs.size() == LiveVariables.size()); + + // Then go ahead and use the builder do actually do the inserts. We insert + // immediately before the previous instruction under the assumption that all + // arguments will be available here. We can't insert afterwards since we may + // be replacing a terminator. + Instruction *InsertBefore = CS.getInstruction(); + IRBuilder<> Builder(InsertBefore); + + ArrayRef<Value *> GCArgs(LiveVariables); + uint64_t StatepointID = StatepointDirectives::DefaultStatepointID; + uint32_t NumPatchBytes = 0; + uint32_t Flags = uint32_t(StatepointFlags::None); + + ArrayRef<Use> CallArgs(CS.arg_begin(), CS.arg_end()); + ArrayRef<Use> DeoptArgs = GetDeoptBundleOperands(CS); + ArrayRef<Use> TransitionArgs; + if (auto TransitionBundle = + CS.getOperandBundle(LLVMContext::OB_gc_transition)) { + Flags |= uint32_t(StatepointFlags::GCTransition); + TransitionArgs = TransitionBundle->Inputs; + } + + // Instead of lowering calls to @llvm.experimental.deoptimize as normal calls + // with a return value, we lower then as never returning calls to + // __llvm_deoptimize that are followed by unreachable to get better codegen. + bool IsDeoptimize = false; + + StatepointDirectives SD = + parseStatepointDirectivesFromAttrs(CS.getAttributes()); + if (SD.NumPatchBytes) + NumPatchBytes = *SD.NumPatchBytes; + if (SD.StatepointID) + StatepointID = *SD.StatepointID; + + // Pass through the requested lowering if any. The default is live-through. + StringRef DeoptLowering = getDeoptLowering(CS); + if (DeoptLowering.equals("live-in")) + Flags |= uint32_t(StatepointFlags::DeoptLiveIn); + else { + assert(DeoptLowering.equals("live-through") && "Unsupported value!"); + } + + Value *CallTarget = CS.getCalledValue(); + if (Function *F = dyn_cast<Function>(CallTarget)) { + if (F->getIntrinsicID() == Intrinsic::experimental_deoptimize) { + // Calls to llvm.experimental.deoptimize are lowered to calls to the + // __llvm_deoptimize symbol. We want to resolve this now, since the + // verifier does not allow taking the address of an intrinsic function. + + SmallVector<Type *, 8> DomainTy; + for (Value *Arg : CallArgs) + DomainTy.push_back(Arg->getType()); + auto *FTy = FunctionType::get(Type::getVoidTy(F->getContext()), DomainTy, + /* isVarArg = */ false); + + // Note: CallTarget can be a bitcast instruction of a symbol if there are + // calls to @llvm.experimental.deoptimize with different argument types in + // the same module. This is fine -- we assume the frontend knew what it + // was doing when generating this kind of IR. + CallTarget = + F->getParent()->getOrInsertFunction("__llvm_deoptimize", FTy); + + IsDeoptimize = true; + } + } + + // Create the statepoint given all the arguments + Instruction *Token = nullptr; + if (CS.isCall()) { + CallInst *ToReplace = cast<CallInst>(CS.getInstruction()); + CallInst *Call = Builder.CreateGCStatepointCall( + StatepointID, NumPatchBytes, CallTarget, Flags, CallArgs, + TransitionArgs, DeoptArgs, GCArgs, "safepoint_token"); + + Call->setTailCallKind(ToReplace->getTailCallKind()); + Call->setCallingConv(ToReplace->getCallingConv()); + + // Currently we will fail on parameter attributes and on certain + // function attributes. In case if we can handle this set of attributes - + // set up function attrs directly on statepoint and return attrs later for + // gc_result intrinsic. + Call->setAttributes(legalizeCallAttributes(ToReplace->getAttributes())); + + Token = Call; + + // Put the following gc_result and gc_relocate calls immediately after the + // the old call (which we're about to delete) + assert(ToReplace->getNextNode() && "Not a terminator, must have next!"); + Builder.SetInsertPoint(ToReplace->getNextNode()); + Builder.SetCurrentDebugLocation(ToReplace->getNextNode()->getDebugLoc()); + } else { + InvokeInst *ToReplace = cast<InvokeInst>(CS.getInstruction()); + + // Insert the new invoke into the old block. We'll remove the old one in a + // moment at which point this will become the new terminator for the + // original block. + InvokeInst *Invoke = Builder.CreateGCStatepointInvoke( + StatepointID, NumPatchBytes, CallTarget, ToReplace->getNormalDest(), + ToReplace->getUnwindDest(), Flags, CallArgs, TransitionArgs, DeoptArgs, + GCArgs, "statepoint_token"); + + Invoke->setCallingConv(ToReplace->getCallingConv()); + + // Currently we will fail on parameter attributes and on certain + // function attributes. In case if we can handle this set of attributes - + // set up function attrs directly on statepoint and return attrs later for + // gc_result intrinsic. + Invoke->setAttributes(legalizeCallAttributes(ToReplace->getAttributes())); + + Token = Invoke; + + // Generate gc relocates in exceptional path + BasicBlock *UnwindBlock = ToReplace->getUnwindDest(); + assert(!isa<PHINode>(UnwindBlock->begin()) && + UnwindBlock->getUniquePredecessor() && + "can't safely insert in this block!"); + + Builder.SetInsertPoint(&*UnwindBlock->getFirstInsertionPt()); + Builder.SetCurrentDebugLocation(ToReplace->getDebugLoc()); + + // Attach exceptional gc relocates to the landingpad. + Instruction *ExceptionalToken = UnwindBlock->getLandingPadInst(); + Result.UnwindToken = ExceptionalToken; + + const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx(); + CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, ExceptionalToken, + Builder); + + // Generate gc relocates and returns for normal block + BasicBlock *NormalDest = ToReplace->getNormalDest(); + assert(!isa<PHINode>(NormalDest->begin()) && + NormalDest->getUniquePredecessor() && + "can't safely insert in this block!"); + + Builder.SetInsertPoint(&*NormalDest->getFirstInsertionPt()); + + // gc relocates will be generated later as if it were regular call + // statepoint + } + assert(Token && "Should be set in one of the above branches!"); + + if (IsDeoptimize) { + // If we're wrapping an @llvm.experimental.deoptimize in a statepoint, we + // transform the tail-call like structure to a call to a void function + // followed by unreachable to get better codegen. + Replacements.push_back( + DeferredReplacement::createDeoptimizeReplacement(CS.getInstruction())); + } else { + Token->setName("statepoint_token"); + if (!CS.getType()->isVoidTy() && !CS.getInstruction()->use_empty()) { + StringRef Name = + CS.getInstruction()->hasName() ? CS.getInstruction()->getName() : ""; + CallInst *GCResult = Builder.CreateGCResult(Token, CS.getType(), Name); + GCResult->setAttributes( + AttributeList::get(GCResult->getContext(), AttributeList::ReturnIndex, + CS.getAttributes().getRetAttributes())); + + // We cannot RAUW or delete CS.getInstruction() because it could be in the + // live set of some other safepoint, in which case that safepoint's + // PartiallyConstructedSafepointRecord will hold a raw pointer to this + // llvm::Instruction. Instead, we defer the replacement and deletion to + // after the live sets have been made explicit in the IR, and we no longer + // have raw pointers to worry about. + Replacements.emplace_back( + DeferredReplacement::createRAUW(CS.getInstruction(), GCResult)); + } else { + Replacements.emplace_back( + DeferredReplacement::createDelete(CS.getInstruction())); + } + } + + Result.StatepointToken = Token; + + // Second, create a gc.relocate for every live variable + const unsigned LiveStartIdx = Statepoint(Token).gcArgsStartIdx(); + CreateGCRelocates(LiveVariables, LiveStartIdx, BasePtrs, Token, Builder); +} + +// Replace an existing gc.statepoint with a new one and a set of gc.relocates +// which make the relocations happening at this safepoint explicit. +// +// WARNING: Does not do any fixup to adjust users of the original live +// values. That's the callers responsibility. +static void +makeStatepointExplicit(DominatorTree &DT, CallSite CS, + PartiallyConstructedSafepointRecord &Result, + std::vector<DeferredReplacement> &Replacements) { + const auto &LiveSet = Result.LiveSet; + const auto &PointerToBase = Result.PointerToBase; + + // Convert to vector for efficient cross referencing. + SmallVector<Value *, 64> BaseVec, LiveVec; + LiveVec.reserve(LiveSet.size()); + BaseVec.reserve(LiveSet.size()); + for (Value *L : LiveSet) { + LiveVec.push_back(L); + assert(PointerToBase.count(L)); + Value *Base = PointerToBase.find(L)->second; + BaseVec.push_back(Base); + } + assert(LiveVec.size() == BaseVec.size()); + + // Do the actual rewriting and delete the old statepoint + makeStatepointExplicitImpl(CS, BaseVec, LiveVec, Result, Replacements); +} + +// Helper function for the relocationViaAlloca. +// +// It receives iterator to the statepoint gc relocates and emits a store to the +// assigned location (via allocaMap) for the each one of them. It adds the +// visited values into the visitedLiveValues set, which we will later use them +// for sanity checking. +static void +insertRelocationStores(iterator_range<Value::user_iterator> GCRelocs, + DenseMap<Value *, Value *> &AllocaMap, + DenseSet<Value *> &VisitedLiveValues) { + for (User *U : GCRelocs) { + GCRelocateInst *Relocate = dyn_cast<GCRelocateInst>(U); + if (!Relocate) + continue; + + Value *OriginalValue = Relocate->getDerivedPtr(); + assert(AllocaMap.count(OriginalValue)); + Value *Alloca = AllocaMap[OriginalValue]; + + // Emit store into the related alloca + // All gc_relocates are i8 addrspace(1)* typed, and it must be bitcasted to + // the correct type according to alloca. + assert(Relocate->getNextNode() && + "Should always have one since it's not a terminator"); + IRBuilder<> Builder(Relocate->getNextNode()); + Value *CastedRelocatedValue = + Builder.CreateBitCast(Relocate, + cast<AllocaInst>(Alloca)->getAllocatedType(), + suffixed_name_or(Relocate, ".casted", "")); + + StoreInst *Store = new StoreInst(CastedRelocatedValue, Alloca); + Store->insertAfter(cast<Instruction>(CastedRelocatedValue)); + +#ifndef NDEBUG + VisitedLiveValues.insert(OriginalValue); +#endif + } +} + +// Helper function for the "relocationViaAlloca". Similar to the +// "insertRelocationStores" but works for rematerialized values. +static void insertRematerializationStores( + const RematerializedValueMapTy &RematerializedValues, + DenseMap<Value *, Value *> &AllocaMap, + DenseSet<Value *> &VisitedLiveValues) { + for (auto RematerializedValuePair: RematerializedValues) { + Instruction *RematerializedValue = RematerializedValuePair.first; + Value *OriginalValue = RematerializedValuePair.second; + + assert(AllocaMap.count(OriginalValue) && + "Can not find alloca for rematerialized value"); + Value *Alloca = AllocaMap[OriginalValue]; + + StoreInst *Store = new StoreInst(RematerializedValue, Alloca); + Store->insertAfter(RematerializedValue); + +#ifndef NDEBUG + VisitedLiveValues.insert(OriginalValue); +#endif + } +} + +/// Do all the relocation update via allocas and mem2reg +static void relocationViaAlloca( + Function &F, DominatorTree &DT, ArrayRef<Value *> Live, + ArrayRef<PartiallyConstructedSafepointRecord> Records) { +#ifndef NDEBUG + // record initial number of (static) allocas; we'll check we have the same + // number when we get done. + int InitialAllocaNum = 0; + for (Instruction &I : F.getEntryBlock()) + if (isa<AllocaInst>(I)) + InitialAllocaNum++; +#endif + + // TODO-PERF: change data structures, reserve + DenseMap<Value *, Value *> AllocaMap; + SmallVector<AllocaInst *, 200> PromotableAllocas; + // Used later to chack that we have enough allocas to store all values + std::size_t NumRematerializedValues = 0; + PromotableAllocas.reserve(Live.size()); + + // Emit alloca for "LiveValue" and record it in "allocaMap" and + // "PromotableAllocas" + const DataLayout &DL = F.getParent()->getDataLayout(); + auto emitAllocaFor = [&](Value *LiveValue) { + AllocaInst *Alloca = new AllocaInst(LiveValue->getType(), + DL.getAllocaAddrSpace(), "", + F.getEntryBlock().getFirstNonPHI()); + AllocaMap[LiveValue] = Alloca; + PromotableAllocas.push_back(Alloca); + }; + + // Emit alloca for each live gc pointer + for (Value *V : Live) + emitAllocaFor(V); + + // Emit allocas for rematerialized values + for (const auto &Info : Records) + for (auto RematerializedValuePair : Info.RematerializedValues) { + Value *OriginalValue = RematerializedValuePair.second; + if (AllocaMap.count(OriginalValue) != 0) + continue; + + emitAllocaFor(OriginalValue); + ++NumRematerializedValues; + } + + // The next two loops are part of the same conceptual operation. We need to + // insert a store to the alloca after the original def and at each + // redefinition. We need to insert a load before each use. These are split + // into distinct loops for performance reasons. + + // Update gc pointer after each statepoint: either store a relocated value or + // null (if no relocated value was found for this gc pointer and it is not a + // gc_result). This must happen before we update the statepoint with load of + // alloca otherwise we lose the link between statepoint and old def. + for (const auto &Info : Records) { + Value *Statepoint = Info.StatepointToken; + + // This will be used for consistency check + DenseSet<Value *> VisitedLiveValues; + + // Insert stores for normal statepoint gc relocates + insertRelocationStores(Statepoint->users(), AllocaMap, VisitedLiveValues); + + // In case if it was invoke statepoint + // we will insert stores for exceptional path gc relocates. + if (isa<InvokeInst>(Statepoint)) { + insertRelocationStores(Info.UnwindToken->users(), AllocaMap, + VisitedLiveValues); + } + + // Do similar thing with rematerialized values + insertRematerializationStores(Info.RematerializedValues, AllocaMap, + VisitedLiveValues); + + if (ClobberNonLive) { + // As a debugging aid, pretend that an unrelocated pointer becomes null at + // the gc.statepoint. This will turn some subtle GC problems into + // slightly easier to debug SEGVs. Note that on large IR files with + // lots of gc.statepoints this is extremely costly both memory and time + // wise. + SmallVector<AllocaInst *, 64> ToClobber; + for (auto Pair : AllocaMap) { + Value *Def = Pair.first; + AllocaInst *Alloca = cast<AllocaInst>(Pair.second); + + // This value was relocated + if (VisitedLiveValues.count(Def)) { + continue; + } + ToClobber.push_back(Alloca); + } + + auto InsertClobbersAt = [&](Instruction *IP) { + for (auto *AI : ToClobber) { + auto PT = cast<PointerType>(AI->getAllocatedType()); + Constant *CPN = ConstantPointerNull::get(PT); + StoreInst *Store = new StoreInst(CPN, AI); + Store->insertBefore(IP); + } + }; + + // Insert the clobbering stores. These may get intermixed with the + // gc.results and gc.relocates, but that's fine. + if (auto II = dyn_cast<InvokeInst>(Statepoint)) { + InsertClobbersAt(&*II->getNormalDest()->getFirstInsertionPt()); + InsertClobbersAt(&*II->getUnwindDest()->getFirstInsertionPt()); + } else { + InsertClobbersAt(cast<Instruction>(Statepoint)->getNextNode()); + } + } + } + + // Update use with load allocas and add store for gc_relocated. + for (auto Pair : AllocaMap) { + Value *Def = Pair.first; + Value *Alloca = Pair.second; + + // We pre-record the uses of allocas so that we dont have to worry about + // later update that changes the user information.. + + SmallVector<Instruction *, 20> Uses; + // PERF: trade a linear scan for repeated reallocation + Uses.reserve(std::distance(Def->user_begin(), Def->user_end())); + for (User *U : Def->users()) { + if (!isa<ConstantExpr>(U)) { + // If the def has a ConstantExpr use, then the def is either a + // ConstantExpr use itself or null. In either case + // (recursively in the first, directly in the second), the oop + // it is ultimately dependent on is null and this particular + // use does not need to be fixed up. + Uses.push_back(cast<Instruction>(U)); + } + } + + std::sort(Uses.begin(), Uses.end()); + auto Last = std::unique(Uses.begin(), Uses.end()); + Uses.erase(Last, Uses.end()); + + for (Instruction *Use : Uses) { + if (isa<PHINode>(Use)) { + PHINode *Phi = cast<PHINode>(Use); + for (unsigned i = 0; i < Phi->getNumIncomingValues(); i++) { + if (Def == Phi->getIncomingValue(i)) { + LoadInst *Load = new LoadInst( + Alloca, "", Phi->getIncomingBlock(i)->getTerminator()); + Phi->setIncomingValue(i, Load); + } + } + } else { + LoadInst *Load = new LoadInst(Alloca, "", Use); + Use->replaceUsesOfWith(Def, Load); + } + } + + // Emit store for the initial gc value. Store must be inserted after load, + // otherwise store will be in alloca's use list and an extra load will be + // inserted before it. + StoreInst *Store = new StoreInst(Def, Alloca); + if (Instruction *Inst = dyn_cast<Instruction>(Def)) { + if (InvokeInst *Invoke = dyn_cast<InvokeInst>(Inst)) { + // InvokeInst is a TerminatorInst so the store need to be inserted + // into its normal destination block. + BasicBlock *NormalDest = Invoke->getNormalDest(); + Store->insertBefore(NormalDest->getFirstNonPHI()); + } else { + assert(!Inst->isTerminator() && + "The only TerminatorInst that can produce a value is " + "InvokeInst which is handled above."); + Store->insertAfter(Inst); + } + } else { + assert(isa<Argument>(Def)); + Store->insertAfter(cast<Instruction>(Alloca)); + } + } + + assert(PromotableAllocas.size() == Live.size() + NumRematerializedValues && + "we must have the same allocas with lives"); + if (!PromotableAllocas.empty()) { + // Apply mem2reg to promote alloca to SSA + PromoteMemToReg(PromotableAllocas, DT); + } + +#ifndef NDEBUG + for (auto &I : F.getEntryBlock()) + if (isa<AllocaInst>(I)) + InitialAllocaNum--; + assert(InitialAllocaNum == 0 && "We must not introduce any extra allocas"); +#endif +} + +/// Implement a unique function which doesn't require we sort the input +/// vector. Doing so has the effect of changing the output of a couple of +/// tests in ways which make them less useful in testing fused safepoints. +template <typename T> static void unique_unsorted(SmallVectorImpl<T> &Vec) { + SmallSet<T, 8> Seen; + Vec.erase(remove_if(Vec, [&](const T &V) { return !Seen.insert(V).second; }), + Vec.end()); +} + +/// Insert holders so that each Value is obviously live through the entire +/// lifetime of the call. +static void insertUseHolderAfter(CallSite &CS, const ArrayRef<Value *> Values, + SmallVectorImpl<CallInst *> &Holders) { + if (Values.empty()) + // No values to hold live, might as well not insert the empty holder + return; + + Module *M = CS.getInstruction()->getModule(); + // Use a dummy vararg function to actually hold the values live + Function *Func = cast<Function>(M->getOrInsertFunction( + "__tmp_use", FunctionType::get(Type::getVoidTy(M->getContext()), true))); + if (CS.isCall()) { + // For call safepoints insert dummy calls right after safepoint + Holders.push_back(CallInst::Create(Func, Values, "", + &*++CS.getInstruction()->getIterator())); + return; + } + // For invoke safepooints insert dummy calls both in normal and + // exceptional destination blocks + auto *II = cast<InvokeInst>(CS.getInstruction()); + Holders.push_back(CallInst::Create( + Func, Values, "", &*II->getNormalDest()->getFirstInsertionPt())); + Holders.push_back(CallInst::Create( + Func, Values, "", &*II->getUnwindDest()->getFirstInsertionPt())); +} + +static void findLiveReferences( + Function &F, DominatorTree &DT, ArrayRef<CallSite> toUpdate, + MutableArrayRef<struct PartiallyConstructedSafepointRecord> records) { + GCPtrLivenessData OriginalLivenessData; + computeLiveInValues(DT, F, OriginalLivenessData); + for (size_t i = 0; i < records.size(); i++) { + struct PartiallyConstructedSafepointRecord &info = records[i]; + analyzeParsePointLiveness(DT, OriginalLivenessData, toUpdate[i], info); + } +} + +// Helper function for the "rematerializeLiveValues". It walks use chain +// starting from the "CurrentValue" until it reaches the root of the chain, i.e. +// the base or a value it cannot process. Only "simple" values are processed +// (currently it is GEP's and casts). The returned root is examined by the +// callers of findRematerializableChainToBasePointer. Fills "ChainToBase" array +// with all visited values. +static Value* findRematerializableChainToBasePointer( + SmallVectorImpl<Instruction*> &ChainToBase, + Value *CurrentValue) { + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(CurrentValue)) { + ChainToBase.push_back(GEP); + return findRematerializableChainToBasePointer(ChainToBase, + GEP->getPointerOperand()); + } + + if (CastInst *CI = dyn_cast<CastInst>(CurrentValue)) { + if (!CI->isNoopCast(CI->getModule()->getDataLayout())) + return CI; + + ChainToBase.push_back(CI); + return findRematerializableChainToBasePointer(ChainToBase, + CI->getOperand(0)); + } + + // We have reached the root of the chain, which is either equal to the base or + // is the first unsupported value along the use chain. + return CurrentValue; +} + +// Helper function for the "rematerializeLiveValues". Compute cost of the use +// chain we are going to rematerialize. +static unsigned +chainToBasePointerCost(SmallVectorImpl<Instruction*> &Chain, + TargetTransformInfo &TTI) { + unsigned Cost = 0; + + for (Instruction *Instr : Chain) { + if (CastInst *CI = dyn_cast<CastInst>(Instr)) { + assert(CI->isNoopCast(CI->getModule()->getDataLayout()) && + "non noop cast is found during rematerialization"); + + Type *SrcTy = CI->getOperand(0)->getType(); + Cost += TTI.getCastInstrCost(CI->getOpcode(), CI->getType(), SrcTy, CI); + + } else if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Instr)) { + // Cost of the address calculation + Type *ValTy = GEP->getSourceElementType(); + Cost += TTI.getAddressComputationCost(ValTy); + + // And cost of the GEP itself + // TODO: Use TTI->getGEPCost here (it exists, but appears to be not + // allowed for the external usage) + if (!GEP->hasAllConstantIndices()) + Cost += 2; + + } else { + llvm_unreachable("unsupported instruciton type during rematerialization"); + } + } + + return Cost; +} + +static bool AreEquivalentPhiNodes(PHINode &OrigRootPhi, PHINode &AlternateRootPhi) { + unsigned PhiNum = OrigRootPhi.getNumIncomingValues(); + if (PhiNum != AlternateRootPhi.getNumIncomingValues() || + OrigRootPhi.getParent() != AlternateRootPhi.getParent()) + return false; + // Map of incoming values and their corresponding basic blocks of + // OrigRootPhi. + SmallDenseMap<Value *, BasicBlock *, 8> CurrentIncomingValues; + for (unsigned i = 0; i < PhiNum; i++) + CurrentIncomingValues[OrigRootPhi.getIncomingValue(i)] = + OrigRootPhi.getIncomingBlock(i); + + // Both current and base PHIs should have same incoming values and + // the same basic blocks corresponding to the incoming values. + for (unsigned i = 0; i < PhiNum; i++) { + auto CIVI = + CurrentIncomingValues.find(AlternateRootPhi.getIncomingValue(i)); + if (CIVI == CurrentIncomingValues.end()) + return false; + BasicBlock *CurrentIncomingBB = CIVI->second; + if (CurrentIncomingBB != AlternateRootPhi.getIncomingBlock(i)) + return false; + } + return true; +} + +// From the statepoint live set pick values that are cheaper to recompute then +// to relocate. Remove this values from the live set, rematerialize them after +// statepoint and record them in "Info" structure. Note that similar to +// relocated values we don't do any user adjustments here. +static void rematerializeLiveValues(CallSite CS, + PartiallyConstructedSafepointRecord &Info, + TargetTransformInfo &TTI) { + const unsigned int ChainLengthThreshold = 10; + + // Record values we are going to delete from this statepoint live set. + // We can not di this in following loop due to iterator invalidation. + SmallVector<Value *, 32> LiveValuesToBeDeleted; + + for (Value *LiveValue: Info.LiveSet) { + // For each live pointer find it's defining chain + SmallVector<Instruction *, 3> ChainToBase; + assert(Info.PointerToBase.count(LiveValue)); + Value *RootOfChain = + findRematerializableChainToBasePointer(ChainToBase, + LiveValue); + + // Nothing to do, or chain is too long + if ( ChainToBase.size() == 0 || + ChainToBase.size() > ChainLengthThreshold) + continue; + + // Handle the scenario where the RootOfChain is not equal to the + // Base Value, but they are essentially the same phi values. + if (RootOfChain != Info.PointerToBase[LiveValue]) { + PHINode *OrigRootPhi = dyn_cast<PHINode>(RootOfChain); + PHINode *AlternateRootPhi = dyn_cast<PHINode>(Info.PointerToBase[LiveValue]); + if (!OrigRootPhi || !AlternateRootPhi) + continue; + // PHI nodes that have the same incoming values, and belonging to the same + // basic blocks are essentially the same SSA value. When the original phi + // has incoming values with different base pointers, the original phi is + // marked as conflict, and an additional `AlternateRootPhi` with the same + // incoming values get generated by the findBasePointer function. We need + // to identify the newly generated AlternateRootPhi (.base version of phi) + // and RootOfChain (the original phi node itself) are the same, so that we + // can rematerialize the gep and casts. This is a workaround for the + // deficiency in the findBasePointer algorithm. + if (!AreEquivalentPhiNodes(*OrigRootPhi, *AlternateRootPhi)) + continue; + // Now that the phi nodes are proved to be the same, assert that + // findBasePointer's newly generated AlternateRootPhi is present in the + // liveset of the call. + assert(Info.LiveSet.count(AlternateRootPhi)); + } + // Compute cost of this chain + unsigned Cost = chainToBasePointerCost(ChainToBase, TTI); + // TODO: We can also account for cases when we will be able to remove some + // of the rematerialized values by later optimization passes. I.e if + // we rematerialized several intersecting chains. Or if original values + // don't have any uses besides this statepoint. + + // For invokes we need to rematerialize each chain twice - for normal and + // for unwind basic blocks. Model this by multiplying cost by two. + if (CS.isInvoke()) { + Cost *= 2; + } + // If it's too expensive - skip it + if (Cost >= RematerializationThreshold) + continue; + + // Remove value from the live set + LiveValuesToBeDeleted.push_back(LiveValue); + + // Clone instructions and record them inside "Info" structure + + // Walk backwards to visit top-most instructions first + std::reverse(ChainToBase.begin(), ChainToBase.end()); + + // Utility function which clones all instructions from "ChainToBase" + // and inserts them before "InsertBefore". Returns rematerialized value + // which should be used after statepoint. + auto rematerializeChain = [&ChainToBase]( + Instruction *InsertBefore, Value *RootOfChain, Value *AlternateLiveBase) { + Instruction *LastClonedValue = nullptr; + Instruction *LastValue = nullptr; + for (Instruction *Instr: ChainToBase) { + // Only GEP's and casts are supported as we need to be careful to not + // introduce any new uses of pointers not in the liveset. + // Note that it's fine to introduce new uses of pointers which were + // otherwise not used after this statepoint. + assert(isa<GetElementPtrInst>(Instr) || isa<CastInst>(Instr)); + + Instruction *ClonedValue = Instr->clone(); + ClonedValue->insertBefore(InsertBefore); + ClonedValue->setName(Instr->getName() + ".remat"); + + // If it is not first instruction in the chain then it uses previously + // cloned value. We should update it to use cloned value. + if (LastClonedValue) { + assert(LastValue); + ClonedValue->replaceUsesOfWith(LastValue, LastClonedValue); +#ifndef NDEBUG + for (auto OpValue : ClonedValue->operand_values()) { + // Assert that cloned instruction does not use any instructions from + // this chain other than LastClonedValue + assert(!is_contained(ChainToBase, OpValue) && + "incorrect use in rematerialization chain"); + // Assert that the cloned instruction does not use the RootOfChain + // or the AlternateLiveBase. + assert(OpValue != RootOfChain && OpValue != AlternateLiveBase); + } +#endif + } else { + // For the first instruction, replace the use of unrelocated base i.e. + // RootOfChain/OrigRootPhi, with the corresponding PHI present in the + // live set. They have been proved to be the same PHI nodes. Note + // that the *only* use of the RootOfChain in the ChainToBase list is + // the first Value in the list. + if (RootOfChain != AlternateLiveBase) + ClonedValue->replaceUsesOfWith(RootOfChain, AlternateLiveBase); + } + + LastClonedValue = ClonedValue; + LastValue = Instr; + } + assert(LastClonedValue); + return LastClonedValue; + }; + + // Different cases for calls and invokes. For invokes we need to clone + // instructions both on normal and unwind path. + if (CS.isCall()) { + Instruction *InsertBefore = CS.getInstruction()->getNextNode(); + assert(InsertBefore); + Instruction *RematerializedValue = rematerializeChain( + InsertBefore, RootOfChain, Info.PointerToBase[LiveValue]); + Info.RematerializedValues[RematerializedValue] = LiveValue; + } else { + InvokeInst *Invoke = cast<InvokeInst>(CS.getInstruction()); + + Instruction *NormalInsertBefore = + &*Invoke->getNormalDest()->getFirstInsertionPt(); + Instruction *UnwindInsertBefore = + &*Invoke->getUnwindDest()->getFirstInsertionPt(); + + Instruction *NormalRematerializedValue = rematerializeChain( + NormalInsertBefore, RootOfChain, Info.PointerToBase[LiveValue]); + Instruction *UnwindRematerializedValue = rematerializeChain( + UnwindInsertBefore, RootOfChain, Info.PointerToBase[LiveValue]); + + Info.RematerializedValues[NormalRematerializedValue] = LiveValue; + Info.RematerializedValues[UnwindRematerializedValue] = LiveValue; + } + } + + // Remove rematerializaed values from the live set + for (auto LiveValue: LiveValuesToBeDeleted) { + Info.LiveSet.remove(LiveValue); + } +} + +static bool insertParsePoints(Function &F, DominatorTree &DT, + TargetTransformInfo &TTI, + SmallVectorImpl<CallSite> &ToUpdate) { +#ifndef NDEBUG + // sanity check the input + std::set<CallSite> Uniqued; + Uniqued.insert(ToUpdate.begin(), ToUpdate.end()); + assert(Uniqued.size() == ToUpdate.size() && "no duplicates please!"); + + for (CallSite CS : ToUpdate) + assert(CS.getInstruction()->getFunction() == &F); +#endif + + // When inserting gc.relocates for invokes, we need to be able to insert at + // the top of the successor blocks. See the comment on + // normalForInvokeSafepoint on exactly what is needed. Note that this step + // may restructure the CFG. + for (CallSite CS : ToUpdate) { + if (!CS.isInvoke()) + continue; + auto *II = cast<InvokeInst>(CS.getInstruction()); + normalizeForInvokeSafepoint(II->getNormalDest(), II->getParent(), DT); + normalizeForInvokeSafepoint(II->getUnwindDest(), II->getParent(), DT); + } + + // A list of dummy calls added to the IR to keep various values obviously + // live in the IR. We'll remove all of these when done. + SmallVector<CallInst *, 64> Holders; + + // Insert a dummy call with all of the deopt operands we'll need for the + // actual safepoint insertion as arguments. This ensures reference operands + // in the deopt argument list are considered live through the safepoint (and + // thus makes sure they get relocated.) + for (CallSite CS : ToUpdate) { + SmallVector<Value *, 64> DeoptValues; + + for (Value *Arg : GetDeoptBundleOperands(CS)) { + assert(!isUnhandledGCPointerType(Arg->getType()) && + "support for FCA unimplemented"); + if (isHandledGCPointerType(Arg->getType())) + DeoptValues.push_back(Arg); + } + + insertUseHolderAfter(CS, DeoptValues, Holders); + } + + SmallVector<PartiallyConstructedSafepointRecord, 64> Records(ToUpdate.size()); + + // A) Identify all gc pointers which are statically live at the given call + // site. + findLiveReferences(F, DT, ToUpdate, Records); + + // B) Find the base pointers for each live pointer + /* scope for caching */ { + // Cache the 'defining value' relation used in the computation and + // insertion of base phis and selects. This ensures that we don't insert + // large numbers of duplicate base_phis. + DefiningValueMapTy DVCache; + + for (size_t i = 0; i < Records.size(); i++) { + PartiallyConstructedSafepointRecord &info = Records[i]; + findBasePointers(DT, DVCache, ToUpdate[i], info); + } + } // end of cache scope + + // The base phi insertion logic (for any safepoint) may have inserted new + // instructions which are now live at some safepoint. The simplest such + // example is: + // loop: + // phi a <-- will be a new base_phi here + // safepoint 1 <-- that needs to be live here + // gep a + 1 + // safepoint 2 + // br loop + // We insert some dummy calls after each safepoint to definitely hold live + // the base pointers which were identified for that safepoint. We'll then + // ask liveness for _every_ base inserted to see what is now live. Then we + // remove the dummy calls. + Holders.reserve(Holders.size() + Records.size()); + for (size_t i = 0; i < Records.size(); i++) { + PartiallyConstructedSafepointRecord &Info = Records[i]; + + SmallVector<Value *, 128> Bases; + for (auto Pair : Info.PointerToBase) + Bases.push_back(Pair.second); + + insertUseHolderAfter(ToUpdate[i], Bases, Holders); + } + + // By selecting base pointers, we've effectively inserted new uses. Thus, we + // need to rerun liveness. We may *also* have inserted new defs, but that's + // not the key issue. + recomputeLiveInValues(F, DT, ToUpdate, Records); + + if (PrintBasePointers) { + for (auto &Info : Records) { + errs() << "Base Pairs: (w/Relocation)\n"; + for (auto Pair : Info.PointerToBase) { + errs() << " derived "; + Pair.first->printAsOperand(errs(), false); + errs() << " base "; + Pair.second->printAsOperand(errs(), false); + errs() << "\n"; + } + } + } + + // It is possible that non-constant live variables have a constant base. For + // example, a GEP with a variable offset from a global. In this case we can + // remove it from the liveset. We already don't add constants to the liveset + // because we assume they won't move at runtime and the GC doesn't need to be + // informed about them. The same reasoning applies if the base is constant. + // Note that the relocation placement code relies on this filtering for + // correctness as it expects the base to be in the liveset, which isn't true + // if the base is constant. + for (auto &Info : Records) + for (auto &BasePair : Info.PointerToBase) + if (isa<Constant>(BasePair.second)) + Info.LiveSet.remove(BasePair.first); + + for (CallInst *CI : Holders) + CI->eraseFromParent(); + + Holders.clear(); + + // In order to reduce live set of statepoint we might choose to rematerialize + // some values instead of relocating them. This is purely an optimization and + // does not influence correctness. + for (size_t i = 0; i < Records.size(); i++) + rematerializeLiveValues(ToUpdate[i], Records[i], TTI); + + // We need this to safely RAUW and delete call or invoke return values that + // may themselves be live over a statepoint. For details, please see usage in + // makeStatepointExplicitImpl. + std::vector<DeferredReplacement> Replacements; + + // Now run through and replace the existing statepoints with new ones with + // the live variables listed. We do not yet update uses of the values being + // relocated. We have references to live variables that need to + // survive to the last iteration of this loop. (By construction, the + // previous statepoint can not be a live variable, thus we can and remove + // the old statepoint calls as we go.) + for (size_t i = 0; i < Records.size(); i++) + makeStatepointExplicit(DT, ToUpdate[i], Records[i], Replacements); + + ToUpdate.clear(); // prevent accident use of invalid CallSites + + for (auto &PR : Replacements) + PR.doReplacement(); + + Replacements.clear(); + + for (auto &Info : Records) { + // These live sets may contain state Value pointers, since we replaced calls + // with operand bundles with calls wrapped in gc.statepoint, and some of + // those calls may have been def'ing live gc pointers. Clear these out to + // avoid accidentally using them. + // + // TODO: We should create a separate data structure that does not contain + // these live sets, and migrate to using that data structure from this point + // onward. + Info.LiveSet.clear(); + Info.PointerToBase.clear(); + } + + // Do all the fixups of the original live variables to their relocated selves + SmallVector<Value *, 128> Live; + for (size_t i = 0; i < Records.size(); i++) { + PartiallyConstructedSafepointRecord &Info = Records[i]; + + // We can't simply save the live set from the original insertion. One of + // the live values might be the result of a call which needs a safepoint. + // That Value* no longer exists and we need to use the new gc_result. + // Thankfully, the live set is embedded in the statepoint (and updated), so + // we just grab that. + Statepoint Statepoint(Info.StatepointToken); + Live.insert(Live.end(), Statepoint.gc_args_begin(), + Statepoint.gc_args_end()); +#ifndef NDEBUG + // Do some basic sanity checks on our liveness results before performing + // relocation. Relocation can and will turn mistakes in liveness results + // into non-sensical code which is must harder to debug. + // TODO: It would be nice to test consistency as well + assert(DT.isReachableFromEntry(Info.StatepointToken->getParent()) && + "statepoint must be reachable or liveness is meaningless"); + for (Value *V : Statepoint.gc_args()) { + if (!isa<Instruction>(V)) + // Non-instruction values trivial dominate all possible uses + continue; + auto *LiveInst = cast<Instruction>(V); + assert(DT.isReachableFromEntry(LiveInst->getParent()) && + "unreachable values should never be live"); + assert(DT.dominates(LiveInst, Info.StatepointToken) && + "basic SSA liveness expectation violated by liveness analysis"); + } +#endif + } + unique_unsorted(Live); + +#ifndef NDEBUG + // sanity check + for (auto *Ptr : Live) + assert(isHandledGCPointerType(Ptr->getType()) && + "must be a gc pointer type"); +#endif + + relocationViaAlloca(F, DT, Live, Records); + return !Records.empty(); +} + +// Handles both return values and arguments for Functions and CallSites. +template <typename AttrHolder> +static void RemoveNonValidAttrAtIndex(LLVMContext &Ctx, AttrHolder &AH, + unsigned Index) { + AttrBuilder R; + if (AH.getDereferenceableBytes(Index)) + R.addAttribute(Attribute::get(Ctx, Attribute::Dereferenceable, + AH.getDereferenceableBytes(Index))); + if (AH.getDereferenceableOrNullBytes(Index)) + R.addAttribute(Attribute::get(Ctx, Attribute::DereferenceableOrNull, + AH.getDereferenceableOrNullBytes(Index))); + if (AH.getAttributes().hasAttribute(Index, Attribute::NoAlias)) + R.addAttribute(Attribute::NoAlias); + + if (!R.empty()) + AH.setAttributes(AH.getAttributes().removeAttributes(Ctx, Index, R)); +} + +static void stripNonValidAttributesFromPrototype(Function &F) { + LLVMContext &Ctx = F.getContext(); + + for (Argument &A : F.args()) + if (isa<PointerType>(A.getType())) + RemoveNonValidAttrAtIndex(Ctx, F, + A.getArgNo() + AttributeList::FirstArgIndex); + + if (isa<PointerType>(F.getReturnType())) + RemoveNonValidAttrAtIndex(Ctx, F, AttributeList::ReturnIndex); +} + +/// Certain metadata on instructions are invalid after running RS4GC. +/// Optimizations that run after RS4GC can incorrectly use this metadata to +/// optimize functions. We drop such metadata on the instruction. +static void stripInvalidMetadataFromInstruction(Instruction &I) { + if (!isa<LoadInst>(I) && !isa<StoreInst>(I)) + return; + // These are the attributes that are still valid on loads and stores after + // RS4GC. + // The metadata implying dereferenceability and noalias are (conservatively) + // dropped. This is because semantically, after RewriteStatepointsForGC runs, + // all calls to gc.statepoint "free" the entire heap. Also, gc.statepoint can + // touch the entire heap including noalias objects. Note: The reasoning is + // same as stripping the dereferenceability and noalias attributes that are + // analogous to the metadata counterparts. + // We also drop the invariant.load metadata on the load because that metadata + // implies the address operand to the load points to memory that is never + // changed once it became dereferenceable. This is no longer true after RS4GC. + // Similar reasoning applies to invariant.group metadata, which applies to + // loads within a group. + unsigned ValidMetadataAfterRS4GC[] = {LLVMContext::MD_tbaa, + LLVMContext::MD_range, + LLVMContext::MD_alias_scope, + LLVMContext::MD_nontemporal, + LLVMContext::MD_nonnull, + LLVMContext::MD_align, + LLVMContext::MD_type}; + + // Drops all metadata on the instruction other than ValidMetadataAfterRS4GC. + I.dropUnknownNonDebugMetadata(ValidMetadataAfterRS4GC); +} + +static void stripNonValidDataFromBody(Function &F) { + if (F.empty()) + return; + + LLVMContext &Ctx = F.getContext(); + MDBuilder Builder(Ctx); + + // Set of invariantstart instructions that we need to remove. + // Use this to avoid invalidating the instruction iterator. + SmallVector<IntrinsicInst*, 12> InvariantStartInstructions; + + for (Instruction &I : instructions(F)) { + // invariant.start on memory location implies that the referenced memory + // location is constant and unchanging. This is no longer true after + // RewriteStatepointsForGC runs because there can be calls to gc.statepoint + // which frees the entire heap and the presence of invariant.start allows + // the optimizer to sink the load of a memory location past a statepoint, + // which is incorrect. + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::invariant_start) { + InvariantStartInstructions.push_back(II); + continue; + } + + if (const MDNode *MD = I.getMetadata(LLVMContext::MD_tbaa)) { + assert(MD->getNumOperands() < 5 && "unrecognized metadata shape!"); + bool IsImmutableTBAA = + MD->getNumOperands() == 4 && + mdconst::extract<ConstantInt>(MD->getOperand(3))->getValue() == 1; + + if (!IsImmutableTBAA) + continue; // no work to do, MD_tbaa is already marked mutable + + MDNode *Base = cast<MDNode>(MD->getOperand(0)); + MDNode *Access = cast<MDNode>(MD->getOperand(1)); + uint64_t Offset = + mdconst::extract<ConstantInt>(MD->getOperand(2))->getZExtValue(); + + MDNode *MutableTBAA = + Builder.createTBAAStructTagNode(Base, Access, Offset); + I.setMetadata(LLVMContext::MD_tbaa, MutableTBAA); + } + + stripInvalidMetadataFromInstruction(I); + + if (CallSite CS = CallSite(&I)) { + for (int i = 0, e = CS.arg_size(); i != e; i++) + if (isa<PointerType>(CS.getArgument(i)->getType())) + RemoveNonValidAttrAtIndex(Ctx, CS, i + AttributeList::FirstArgIndex); + if (isa<PointerType>(CS.getType())) + RemoveNonValidAttrAtIndex(Ctx, CS, AttributeList::ReturnIndex); + } + } + + // Delete the invariant.start instructions and RAUW undef. + for (auto *II : InvariantStartInstructions) { + II->replaceAllUsesWith(UndefValue::get(II->getType())); + II->eraseFromParent(); + } +} + +/// Returns true if this function should be rewritten by this pass. The main +/// point of this function is as an extension point for custom logic. +static bool shouldRewriteStatepointsIn(Function &F) { + // TODO: This should check the GCStrategy + if (F.hasGC()) { + const auto &FunctionGCName = F.getGC(); + const StringRef StatepointExampleName("statepoint-example"); + const StringRef CoreCLRName("coreclr"); + return (StatepointExampleName == FunctionGCName) || + (CoreCLRName == FunctionGCName); + } else + return false; +} + +static void stripNonValidData(Module &M) { +#ifndef NDEBUG + assert(llvm::any_of(M, shouldRewriteStatepointsIn) && "precondition!"); +#endif + + for (Function &F : M) + stripNonValidAttributesFromPrototype(F); + + for (Function &F : M) + stripNonValidDataFromBody(F); +} + +bool RewriteStatepointsForGC::runOnFunction(Function &F, DominatorTree &DT, + TargetTransformInfo &TTI, + const TargetLibraryInfo &TLI) { + assert(!F.isDeclaration() && !F.empty() && + "need function body to rewrite statepoints in"); + assert(shouldRewriteStatepointsIn(F) && "mismatch in rewrite decision"); + + auto NeedsRewrite = [&TLI](Instruction &I) { + if (ImmutableCallSite CS = ImmutableCallSite(&I)) + return !callsGCLeafFunction(CS, TLI) && !isStatepoint(CS); + return false; + }; + + // Gather all the statepoints which need rewritten. Be careful to only + // consider those in reachable code since we need to ask dominance queries + // when rewriting. We'll delete the unreachable ones in a moment. + SmallVector<CallSite, 64> ParsePointNeeded; + bool HasUnreachableStatepoint = false; + for (Instruction &I : instructions(F)) { + // TODO: only the ones with the flag set! + if (NeedsRewrite(I)) { + if (DT.isReachableFromEntry(I.getParent())) + ParsePointNeeded.push_back(CallSite(&I)); + else + HasUnreachableStatepoint = true; + } + } + + bool MadeChange = false; + + // Delete any unreachable statepoints so that we don't have unrewritten + // statepoints surviving this pass. This makes testing easier and the + // resulting IR less confusing to human readers. Rather than be fancy, we + // just reuse a utility function which removes the unreachable blocks. + if (HasUnreachableStatepoint) + MadeChange |= removeUnreachableBlocks(F); + + // Return early if no work to do. + if (ParsePointNeeded.empty()) + return MadeChange; + + // As a prepass, go ahead and aggressively destroy single entry phi nodes. + // These are created by LCSSA. They have the effect of increasing the size + // of liveness sets for no good reason. It may be harder to do this post + // insertion since relocations and base phis can confuse things. + for (BasicBlock &BB : F) + if (BB.getUniquePredecessor()) { + MadeChange = true; + FoldSingleEntryPHINodes(&BB); + } + + // Before we start introducing relocations, we want to tweak the IR a bit to + // avoid unfortunate code generation effects. The main example is that we + // want to try to make sure the comparison feeding a branch is after any + // safepoints. Otherwise, we end up with a comparison of pre-relocation + // values feeding a branch after relocation. This is semantically correct, + // but results in extra register pressure since both the pre-relocation and + // post-relocation copies must be available in registers. For code without + // relocations this is handled elsewhere, but teaching the scheduler to + // reverse the transform we're about to do would be slightly complex. + // Note: This may extend the live range of the inputs to the icmp and thus + // increase the liveset of any statepoint we move over. This is profitable + // as long as all statepoints are in rare blocks. If we had in-register + // lowering for live values this would be a much safer transform. + auto getConditionInst = [](TerminatorInst *TI) -> Instruction* { + if (auto *BI = dyn_cast<BranchInst>(TI)) + if (BI->isConditional()) + return dyn_cast<Instruction>(BI->getCondition()); + // TODO: Extend this to handle switches + return nullptr; + }; + for (BasicBlock &BB : F) { + TerminatorInst *TI = BB.getTerminator(); + if (auto *Cond = getConditionInst(TI)) + // TODO: Handle more than just ICmps here. We should be able to move + // most instructions without side effects or memory access. + if (isa<ICmpInst>(Cond) && Cond->hasOneUse()) { + MadeChange = true; + Cond->moveBefore(TI); + } + } + + MadeChange |= insertParsePoints(F, DT, TTI, ParsePointNeeded); + return MadeChange; +} + +// liveness computation via standard dataflow +// ------------------------------------------------------------------- + +// TODO: Consider using bitvectors for liveness, the set of potentially +// interesting values should be small and easy to pre-compute. + +/// Compute the live-in set for the location rbegin starting from +/// the live-out set of the basic block +static void computeLiveInValues(BasicBlock::reverse_iterator Begin, + BasicBlock::reverse_iterator End, + SetVector<Value *> &LiveTmp) { + for (auto &I : make_range(Begin, End)) { + // KILL/Def - Remove this definition from LiveIn + LiveTmp.remove(&I); + + // Don't consider *uses* in PHI nodes, we handle their contribution to + // predecessor blocks when we seed the LiveOut sets + if (isa<PHINode>(I)) + continue; + + // USE - Add to the LiveIn set for this instruction + for (Value *V : I.operands()) { + assert(!isUnhandledGCPointerType(V->getType()) && + "support for FCA unimplemented"); + if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) { + // The choice to exclude all things constant here is slightly subtle. + // There are two independent reasons: + // - We assume that things which are constant (from LLVM's definition) + // do not move at runtime. For example, the address of a global + // variable is fixed, even though it's contents may not be. + // - Second, we can't disallow arbitrary inttoptr constants even + // if the language frontend does. Optimization passes are free to + // locally exploit facts without respect to global reachability. This + // can create sections of code which are dynamically unreachable and + // contain just about anything. (see constants.ll in tests) + LiveTmp.insert(V); + } + } + } +} + +static void computeLiveOutSeed(BasicBlock *BB, SetVector<Value *> &LiveTmp) { + for (BasicBlock *Succ : successors(BB)) { + for (auto &I : *Succ) { + PHINode *PN = dyn_cast<PHINode>(&I); + if (!PN) + break; + + Value *V = PN->getIncomingValueForBlock(BB); + assert(!isUnhandledGCPointerType(V->getType()) && + "support for FCA unimplemented"); + if (isHandledGCPointerType(V->getType()) && !isa<Constant>(V)) + LiveTmp.insert(V); + } + } +} + +static SetVector<Value *> computeKillSet(BasicBlock *BB) { + SetVector<Value *> KillSet; + for (Instruction &I : *BB) + if (isHandledGCPointerType(I.getType())) + KillSet.insert(&I); + return KillSet; +} + +#ifndef NDEBUG +/// Check that the items in 'Live' dominate 'TI'. This is used as a basic +/// sanity check for the liveness computation. +static void checkBasicSSA(DominatorTree &DT, SetVector<Value *> &Live, + TerminatorInst *TI, bool TermOkay = false) { + for (Value *V : Live) { + if (auto *I = dyn_cast<Instruction>(V)) { + // The terminator can be a member of the LiveOut set. LLVM's definition + // of instruction dominance states that V does not dominate itself. As + // such, we need to special case this to allow it. + if (TermOkay && TI == I) + continue; + assert(DT.dominates(I, TI) && + "basic SSA liveness expectation violated by liveness analysis"); + } + } +} + +/// Check that all the liveness sets used during the computation of liveness +/// obey basic SSA properties. This is useful for finding cases where we miss +/// a def. +static void checkBasicSSA(DominatorTree &DT, GCPtrLivenessData &Data, + BasicBlock &BB) { + checkBasicSSA(DT, Data.LiveSet[&BB], BB.getTerminator()); + checkBasicSSA(DT, Data.LiveOut[&BB], BB.getTerminator(), true); + checkBasicSSA(DT, Data.LiveIn[&BB], BB.getTerminator()); +} +#endif + +static void computeLiveInValues(DominatorTree &DT, Function &F, + GCPtrLivenessData &Data) { + SmallSetVector<BasicBlock *, 32> Worklist; + + // Seed the liveness for each individual block + for (BasicBlock &BB : F) { + Data.KillSet[&BB] = computeKillSet(&BB); + Data.LiveSet[&BB].clear(); + computeLiveInValues(BB.rbegin(), BB.rend(), Data.LiveSet[&BB]); + +#ifndef NDEBUG + for (Value *Kill : Data.KillSet[&BB]) + assert(!Data.LiveSet[&BB].count(Kill) && "live set contains kill"); +#endif + + Data.LiveOut[&BB] = SetVector<Value *>(); + computeLiveOutSeed(&BB, Data.LiveOut[&BB]); + Data.LiveIn[&BB] = Data.LiveSet[&BB]; + Data.LiveIn[&BB].set_union(Data.LiveOut[&BB]); + Data.LiveIn[&BB].set_subtract(Data.KillSet[&BB]); + if (!Data.LiveIn[&BB].empty()) + Worklist.insert(pred_begin(&BB), pred_end(&BB)); + } + + // Propagate that liveness until stable + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + + // Compute our new liveout set, then exit early if it hasn't changed despite + // the contribution of our successor. + SetVector<Value *> LiveOut = Data.LiveOut[BB]; + const auto OldLiveOutSize = LiveOut.size(); + for (BasicBlock *Succ : successors(BB)) { + assert(Data.LiveIn.count(Succ)); + LiveOut.set_union(Data.LiveIn[Succ]); + } + // assert OutLiveOut is a subset of LiveOut + if (OldLiveOutSize == LiveOut.size()) { + // If the sets are the same size, then we didn't actually add anything + // when unioning our successors LiveIn. Thus, the LiveIn of this block + // hasn't changed. + continue; + } + Data.LiveOut[BB] = LiveOut; + + // Apply the effects of this basic block + SetVector<Value *> LiveTmp = LiveOut; + LiveTmp.set_union(Data.LiveSet[BB]); + LiveTmp.set_subtract(Data.KillSet[BB]); + + assert(Data.LiveIn.count(BB)); + const SetVector<Value *> &OldLiveIn = Data.LiveIn[BB]; + // assert: OldLiveIn is a subset of LiveTmp + if (OldLiveIn.size() != LiveTmp.size()) { + Data.LiveIn[BB] = LiveTmp; + Worklist.insert(pred_begin(BB), pred_end(BB)); + } + } // while (!Worklist.empty()) + +#ifndef NDEBUG + // Sanity check our output against SSA properties. This helps catch any + // missing kills during the above iteration. + for (BasicBlock &BB : F) + checkBasicSSA(DT, Data, BB); +#endif +} + +static void findLiveSetAtInst(Instruction *Inst, GCPtrLivenessData &Data, + StatepointLiveSetTy &Out) { + BasicBlock *BB = Inst->getParent(); + + // Note: The copy is intentional and required + assert(Data.LiveOut.count(BB)); + SetVector<Value *> LiveOut = Data.LiveOut[BB]; + + // We want to handle the statepoint itself oddly. It's + // call result is not live (normal), nor are it's arguments + // (unless they're used again later). This adjustment is + // specifically what we need to relocate + computeLiveInValues(BB->rbegin(), ++Inst->getIterator().getReverse(), + LiveOut); + LiveOut.remove(Inst); + Out.insert(LiveOut.begin(), LiveOut.end()); +} + +static void recomputeLiveInValues(GCPtrLivenessData &RevisedLivenessData, + CallSite CS, + PartiallyConstructedSafepointRecord &Info) { + Instruction *Inst = CS.getInstruction(); + StatepointLiveSetTy Updated; + findLiveSetAtInst(Inst, RevisedLivenessData, Updated); + + // We may have base pointers which are now live that weren't before. We need + // to update the PointerToBase structure to reflect this. + for (auto V : Updated) + if (Info.PointerToBase.insert({V, V}).second) { + assert(isKnownBaseResult(V) && + "Can't find base for unexpected live value!"); + continue; + } + +#ifndef NDEBUG + for (auto V : Updated) + assert(Info.PointerToBase.count(V) && + "Must be able to find base for live value!"); +#endif + + // Remove any stale base mappings - this can happen since our liveness is + // more precise then the one inherent in the base pointer analysis. + DenseSet<Value *> ToErase; + for (auto KVPair : Info.PointerToBase) + if (!Updated.count(KVPair.first)) + ToErase.insert(KVPair.first); + + for (auto *V : ToErase) + Info.PointerToBase.erase(V); + +#ifndef NDEBUG + for (auto KVPair : Info.PointerToBase) + assert(Updated.count(KVPair.first) && "record for non-live value"); +#endif + + Info.LiveSet = Updated; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp b/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp new file mode 100644 index 000000000000..9dc550ceaeca --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SCCP.cpp @@ -0,0 +1,2068 @@ +//===- SCCP.cpp - Sparse Conditional Constant Propagation -----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements sparse conditional constant propagation and merging: +// +// Specifically, this: +// * Assumes values are constant unless proven otherwise +// * Assumes BasicBlocks are dead unless proven otherwise +// * Proves values to be constant, and replaces them with constants +// * Proves conditional branches to be unconditional +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/IPO/SCCP.h" +#include "llvm/Transforms/Scalar/SCCP.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DenseSet.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/ConstantFolding.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/ValueLattice.h" +#include "llvm/Analysis/ValueLatticeUtils.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/IPO.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <utility> +#include <vector> + +using namespace llvm; + +#define DEBUG_TYPE "sccp" + +STATISTIC(NumInstRemoved, "Number of instructions removed"); +STATISTIC(NumDeadBlocks , "Number of basic blocks unreachable"); + +STATISTIC(IPNumInstRemoved, "Number of instructions removed by IPSCCP"); +STATISTIC(IPNumArgsElimed ,"Number of arguments constant propagated by IPSCCP"); +STATISTIC(IPNumGlobalConst, "Number of globals found to be constant by IPSCCP"); +STATISTIC(IPNumRangeInfoUsed, "Number of times constant range info was used by" + "IPSCCP"); + +namespace { + +/// LatticeVal class - This class represents the different lattice values that +/// an LLVM value may occupy. It is a simple class with value semantics. +/// +class LatticeVal { + enum LatticeValueTy { + /// unknown - This LLVM Value has no known value yet. + unknown, + + /// constant - This LLVM Value has a specific constant value. + constant, + + /// forcedconstant - This LLVM Value was thought to be undef until + /// ResolvedUndefsIn. This is treated just like 'constant', but if merged + /// with another (different) constant, it goes to overdefined, instead of + /// asserting. + forcedconstant, + + /// overdefined - This instruction is not known to be constant, and we know + /// it has a value. + overdefined + }; + + /// Val: This stores the current lattice value along with the Constant* for + /// the constant if this is a 'constant' or 'forcedconstant' value. + PointerIntPair<Constant *, 2, LatticeValueTy> Val; + + LatticeValueTy getLatticeValue() const { + return Val.getInt(); + } + +public: + LatticeVal() : Val(nullptr, unknown) {} + + bool isUnknown() const { return getLatticeValue() == unknown; } + + bool isConstant() const { + return getLatticeValue() == constant || getLatticeValue() == forcedconstant; + } + + bool isOverdefined() const { return getLatticeValue() == overdefined; } + + Constant *getConstant() const { + assert(isConstant() && "Cannot get the constant of a non-constant!"); + return Val.getPointer(); + } + + /// markOverdefined - Return true if this is a change in status. + bool markOverdefined() { + if (isOverdefined()) + return false; + + Val.setInt(overdefined); + return true; + } + + /// markConstant - Return true if this is a change in status. + bool markConstant(Constant *V) { + if (getLatticeValue() == constant) { // Constant but not forcedconstant. + assert(getConstant() == V && "Marking constant with different value"); + return false; + } + + if (isUnknown()) { + Val.setInt(constant); + assert(V && "Marking constant with NULL"); + Val.setPointer(V); + } else { + assert(getLatticeValue() == forcedconstant && + "Cannot move from overdefined to constant!"); + // Stay at forcedconstant if the constant is the same. + if (V == getConstant()) return false; + + // Otherwise, we go to overdefined. Assumptions made based on the + // forced value are possibly wrong. Assuming this is another constant + // could expose a contradiction. + Val.setInt(overdefined); + } + return true; + } + + /// getConstantInt - If this is a constant with a ConstantInt value, return it + /// otherwise return null. + ConstantInt *getConstantInt() const { + if (isConstant()) + return dyn_cast<ConstantInt>(getConstant()); + return nullptr; + } + + /// getBlockAddress - If this is a constant with a BlockAddress value, return + /// it, otherwise return null. + BlockAddress *getBlockAddress() const { + if (isConstant()) + return dyn_cast<BlockAddress>(getConstant()); + return nullptr; + } + + void markForcedConstant(Constant *V) { + assert(isUnknown() && "Can't force a defined value!"); + Val.setInt(forcedconstant); + Val.setPointer(V); + } + + ValueLatticeElement toValueLattice() const { + if (isOverdefined()) + return ValueLatticeElement::getOverdefined(); + if (isConstant()) + return ValueLatticeElement::get(getConstant()); + return ValueLatticeElement(); + } +}; + +//===----------------------------------------------------------------------===// +// +/// SCCPSolver - This class is a general purpose solver for Sparse Conditional +/// Constant Propagation. +/// +class SCCPSolver : public InstVisitor<SCCPSolver> { + const DataLayout &DL; + const TargetLibraryInfo *TLI; + SmallPtrSet<BasicBlock *, 8> BBExecutable; // The BBs that are executable. + DenseMap<Value *, LatticeVal> ValueState; // The state each value is in. + // The state each parameter is in. + DenseMap<Value *, ValueLatticeElement> ParamState; + + /// StructValueState - This maintains ValueState for values that have + /// StructType, for example for formal arguments, calls, insertelement, etc. + DenseMap<std::pair<Value *, unsigned>, LatticeVal> StructValueState; + + /// GlobalValue - If we are tracking any values for the contents of a global + /// variable, we keep a mapping from the constant accessor to the element of + /// the global, to the currently known value. If the value becomes + /// overdefined, it's entry is simply removed from this map. + DenseMap<GlobalVariable *, LatticeVal> TrackedGlobals; + + /// TrackedRetVals - If we are tracking arguments into and the return + /// value out of a function, it will have an entry in this map, indicating + /// what the known return value for the function is. + DenseMap<Function *, LatticeVal> TrackedRetVals; + + /// TrackedMultipleRetVals - Same as TrackedRetVals, but used for functions + /// that return multiple values. + DenseMap<std::pair<Function *, unsigned>, LatticeVal> TrackedMultipleRetVals; + + /// MRVFunctionsTracked - Each function in TrackedMultipleRetVals is + /// represented here for efficient lookup. + SmallPtrSet<Function *, 16> MRVFunctionsTracked; + + /// TrackingIncomingArguments - This is the set of functions for whose + /// arguments we make optimistic assumptions about and try to prove as + /// constants. + SmallPtrSet<Function *, 16> TrackingIncomingArguments; + + /// The reason for two worklists is that overdefined is the lowest state + /// on the lattice, and moving things to overdefined as fast as possible + /// makes SCCP converge much faster. + /// + /// By having a separate worklist, we accomplish this because everything + /// possibly overdefined will become overdefined at the soonest possible + /// point. + SmallVector<Value *, 64> OverdefinedInstWorkList; + SmallVector<Value *, 64> InstWorkList; + + // The BasicBlock work list + SmallVector<BasicBlock *, 64> BBWorkList; + + /// KnownFeasibleEdges - Entries in this set are edges which have already had + /// PHI nodes retriggered. + using Edge = std::pair<BasicBlock *, BasicBlock *>; + DenseSet<Edge> KnownFeasibleEdges; + +public: + SCCPSolver(const DataLayout &DL, const TargetLibraryInfo *tli) + : DL(DL), TLI(tli) {} + + /// MarkBlockExecutable - This method can be used by clients to mark all of + /// the blocks that are known to be intrinsically live in the processed unit. + /// + /// This returns true if the block was not considered live before. + bool MarkBlockExecutable(BasicBlock *BB) { + if (!BBExecutable.insert(BB).second) + return false; + DEBUG(dbgs() << "Marking Block Executable: " << BB->getName() << '\n'); + BBWorkList.push_back(BB); // Add the block to the work list! + return true; + } + + /// TrackValueOfGlobalVariable - Clients can use this method to + /// inform the SCCPSolver that it should track loads and stores to the + /// specified global variable if it can. This is only legal to call if + /// performing Interprocedural SCCP. + void TrackValueOfGlobalVariable(GlobalVariable *GV) { + // We only track the contents of scalar globals. + if (GV->getValueType()->isSingleValueType()) { + LatticeVal &IV = TrackedGlobals[GV]; + if (!isa<UndefValue>(GV->getInitializer())) + IV.markConstant(GV->getInitializer()); + } + } + + /// AddTrackedFunction - If the SCCP solver is supposed to track calls into + /// and out of the specified function (which cannot have its address taken), + /// this method must be called. + void AddTrackedFunction(Function *F) { + // Add an entry, F -> undef. + if (auto *STy = dyn_cast<StructType>(F->getReturnType())) { + MRVFunctionsTracked.insert(F); + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) + TrackedMultipleRetVals.insert(std::make_pair(std::make_pair(F, i), + LatticeVal())); + } else + TrackedRetVals.insert(std::make_pair(F, LatticeVal())); + } + + void AddArgumentTrackedFunction(Function *F) { + TrackingIncomingArguments.insert(F); + } + + /// Returns true if the given function is in the solver's set of + /// argument-tracked functions. + bool isArgumentTrackedFunction(Function *F) { + return TrackingIncomingArguments.count(F); + } + + /// Solve - Solve for constants and executable blocks. + void Solve(); + + /// ResolvedUndefsIn - While solving the dataflow for a function, we assume + /// that branches on undef values cannot reach any of their successors. + /// However, this is not a safe assumption. After we solve dataflow, this + /// method should be use to handle this. If this returns true, the solver + /// should be rerun. + bool ResolvedUndefsIn(Function &F); + + bool isBlockExecutable(BasicBlock *BB) const { + return BBExecutable.count(BB); + } + + std::vector<LatticeVal> getStructLatticeValueFor(Value *V) const { + std::vector<LatticeVal> StructValues; + auto *STy = dyn_cast<StructType>(V->getType()); + assert(STy && "getStructLatticeValueFor() can be called only on structs"); + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + auto I = StructValueState.find(std::make_pair(V, i)); + assert(I != StructValueState.end() && "Value not in valuemap!"); + StructValues.push_back(I->second); + } + return StructValues; + } + + ValueLatticeElement getLatticeValueFor(Value *V) { + assert(!V->getType()->isStructTy() && + "Should use getStructLatticeValueFor"); + std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool> + PI = ParamState.insert(std::make_pair(V, ValueLatticeElement())); + ValueLatticeElement &LV = PI.first->second; + if (PI.second) { + DenseMap<Value*, LatticeVal>::const_iterator I = ValueState.find(V); + assert(I != ValueState.end() && + "V not found in ValueState nor Paramstate map!"); + LV = I->second.toValueLattice(); + } + + return LV; + } + + /// getTrackedRetVals - Get the inferred return value map. + const DenseMap<Function*, LatticeVal> &getTrackedRetVals() { + return TrackedRetVals; + } + + /// getTrackedGlobals - Get and return the set of inferred initializers for + /// global variables. + const DenseMap<GlobalVariable*, LatticeVal> &getTrackedGlobals() { + return TrackedGlobals; + } + + /// getMRVFunctionsTracked - Get the set of functions which return multiple + /// values tracked by the pass. + const SmallPtrSet<Function *, 16> getMRVFunctionsTracked() { + return MRVFunctionsTracked; + } + + /// markOverdefined - Mark the specified value overdefined. This + /// works with both scalars and structs. + void markOverdefined(Value *V) { + if (auto *STy = dyn_cast<StructType>(V->getType())) + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) + markOverdefined(getStructValueState(V, i), V); + else + markOverdefined(ValueState[V], V); + } + + // isStructLatticeConstant - Return true if all the lattice values + // corresponding to elements of the structure are not overdefined, + // false otherwise. + bool isStructLatticeConstant(Function *F, StructType *STy) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + const auto &It = TrackedMultipleRetVals.find(std::make_pair(F, i)); + assert(It != TrackedMultipleRetVals.end()); + LatticeVal LV = It->second; + if (LV.isOverdefined()) + return false; + } + return true; + } + +private: + // pushToWorkList - Helper for markConstant/markForcedConstant/markOverdefined + void pushToWorkList(LatticeVal &IV, Value *V) { + if (IV.isOverdefined()) + return OverdefinedInstWorkList.push_back(V); + InstWorkList.push_back(V); + } + + // markConstant - Make a value be marked as "constant". If the value + // is not already a constant, add it to the instruction work list so that + // the users of the instruction are updated later. + void markConstant(LatticeVal &IV, Value *V, Constant *C) { + if (!IV.markConstant(C)) return; + DEBUG(dbgs() << "markConstant: " << *C << ": " << *V << '\n'); + pushToWorkList(IV, V); + } + + void markConstant(Value *V, Constant *C) { + assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); + markConstant(ValueState[V], V, C); + } + + void markForcedConstant(Value *V, Constant *C) { + assert(!V->getType()->isStructTy() && "structs should use mergeInValue"); + LatticeVal &IV = ValueState[V]; + IV.markForcedConstant(C); + DEBUG(dbgs() << "markForcedConstant: " << *C << ": " << *V << '\n'); + pushToWorkList(IV, V); + } + + // markOverdefined - Make a value be marked as "overdefined". If the + // value is not already overdefined, add it to the overdefined instruction + // work list so that the users of the instruction are updated later. + void markOverdefined(LatticeVal &IV, Value *V) { + if (!IV.markOverdefined()) return; + + DEBUG(dbgs() << "markOverdefined: "; + if (auto *F = dyn_cast<Function>(V)) + dbgs() << "Function '" << F->getName() << "'\n"; + else + dbgs() << *V << '\n'); + // Only instructions go on the work list + pushToWorkList(IV, V); + } + + void mergeInValue(LatticeVal &IV, Value *V, LatticeVal MergeWithV) { + if (IV.isOverdefined() || MergeWithV.isUnknown()) + return; // Noop. + if (MergeWithV.isOverdefined()) + return markOverdefined(IV, V); + if (IV.isUnknown()) + return markConstant(IV, V, MergeWithV.getConstant()); + if (IV.getConstant() != MergeWithV.getConstant()) + return markOverdefined(IV, V); + } + + void mergeInValue(Value *V, LatticeVal MergeWithV) { + assert(!V->getType()->isStructTy() && + "non-structs should use markConstant"); + mergeInValue(ValueState[V], V, MergeWithV); + } + + /// getValueState - Return the LatticeVal object that corresponds to the + /// value. This function handles the case when the value hasn't been seen yet + /// by properly seeding constants etc. + LatticeVal &getValueState(Value *V) { + assert(!V->getType()->isStructTy() && "Should use getStructValueState"); + + std::pair<DenseMap<Value*, LatticeVal>::iterator, bool> I = + ValueState.insert(std::make_pair(V, LatticeVal())); + LatticeVal &LV = I.first->second; + + if (!I.second) + return LV; // Common case, already in the map. + + if (auto *C = dyn_cast<Constant>(V)) { + // Undef values remain unknown. + if (!isa<UndefValue>(V)) + LV.markConstant(C); // Constants are constant + } + + // All others are underdefined by default. + return LV; + } + + ValueLatticeElement &getParamState(Value *V) { + assert(!V->getType()->isStructTy() && "Should use getStructValueState"); + + std::pair<DenseMap<Value*, ValueLatticeElement>::iterator, bool> + PI = ParamState.insert(std::make_pair(V, ValueLatticeElement())); + ValueLatticeElement &LV = PI.first->second; + if (PI.second) + LV = getValueState(V).toValueLattice(); + + return LV; + } + + /// getStructValueState - Return the LatticeVal object that corresponds to the + /// value/field pair. This function handles the case when the value hasn't + /// been seen yet by properly seeding constants etc. + LatticeVal &getStructValueState(Value *V, unsigned i) { + assert(V->getType()->isStructTy() && "Should use getValueState"); + assert(i < cast<StructType>(V->getType())->getNumElements() && + "Invalid element #"); + + std::pair<DenseMap<std::pair<Value*, unsigned>, LatticeVal>::iterator, + bool> I = StructValueState.insert( + std::make_pair(std::make_pair(V, i), LatticeVal())); + LatticeVal &LV = I.first->second; + + if (!I.second) + return LV; // Common case, already in the map. + + if (auto *C = dyn_cast<Constant>(V)) { + Constant *Elt = C->getAggregateElement(i); + + if (!Elt) + LV.markOverdefined(); // Unknown sort of constant. + else if (isa<UndefValue>(Elt)) + ; // Undef values remain unknown. + else + LV.markConstant(Elt); // Constants are constant. + } + + // All others are underdefined by default. + return LV; + } + + /// markEdgeExecutable - Mark a basic block as executable, adding it to the BB + /// work list if it is not already executable. + void markEdgeExecutable(BasicBlock *Source, BasicBlock *Dest) { + if (!KnownFeasibleEdges.insert(Edge(Source, Dest)).second) + return; // This edge is already known to be executable! + + if (!MarkBlockExecutable(Dest)) { + // If the destination is already executable, we just made an *edge* + // feasible that wasn't before. Revisit the PHI nodes in the block + // because they have potentially new operands. + DEBUG(dbgs() << "Marking Edge Executable: " << Source->getName() + << " -> " << Dest->getName() << '\n'); + + for (PHINode &PN : Dest->phis()) + visitPHINode(PN); + } + } + + // getFeasibleSuccessors - Return a vector of booleans to indicate which + // successors are reachable from a given terminator instruction. + void getFeasibleSuccessors(TerminatorInst &TI, SmallVectorImpl<bool> &Succs); + + // isEdgeFeasible - Return true if the control flow edge from the 'From' basic + // block to the 'To' basic block is currently feasible. + bool isEdgeFeasible(BasicBlock *From, BasicBlock *To); + + // OperandChangedState - This method is invoked on all of the users of an + // instruction that was just changed state somehow. Based on this + // information, we need to update the specified user of this instruction. + void OperandChangedState(Instruction *I) { + if (BBExecutable.count(I->getParent())) // Inst is executable? + visit(*I); + } + +private: + friend class InstVisitor<SCCPSolver>; + + // visit implementations - Something changed in this instruction. Either an + // operand made a transition, or the instruction is newly executable. Change + // the value type of I to reflect these changes if appropriate. + void visitPHINode(PHINode &I); + + // Terminators + + void visitReturnInst(ReturnInst &I); + void visitTerminatorInst(TerminatorInst &TI); + + void visitCastInst(CastInst &I); + void visitSelectInst(SelectInst &I); + void visitBinaryOperator(Instruction &I); + void visitCmpInst(CmpInst &I); + void visitExtractValueInst(ExtractValueInst &EVI); + void visitInsertValueInst(InsertValueInst &IVI); + + void visitCatchSwitchInst(CatchSwitchInst &CPI) { + markOverdefined(&CPI); + visitTerminatorInst(CPI); + } + + // Instructions that cannot be folded away. + + void visitStoreInst (StoreInst &I); + void visitLoadInst (LoadInst &I); + void visitGetElementPtrInst(GetElementPtrInst &I); + + void visitCallInst (CallInst &I) { + visitCallSite(&I); + } + + void visitInvokeInst (InvokeInst &II) { + visitCallSite(&II); + visitTerminatorInst(II); + } + + void visitCallSite (CallSite CS); + void visitResumeInst (TerminatorInst &I) { /*returns void*/ } + void visitUnreachableInst(TerminatorInst &I) { /*returns void*/ } + void visitFenceInst (FenceInst &I) { /*returns void*/ } + + void visitInstruction(Instruction &I) { + // All the instructions we don't do any special handling for just + // go to overdefined. + DEBUG(dbgs() << "SCCP: Don't know how to handle: " << I << '\n'); + markOverdefined(&I); + } +}; + +} // end anonymous namespace + +// getFeasibleSuccessors - Return a vector of booleans to indicate which +// successors are reachable from a given terminator instruction. +void SCCPSolver::getFeasibleSuccessors(TerminatorInst &TI, + SmallVectorImpl<bool> &Succs) { + Succs.resize(TI.getNumSuccessors()); + if (auto *BI = dyn_cast<BranchInst>(&TI)) { + if (BI->isUnconditional()) { + Succs[0] = true; + return; + } + + LatticeVal BCValue = getValueState(BI->getCondition()); + ConstantInt *CI = BCValue.getConstantInt(); + if (!CI) { + // Overdefined condition variables, and branches on unfoldable constant + // conditions, mean the branch could go either way. + if (!BCValue.isUnknown()) + Succs[0] = Succs[1] = true; + return; + } + + // Constant condition variables mean the branch can only go a single way. + Succs[CI->isZero()] = true; + return; + } + + // Unwinding instructions successors are always executable. + if (TI.isExceptional()) { + Succs.assign(TI.getNumSuccessors(), true); + return; + } + + if (auto *SI = dyn_cast<SwitchInst>(&TI)) { + if (!SI->getNumCases()) { + Succs[0] = true; + return; + } + LatticeVal SCValue = getValueState(SI->getCondition()); + ConstantInt *CI = SCValue.getConstantInt(); + + if (!CI) { // Overdefined or unknown condition? + // All destinations are executable! + if (!SCValue.isUnknown()) + Succs.assign(TI.getNumSuccessors(), true); + return; + } + + Succs[SI->findCaseValue(CI)->getSuccessorIndex()] = true; + return; + } + + // In case of indirect branch and its address is a blockaddress, we mark + // the target as executable. + if (auto *IBR = dyn_cast<IndirectBrInst>(&TI)) { + // Casts are folded by visitCastInst. + LatticeVal IBRValue = getValueState(IBR->getAddress()); + BlockAddress *Addr = IBRValue.getBlockAddress(); + if (!Addr) { // Overdefined or unknown condition? + // All destinations are executable! + if (!IBRValue.isUnknown()) + Succs.assign(TI.getNumSuccessors(), true); + return; + } + + BasicBlock* T = Addr->getBasicBlock(); + assert(Addr->getFunction() == T->getParent() && + "Block address of a different function ?"); + for (unsigned i = 0; i < IBR->getNumSuccessors(); ++i) { + // This is the target. + if (IBR->getDestination(i) == T) { + Succs[i] = true; + return; + } + } + + // If we didn't find our destination in the IBR successor list, then we + // have undefined behavior. Its ok to assume no successor is executable. + return; + } + + DEBUG(dbgs() << "Unknown terminator instruction: " << TI << '\n'); + llvm_unreachable("SCCP: Don't know how to handle this terminator!"); +} + +// isEdgeFeasible - Return true if the control flow edge from the 'From' basic +// block to the 'To' basic block is currently feasible. +bool SCCPSolver::isEdgeFeasible(BasicBlock *From, BasicBlock *To) { + assert(BBExecutable.count(To) && "Dest should always be alive!"); + + // Make sure the source basic block is executable!! + if (!BBExecutable.count(From)) return false; + + // Check to make sure this edge itself is actually feasible now. + TerminatorInst *TI = From->getTerminator(); + if (auto *BI = dyn_cast<BranchInst>(TI)) { + if (BI->isUnconditional()) + return true; + + LatticeVal BCValue = getValueState(BI->getCondition()); + + // Overdefined condition variables mean the branch could go either way, + // undef conditions mean that neither edge is feasible yet. + ConstantInt *CI = BCValue.getConstantInt(); + if (!CI) + return !BCValue.isUnknown(); + + // Constant condition variables mean the branch can only go a single way. + return BI->getSuccessor(CI->isZero()) == To; + } + + // Unwinding instructions successors are always executable. + if (TI->isExceptional()) + return true; + + if (auto *SI = dyn_cast<SwitchInst>(TI)) { + if (SI->getNumCases() < 1) + return true; + + LatticeVal SCValue = getValueState(SI->getCondition()); + ConstantInt *CI = SCValue.getConstantInt(); + + if (!CI) + return !SCValue.isUnknown(); + + return SI->findCaseValue(CI)->getCaseSuccessor() == To; + } + + // In case of indirect branch and its address is a blockaddress, we mark + // the target as executable. + if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { + LatticeVal IBRValue = getValueState(IBR->getAddress()); + BlockAddress *Addr = IBRValue.getBlockAddress(); + + if (!Addr) + return !IBRValue.isUnknown(); + + // At this point, the indirectbr is branching on a blockaddress. + return Addr->getBasicBlock() == To; + } + + DEBUG(dbgs() << "Unknown terminator instruction: " << *TI << '\n'); + llvm_unreachable("SCCP: Don't know how to handle this terminator!"); +} + +// visit Implementations - Something changed in this instruction, either an +// operand made a transition, or the instruction is newly executable. Change +// the value type of I to reflect these changes if appropriate. This method +// makes sure to do the following actions: +// +// 1. If a phi node merges two constants in, and has conflicting value coming +// from different branches, or if the PHI node merges in an overdefined +// value, then the PHI node becomes overdefined. +// 2. If a phi node merges only constants in, and they all agree on value, the +// PHI node becomes a constant value equal to that. +// 3. If V <- x (op) y && isConstant(x) && isConstant(y) V = Constant +// 4. If V <- x (op) y && (isOverdefined(x) || isOverdefined(y)) V = Overdefined +// 5. If V <- MEM or V <- CALL or V <- (unknown) then V = Overdefined +// 6. If a conditional branch has a value that is constant, make the selected +// destination executable +// 7. If a conditional branch has a value that is overdefined, make all +// successors executable. +void SCCPSolver::visitPHINode(PHINode &PN) { + // If this PN returns a struct, just mark the result overdefined. + // TODO: We could do a lot better than this if code actually uses this. + if (PN.getType()->isStructTy()) + return markOverdefined(&PN); + + if (getValueState(&PN).isOverdefined()) + return; // Quick exit + + // Super-extra-high-degree PHI nodes are unlikely to ever be marked constant, + // and slow us down a lot. Just mark them overdefined. + if (PN.getNumIncomingValues() > 64) + return markOverdefined(&PN); + + // Look at all of the executable operands of the PHI node. If any of them + // are overdefined, the PHI becomes overdefined as well. If they are all + // constant, and they agree with each other, the PHI becomes the identical + // constant. If they are constant and don't agree, the PHI is overdefined. + // If there are no executable operands, the PHI remains unknown. + Constant *OperandVal = nullptr; + for (unsigned i = 0, e = PN.getNumIncomingValues(); i != e; ++i) { + LatticeVal IV = getValueState(PN.getIncomingValue(i)); + if (IV.isUnknown()) continue; // Doesn't influence PHI node. + + if (!isEdgeFeasible(PN.getIncomingBlock(i), PN.getParent())) + continue; + + if (IV.isOverdefined()) // PHI node becomes overdefined! + return markOverdefined(&PN); + + if (!OperandVal) { // Grab the first value. + OperandVal = IV.getConstant(); + continue; + } + + // There is already a reachable operand. If we conflict with it, + // then the PHI node becomes overdefined. If we agree with it, we + // can continue on. + + // Check to see if there are two different constants merging, if so, the PHI + // node is overdefined. + if (IV.getConstant() != OperandVal) + return markOverdefined(&PN); + } + + // If we exited the loop, this means that the PHI node only has constant + // arguments that agree with each other(and OperandVal is the constant) or + // OperandVal is null because there are no defined incoming arguments. If + // this is the case, the PHI remains unknown. + if (OperandVal) + markConstant(&PN, OperandVal); // Acquire operand value +} + +void SCCPSolver::visitReturnInst(ReturnInst &I) { + if (I.getNumOperands() == 0) return; // ret void + + Function *F = I.getParent()->getParent(); + Value *ResultOp = I.getOperand(0); + + // If we are tracking the return value of this function, merge it in. + if (!TrackedRetVals.empty() && !ResultOp->getType()->isStructTy()) { + DenseMap<Function*, LatticeVal>::iterator TFRVI = + TrackedRetVals.find(F); + if (TFRVI != TrackedRetVals.end()) { + mergeInValue(TFRVI->second, F, getValueState(ResultOp)); + return; + } + } + + // Handle functions that return multiple values. + if (!TrackedMultipleRetVals.empty()) { + if (auto *STy = dyn_cast<StructType>(ResultOp->getType())) + if (MRVFunctionsTracked.count(F)) + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) + mergeInValue(TrackedMultipleRetVals[std::make_pair(F, i)], F, + getStructValueState(ResultOp, i)); + } +} + +void SCCPSolver::visitTerminatorInst(TerminatorInst &TI) { + SmallVector<bool, 16> SuccFeasible; + getFeasibleSuccessors(TI, SuccFeasible); + + BasicBlock *BB = TI.getParent(); + + // Mark all feasible successors executable. + for (unsigned i = 0, e = SuccFeasible.size(); i != e; ++i) + if (SuccFeasible[i]) + markEdgeExecutable(BB, TI.getSuccessor(i)); +} + +void SCCPSolver::visitCastInst(CastInst &I) { + LatticeVal OpSt = getValueState(I.getOperand(0)); + if (OpSt.isOverdefined()) // Inherit overdefinedness of operand + markOverdefined(&I); + else if (OpSt.isConstant()) { + // Fold the constant as we build. + Constant *C = ConstantFoldCastOperand(I.getOpcode(), OpSt.getConstant(), + I.getType(), DL); + if (isa<UndefValue>(C)) + return; + // Propagate constant value + markConstant(&I, C); + } +} + +void SCCPSolver::visitExtractValueInst(ExtractValueInst &EVI) { + // If this returns a struct, mark all elements over defined, we don't track + // structs in structs. + if (EVI.getType()->isStructTy()) + return markOverdefined(&EVI); + + // If this is extracting from more than one level of struct, we don't know. + if (EVI.getNumIndices() != 1) + return markOverdefined(&EVI); + + Value *AggVal = EVI.getAggregateOperand(); + if (AggVal->getType()->isStructTy()) { + unsigned i = *EVI.idx_begin(); + LatticeVal EltVal = getStructValueState(AggVal, i); + mergeInValue(getValueState(&EVI), &EVI, EltVal); + } else { + // Otherwise, must be extracting from an array. + return markOverdefined(&EVI); + } +} + +void SCCPSolver::visitInsertValueInst(InsertValueInst &IVI) { + auto *STy = dyn_cast<StructType>(IVI.getType()); + if (!STy) + return markOverdefined(&IVI); + + // If this has more than one index, we can't handle it, drive all results to + // undef. + if (IVI.getNumIndices() != 1) + return markOverdefined(&IVI); + + Value *Aggr = IVI.getAggregateOperand(); + unsigned Idx = *IVI.idx_begin(); + + // Compute the result based on what we're inserting. + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + // This passes through all values that aren't the inserted element. + if (i != Idx) { + LatticeVal EltVal = getStructValueState(Aggr, i); + mergeInValue(getStructValueState(&IVI, i), &IVI, EltVal); + continue; + } + + Value *Val = IVI.getInsertedValueOperand(); + if (Val->getType()->isStructTy()) + // We don't track structs in structs. + markOverdefined(getStructValueState(&IVI, i), &IVI); + else { + LatticeVal InVal = getValueState(Val); + mergeInValue(getStructValueState(&IVI, i), &IVI, InVal); + } + } +} + +void SCCPSolver::visitSelectInst(SelectInst &I) { + // If this select returns a struct, just mark the result overdefined. + // TODO: We could do a lot better than this if code actually uses this. + if (I.getType()->isStructTy()) + return markOverdefined(&I); + + LatticeVal CondValue = getValueState(I.getCondition()); + if (CondValue.isUnknown()) + return; + + if (ConstantInt *CondCB = CondValue.getConstantInt()) { + Value *OpVal = CondCB->isZero() ? I.getFalseValue() : I.getTrueValue(); + mergeInValue(&I, getValueState(OpVal)); + return; + } + + // Otherwise, the condition is overdefined or a constant we can't evaluate. + // See if we can produce something better than overdefined based on the T/F + // value. + LatticeVal TVal = getValueState(I.getTrueValue()); + LatticeVal FVal = getValueState(I.getFalseValue()); + + // select ?, C, C -> C. + if (TVal.isConstant() && FVal.isConstant() && + TVal.getConstant() == FVal.getConstant()) + return markConstant(&I, FVal.getConstant()); + + if (TVal.isUnknown()) // select ?, undef, X -> X. + return mergeInValue(&I, FVal); + if (FVal.isUnknown()) // select ?, X, undef -> X. + return mergeInValue(&I, TVal); + markOverdefined(&I); +} + +// Handle Binary Operators. +void SCCPSolver::visitBinaryOperator(Instruction &I) { + LatticeVal V1State = getValueState(I.getOperand(0)); + LatticeVal V2State = getValueState(I.getOperand(1)); + + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + if (V1State.isConstant() && V2State.isConstant()) { + Constant *C = ConstantExpr::get(I.getOpcode(), V1State.getConstant(), + V2State.getConstant()); + // X op Y -> undef. + if (isa<UndefValue>(C)) + return; + return markConstant(IV, &I, C); + } + + // If something is undef, wait for it to resolve. + if (!V1State.isOverdefined() && !V2State.isOverdefined()) + return; + + // Otherwise, one of our operands is overdefined. Try to produce something + // better than overdefined with some tricks. + // If this is 0 / Y, it doesn't matter that the second operand is + // overdefined, and we can replace it with zero. + if (I.getOpcode() == Instruction::UDiv || I.getOpcode() == Instruction::SDiv) + if (V1State.isConstant() && V1State.getConstant()->isNullValue()) + return markConstant(IV, &I, V1State.getConstant()); + + // If this is: + // -> AND/MUL with 0 + // -> OR with -1 + // it doesn't matter that the other operand is overdefined. + if (I.getOpcode() == Instruction::And || I.getOpcode() == Instruction::Mul || + I.getOpcode() == Instruction::Or) { + LatticeVal *NonOverdefVal = nullptr; + if (!V1State.isOverdefined()) + NonOverdefVal = &V1State; + else if (!V2State.isOverdefined()) + NonOverdefVal = &V2State; + + if (NonOverdefVal) { + if (NonOverdefVal->isUnknown()) + return; + + if (I.getOpcode() == Instruction::And || + I.getOpcode() == Instruction::Mul) { + // X and 0 = 0 + // X * 0 = 0 + if (NonOverdefVal->getConstant()->isNullValue()) + return markConstant(IV, &I, NonOverdefVal->getConstant()); + } else { + // X or -1 = -1 + if (ConstantInt *CI = NonOverdefVal->getConstantInt()) + if (CI->isMinusOne()) + return markConstant(IV, &I, NonOverdefVal->getConstant()); + } + } + } + + markOverdefined(&I); +} + +// Handle ICmpInst instruction. +void SCCPSolver::visitCmpInst(CmpInst &I) { + LatticeVal V1State = getValueState(I.getOperand(0)); + LatticeVal V2State = getValueState(I.getOperand(1)); + + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + if (V1State.isConstant() && V2State.isConstant()) { + Constant *C = ConstantExpr::getCompare( + I.getPredicate(), V1State.getConstant(), V2State.getConstant()); + if (isa<UndefValue>(C)) + return; + return markConstant(IV, &I, C); + } + + // If operands are still unknown, wait for it to resolve. + if (!V1State.isOverdefined() && !V2State.isOverdefined()) + return; + + markOverdefined(&I); +} + +// Handle getelementptr instructions. If all operands are constants then we +// can turn this into a getelementptr ConstantExpr. +void SCCPSolver::visitGetElementPtrInst(GetElementPtrInst &I) { + if (ValueState[&I].isOverdefined()) return; + + SmallVector<Constant*, 8> Operands; + Operands.reserve(I.getNumOperands()); + + for (unsigned i = 0, e = I.getNumOperands(); i != e; ++i) { + LatticeVal State = getValueState(I.getOperand(i)); + if (State.isUnknown()) + return; // Operands are not resolved yet. + + if (State.isOverdefined()) + return markOverdefined(&I); + + assert(State.isConstant() && "Unknown state!"); + Operands.push_back(State.getConstant()); + } + + Constant *Ptr = Operands[0]; + auto Indices = makeArrayRef(Operands.begin() + 1, Operands.end()); + Constant *C = + ConstantExpr::getGetElementPtr(I.getSourceElementType(), Ptr, Indices); + if (isa<UndefValue>(C)) + return; + markConstant(&I, C); +} + +void SCCPSolver::visitStoreInst(StoreInst &SI) { + // If this store is of a struct, ignore it. + if (SI.getOperand(0)->getType()->isStructTy()) + return; + + if (TrackedGlobals.empty() || !isa<GlobalVariable>(SI.getOperand(1))) + return; + + GlobalVariable *GV = cast<GlobalVariable>(SI.getOperand(1)); + DenseMap<GlobalVariable*, LatticeVal>::iterator I = TrackedGlobals.find(GV); + if (I == TrackedGlobals.end() || I->second.isOverdefined()) return; + + // Get the value we are storing into the global, then merge it. + mergeInValue(I->second, GV, getValueState(SI.getOperand(0))); + if (I->second.isOverdefined()) + TrackedGlobals.erase(I); // No need to keep tracking this! +} + +// Handle load instructions. If the operand is a constant pointer to a constant +// global, we can replace the load with the loaded constant value! +void SCCPSolver::visitLoadInst(LoadInst &I) { + // If this load is of a struct, just mark the result overdefined. + if (I.getType()->isStructTy()) + return markOverdefined(&I); + + LatticeVal PtrVal = getValueState(I.getOperand(0)); + if (PtrVal.isUnknown()) return; // The pointer is not resolved yet! + + LatticeVal &IV = ValueState[&I]; + if (IV.isOverdefined()) return; + + if (!PtrVal.isConstant() || I.isVolatile()) + return markOverdefined(IV, &I); + + Constant *Ptr = PtrVal.getConstant(); + + // load null is undefined. + if (isa<ConstantPointerNull>(Ptr) && I.getPointerAddressSpace() == 0) + return; + + // Transform load (constant global) into the value loaded. + if (auto *GV = dyn_cast<GlobalVariable>(Ptr)) { + if (!TrackedGlobals.empty()) { + // If we are tracking this global, merge in the known value for it. + DenseMap<GlobalVariable*, LatticeVal>::iterator It = + TrackedGlobals.find(GV); + if (It != TrackedGlobals.end()) { + mergeInValue(IV, &I, It->second); + return; + } + } + } + + // Transform load from a constant into a constant if possible. + if (Constant *C = ConstantFoldLoadFromConstPtr(Ptr, I.getType(), DL)) { + if (isa<UndefValue>(C)) + return; + return markConstant(IV, &I, C); + } + + // Otherwise we cannot say for certain what value this load will produce. + // Bail out. + markOverdefined(IV, &I); +} + +void SCCPSolver::visitCallSite(CallSite CS) { + Function *F = CS.getCalledFunction(); + Instruction *I = CS.getInstruction(); + + // The common case is that we aren't tracking the callee, either because we + // are not doing interprocedural analysis or the callee is indirect, or is + // external. Handle these cases first. + if (!F || F->isDeclaration()) { +CallOverdefined: + // Void return and not tracking callee, just bail. + if (I->getType()->isVoidTy()) return; + + // Otherwise, if we have a single return value case, and if the function is + // a declaration, maybe we can constant fold it. + if (F && F->isDeclaration() && !I->getType()->isStructTy() && + canConstantFoldCallTo(CS, F)) { + SmallVector<Constant*, 8> Operands; + for (CallSite::arg_iterator AI = CS.arg_begin(), E = CS.arg_end(); + AI != E; ++AI) { + LatticeVal State = getValueState(*AI); + + if (State.isUnknown()) + return; // Operands are not resolved yet. + if (State.isOverdefined()) + return markOverdefined(I); + assert(State.isConstant() && "Unknown state!"); + Operands.push_back(State.getConstant()); + } + + if (getValueState(I).isOverdefined()) + return; + + // If we can constant fold this, mark the result of the call as a + // constant. + if (Constant *C = ConstantFoldCall(CS, F, Operands, TLI)) { + // call -> undef. + if (isa<UndefValue>(C)) + return; + return markConstant(I, C); + } + } + + // Otherwise, we don't know anything about this call, mark it overdefined. + return markOverdefined(I); + } + + // If this is a local function that doesn't have its address taken, mark its + // entry block executable and merge in the actual arguments to the call into + // the formal arguments of the function. + if (!TrackingIncomingArguments.empty() && TrackingIncomingArguments.count(F)){ + MarkBlockExecutable(&F->front()); + + // Propagate information from this call site into the callee. + CallSite::arg_iterator CAI = CS.arg_begin(); + for (Function::arg_iterator AI = F->arg_begin(), E = F->arg_end(); + AI != E; ++AI, ++CAI) { + // If this argument is byval, and if the function is not readonly, there + // will be an implicit copy formed of the input aggregate. + if (AI->hasByValAttr() && !F->onlyReadsMemory()) { + markOverdefined(&*AI); + continue; + } + + if (auto *STy = dyn_cast<StructType>(AI->getType())) { + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + LatticeVal CallArg = getStructValueState(*CAI, i); + mergeInValue(getStructValueState(&*AI, i), &*AI, CallArg); + } + } else { + // Most other parts of the Solver still only use the simpler value + // lattice, so we propagate changes for parameters to both lattices. + getParamState(&*AI).mergeIn(getValueState(*CAI).toValueLattice(), DL); + mergeInValue(&*AI, getValueState(*CAI)); + } + } + } + + // If this is a single/zero retval case, see if we're tracking the function. + if (auto *STy = dyn_cast<StructType>(F->getReturnType())) { + if (!MRVFunctionsTracked.count(F)) + goto CallOverdefined; // Not tracking this callee. + + // If we are tracking this callee, propagate the result of the function + // into this call site. + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) + mergeInValue(getStructValueState(I, i), I, + TrackedMultipleRetVals[std::make_pair(F, i)]); + } else { + DenseMap<Function*, LatticeVal>::iterator TFRVI = TrackedRetVals.find(F); + if (TFRVI == TrackedRetVals.end()) + goto CallOverdefined; // Not tracking this callee. + + // If so, propagate the return value of the callee into this call result. + mergeInValue(I, TFRVI->second); + } +} + +void SCCPSolver::Solve() { + // Process the work lists until they are empty! + while (!BBWorkList.empty() || !InstWorkList.empty() || + !OverdefinedInstWorkList.empty()) { + // Process the overdefined instruction's work list first, which drives other + // things to overdefined more quickly. + while (!OverdefinedInstWorkList.empty()) { + Value *I = OverdefinedInstWorkList.pop_back_val(); + + DEBUG(dbgs() << "\nPopped off OI-WL: " << *I << '\n'); + + // "I" got into the work list because it either made the transition from + // bottom to constant, or to overdefined. + // + // Anything on this worklist that is overdefined need not be visited + // since all of its users will have already been marked as overdefined + // Update all of the users of this instruction's value. + // + for (User *U : I->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + OperandChangedState(UI); + } + + // Process the instruction work list. + while (!InstWorkList.empty()) { + Value *I = InstWorkList.pop_back_val(); + + DEBUG(dbgs() << "\nPopped off I-WL: " << *I << '\n'); + + // "I" got into the work list because it made the transition from undef to + // constant. + // + // Anything on this worklist that is overdefined need not be visited + // since all of its users will have already been marked as overdefined. + // Update all of the users of this instruction's value. + // + if (I->getType()->isStructTy() || !getValueState(I).isOverdefined()) + for (User *U : I->users()) + if (auto *UI = dyn_cast<Instruction>(U)) + OperandChangedState(UI); + } + + // Process the basic block work list. + while (!BBWorkList.empty()) { + BasicBlock *BB = BBWorkList.back(); + BBWorkList.pop_back(); + + DEBUG(dbgs() << "\nPopped off BBWL: " << *BB << '\n'); + + // Notify all instructions in this basic block that they are newly + // executable. + visit(BB); + } + } +} + +/// ResolvedUndefsIn - While solving the dataflow for a function, we assume +/// that branches on undef values cannot reach any of their successors. +/// However, this is not a safe assumption. After we solve dataflow, this +/// method should be use to handle this. If this returns true, the solver +/// should be rerun. +/// +/// This method handles this by finding an unresolved branch and marking it one +/// of the edges from the block as being feasible, even though the condition +/// doesn't say it would otherwise be. This allows SCCP to find the rest of the +/// CFG and only slightly pessimizes the analysis results (by marking one, +/// potentially infeasible, edge feasible). This cannot usefully modify the +/// constraints on the condition of the branch, as that would impact other users +/// of the value. +/// +/// This scan also checks for values that use undefs, whose results are actually +/// defined. For example, 'zext i8 undef to i32' should produce all zeros +/// conservatively, as "(zext i8 X -> i32) & 0xFF00" must always return zero, +/// even if X isn't defined. +bool SCCPSolver::ResolvedUndefsIn(Function &F) { + for (BasicBlock &BB : F) { + if (!BBExecutable.count(&BB)) + continue; + + for (Instruction &I : BB) { + // Look for instructions which produce undef values. + if (I.getType()->isVoidTy()) continue; + + if (auto *STy = dyn_cast<StructType>(I.getType())) { + // Only a few things that can be structs matter for undef. + + // Tracked calls must never be marked overdefined in ResolvedUndefsIn. + if (CallSite CS = CallSite(&I)) + if (Function *F = CS.getCalledFunction()) + if (MRVFunctionsTracked.count(F)) + continue; + + // extractvalue and insertvalue don't need to be marked; they are + // tracked as precisely as their operands. + if (isa<ExtractValueInst>(I) || isa<InsertValueInst>(I)) + continue; + + // Send the results of everything else to overdefined. We could be + // more precise than this but it isn't worth bothering. + for (unsigned i = 0, e = STy->getNumElements(); i != e; ++i) { + LatticeVal &LV = getStructValueState(&I, i); + if (LV.isUnknown()) + markOverdefined(LV, &I); + } + continue; + } + + LatticeVal &LV = getValueState(&I); + if (!LV.isUnknown()) continue; + + // extractvalue is safe; check here because the argument is a struct. + if (isa<ExtractValueInst>(I)) + continue; + + // Compute the operand LatticeVals, for convenience below. + // Anything taking a struct is conservatively assumed to require + // overdefined markings. + if (I.getOperand(0)->getType()->isStructTy()) { + markOverdefined(&I); + return true; + } + LatticeVal Op0LV = getValueState(I.getOperand(0)); + LatticeVal Op1LV; + if (I.getNumOperands() == 2) { + if (I.getOperand(1)->getType()->isStructTy()) { + markOverdefined(&I); + return true; + } + + Op1LV = getValueState(I.getOperand(1)); + } + // If this is an instructions whose result is defined even if the input is + // not fully defined, propagate the information. + Type *ITy = I.getType(); + switch (I.getOpcode()) { + case Instruction::Add: + case Instruction::Sub: + case Instruction::Trunc: + case Instruction::FPTrunc: + case Instruction::BitCast: + break; // Any undef -> undef + case Instruction::FSub: + case Instruction::FAdd: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + // Floating-point binary operation: be conservative. + if (Op0LV.isUnknown() && Op1LV.isUnknown()) + markForcedConstant(&I, Constant::getNullValue(ITy)); + else + markOverdefined(&I); + return true; + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::FPExt: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::SIToFP: + case Instruction::UIToFP: + // undef -> 0; some outputs are impossible + markForcedConstant(&I, Constant::getNullValue(ITy)); + return true; + case Instruction::Mul: + case Instruction::And: + // Both operands undef -> undef + if (Op0LV.isUnknown() && Op1LV.isUnknown()) + break; + // undef * X -> 0. X could be zero. + // undef & X -> 0. X could be zero. + markForcedConstant(&I, Constant::getNullValue(ITy)); + return true; + case Instruction::Or: + // Both operands undef -> undef + if (Op0LV.isUnknown() && Op1LV.isUnknown()) + break; + // undef | X -> -1. X could be -1. + markForcedConstant(&I, Constant::getAllOnesValue(ITy)); + return true; + case Instruction::Xor: + // undef ^ undef -> 0; strictly speaking, this is not strictly + // necessary, but we try to be nice to people who expect this + // behavior in simple cases + if (Op0LV.isUnknown() && Op1LV.isUnknown()) { + markForcedConstant(&I, Constant::getNullValue(ITy)); + return true; + } + // undef ^ X -> undef + break; + case Instruction::SDiv: + case Instruction::UDiv: + case Instruction::SRem: + case Instruction::URem: + // X / undef -> undef. No change. + // X % undef -> undef. No change. + if (Op1LV.isUnknown()) break; + + // X / 0 -> undef. No change. + // X % 0 -> undef. No change. + if (Op1LV.isConstant() && Op1LV.getConstant()->isZeroValue()) + break; + + // undef / X -> 0. X could be maxint. + // undef % X -> 0. X could be 1. + markForcedConstant(&I, Constant::getNullValue(ITy)); + return true; + case Instruction::AShr: + // X >>a undef -> undef. + if (Op1LV.isUnknown()) break; + + // Shifting by the bitwidth or more is undefined. + if (Op1LV.isConstant()) { + if (auto *ShiftAmt = Op1LV.getConstantInt()) + if (ShiftAmt->getLimitedValue() >= + ShiftAmt->getType()->getScalarSizeInBits()) + break; + } + + // undef >>a X -> 0 + markForcedConstant(&I, Constant::getNullValue(ITy)); + return true; + case Instruction::LShr: + case Instruction::Shl: + // X << undef -> undef. + // X >> undef -> undef. + if (Op1LV.isUnknown()) break; + + // Shifting by the bitwidth or more is undefined. + if (Op1LV.isConstant()) { + if (auto *ShiftAmt = Op1LV.getConstantInt()) + if (ShiftAmt->getLimitedValue() >= + ShiftAmt->getType()->getScalarSizeInBits()) + break; + } + + // undef << X -> 0 + // undef >> X -> 0 + markForcedConstant(&I, Constant::getNullValue(ITy)); + return true; + case Instruction::Select: + Op1LV = getValueState(I.getOperand(1)); + // undef ? X : Y -> X or Y. There could be commonality between X/Y. + if (Op0LV.isUnknown()) { + if (!Op1LV.isConstant()) // Pick the constant one if there is any. + Op1LV = getValueState(I.getOperand(2)); + } else if (Op1LV.isUnknown()) { + // c ? undef : undef -> undef. No change. + Op1LV = getValueState(I.getOperand(2)); + if (Op1LV.isUnknown()) + break; + // Otherwise, c ? undef : x -> x. + } else { + // Leave Op1LV as Operand(1)'s LatticeValue. + } + + if (Op1LV.isConstant()) + markForcedConstant(&I, Op1LV.getConstant()); + else + markOverdefined(&I); + return true; + case Instruction::Load: + // A load here means one of two things: a load of undef from a global, + // a load from an unknown pointer. Either way, having it return undef + // is okay. + break; + case Instruction::ICmp: + // X == undef -> undef. Other comparisons get more complicated. + if (cast<ICmpInst>(&I)->isEquality()) + break; + markOverdefined(&I); + return true; + case Instruction::Call: + case Instruction::Invoke: + // There are two reasons a call can have an undef result + // 1. It could be tracked. + // 2. It could be constant-foldable. + // Because of the way we solve return values, tracked calls must + // never be marked overdefined in ResolvedUndefsIn. + if (Function *F = CallSite(&I).getCalledFunction()) + if (TrackedRetVals.count(F)) + break; + + // If the call is constant-foldable, we mark it overdefined because + // we do not know what return values are valid. + markOverdefined(&I); + return true; + default: + // If we don't know what should happen here, conservatively mark it + // overdefined. + markOverdefined(&I); + return true; + } + } + + // Check to see if we have a branch or switch on an undefined value. If so + // we force the branch to go one way or the other to make the successor + // values live. It doesn't really matter which way we force it. + TerminatorInst *TI = BB.getTerminator(); + if (auto *BI = dyn_cast<BranchInst>(TI)) { + if (!BI->isConditional()) continue; + if (!getValueState(BI->getCondition()).isUnknown()) + continue; + + // If the input to SCCP is actually branch on undef, fix the undef to + // false. + if (isa<UndefValue>(BI->getCondition())) { + BI->setCondition(ConstantInt::getFalse(BI->getContext())); + markEdgeExecutable(&BB, TI->getSuccessor(1)); + return true; + } + + // Otherwise, it is a branch on a symbolic value which is currently + // considered to be undef. Handle this by forcing the input value to the + // branch to false. + markForcedConstant(BI->getCondition(), + ConstantInt::getFalse(TI->getContext())); + return true; + } + + if (auto *IBR = dyn_cast<IndirectBrInst>(TI)) { + // Indirect branch with no successor ?. Its ok to assume it branches + // to no target. + if (IBR->getNumSuccessors() < 1) + continue; + + if (!getValueState(IBR->getAddress()).isUnknown()) + continue; + + // If the input to SCCP is actually branch on undef, fix the undef to + // the first successor of the indirect branch. + if (isa<UndefValue>(IBR->getAddress())) { + IBR->setAddress(BlockAddress::get(IBR->getSuccessor(0))); + markEdgeExecutable(&BB, IBR->getSuccessor(0)); + return true; + } + + // Otherwise, it is a branch on a symbolic value which is currently + // considered to be undef. Handle this by forcing the input value to the + // branch to the first successor. + markForcedConstant(IBR->getAddress(), + BlockAddress::get(IBR->getSuccessor(0))); + return true; + } + + if (auto *SI = dyn_cast<SwitchInst>(TI)) { + if (!SI->getNumCases() || !getValueState(SI->getCondition()).isUnknown()) + continue; + + // If the input to SCCP is actually switch on undef, fix the undef to + // the first constant. + if (isa<UndefValue>(SI->getCondition())) { + SI->setCondition(SI->case_begin()->getCaseValue()); + markEdgeExecutable(&BB, SI->case_begin()->getCaseSuccessor()); + return true; + } + + markForcedConstant(SI->getCondition(), SI->case_begin()->getCaseValue()); + return true; + } + } + + return false; +} + +static bool tryToReplaceWithConstantRange(SCCPSolver &Solver, Value *V) { + bool Changed = false; + + // Currently we only use range information for integer values. + if (!V->getType()->isIntegerTy()) + return false; + + const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); + if (!IV.isConstantRange()) + return false; + + for (auto UI = V->uses().begin(), E = V->uses().end(); UI != E;) { + const Use &U = *UI++; + auto *Icmp = dyn_cast<ICmpInst>(U.getUser()); + if (!Icmp || !Solver.isBlockExecutable(Icmp->getParent())) + continue; + + auto getIcmpLatticeValue = [&](Value *Op) { + if (auto *C = dyn_cast<Constant>(Op)) + return ValueLatticeElement::get(C); + return Solver.getLatticeValueFor(Op); + }; + + ValueLatticeElement A = getIcmpLatticeValue(Icmp->getOperand(0)); + ValueLatticeElement B = getIcmpLatticeValue(Icmp->getOperand(1)); + + Constant *C = nullptr; + if (A.satisfiesPredicate(Icmp->getPredicate(), B)) + C = ConstantInt::getTrue(Icmp->getType()); + else if (A.satisfiesPredicate(Icmp->getInversePredicate(), B)) + C = ConstantInt::getFalse(Icmp->getType()); + + if (C) { + Icmp->replaceAllUsesWith(C); + DEBUG(dbgs() << "Replacing " << *Icmp << " with " << *C + << ", because of range information " << A << " " << B + << "\n"); + Icmp->eraseFromParent(); + Changed = true; + } + } + return Changed; +} + +static bool tryToReplaceWithConstant(SCCPSolver &Solver, Value *V) { + Constant *Const = nullptr; + if (V->getType()->isStructTy()) { + std::vector<LatticeVal> IVs = Solver.getStructLatticeValueFor(V); + if (llvm::any_of(IVs, + [](const LatticeVal &LV) { return LV.isOverdefined(); })) + return false; + std::vector<Constant *> ConstVals; + auto *ST = dyn_cast<StructType>(V->getType()); + for (unsigned i = 0, e = ST->getNumElements(); i != e; ++i) { + LatticeVal V = IVs[i]; + ConstVals.push_back(V.isConstant() + ? V.getConstant() + : UndefValue::get(ST->getElementType(i))); + } + Const = ConstantStruct::get(ST, ConstVals); + } else { + const ValueLatticeElement &IV = Solver.getLatticeValueFor(V); + if (IV.isOverdefined()) + return false; + + if (IV.isConstantRange()) { + if (IV.getConstantRange().isSingleElement()) + Const = + ConstantInt::get(V->getType(), IV.asConstantInteger().getValue()); + else + return false; + } else + Const = + IV.isConstant() ? IV.getConstant() : UndefValue::get(V->getType()); + } + assert(Const && "Constant is nullptr here!"); + DEBUG(dbgs() << " Constant: " << *Const << " = " << *V << '\n'); + + // Replaces all of the uses of a variable with uses of the constant. + V->replaceAllUsesWith(Const); + return true; +} + +// runSCCP() - Run the Sparse Conditional Constant Propagation algorithm, +// and return true if the function was modified. +static bool runSCCP(Function &F, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + DEBUG(dbgs() << "SCCP on function '" << F.getName() << "'\n"); + SCCPSolver Solver(DL, TLI); + + // Mark the first block of the function as being executable. + Solver.MarkBlockExecutable(&F.front()); + + // Mark all arguments to the function as being overdefined. + for (Argument &AI : F.args()) + Solver.markOverdefined(&AI); + + // Solve for constants. + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + Solver.Solve(); + DEBUG(dbgs() << "RESOLVING UNDEFs\n"); + ResolvedUndefs = Solver.ResolvedUndefsIn(F); + } + + bool MadeChanges = false; + + // If we decided that there are basic blocks that are dead in this function, + // delete their contents now. Note that we cannot actually delete the blocks, + // as we cannot modify the CFG of the function. + + for (BasicBlock &BB : F) { + if (!Solver.isBlockExecutable(&BB)) { + DEBUG(dbgs() << " BasicBlock Dead:" << BB); + + ++NumDeadBlocks; + NumInstRemoved += removeAllNonTerminatorAndEHPadInstructions(&BB); + + MadeChanges = true; + continue; + } + + // Iterate over all of the instructions in a function, replacing them with + // constants if we have found them to be of constant values. + for (BasicBlock::iterator BI = BB.begin(), E = BB.end(); BI != E;) { + Instruction *Inst = &*BI++; + if (Inst->getType()->isVoidTy() || isa<TerminatorInst>(Inst)) + continue; + + if (tryToReplaceWithConstant(Solver, Inst)) { + if (isInstructionTriviallyDead(Inst)) + Inst->eraseFromParent(); + // Hey, we just changed something! + MadeChanges = true; + ++NumInstRemoved; + } + } + } + + return MadeChanges; +} + +PreservedAnalyses SCCPPass::run(Function &F, FunctionAnalysisManager &AM) { + const DataLayout &DL = F.getParent()->getDataLayout(); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + if (!runSCCP(F, DL, &TLI)) + return PreservedAnalyses::all(); + + auto PA = PreservedAnalyses(); + PA.preserve<GlobalsAA>(); + return PA; +} + +namespace { + +//===--------------------------------------------------------------------===// +// +/// SCCP Class - This class uses the SCCPSolver to implement a per-function +/// Sparse Conditional Constant Propagator. +/// +class SCCPLegacyPass : public FunctionPass { +public: + // Pass identification, replacement for typeid + static char ID; + + SCCPLegacyPass() : FunctionPass(ID) { + initializeSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + + // runOnFunction - Run the Sparse Conditional Constant Propagation + // algorithm, and return true if the function was modified. + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + const DataLayout &DL = F.getParent()->getDataLayout(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return runSCCP(F, DL, TLI); + } +}; + +} // end anonymous namespace + +char SCCPLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(SCCPLegacyPass, "sccp", + "Sparse Conditional Constant Propagation", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(SCCPLegacyPass, "sccp", + "Sparse Conditional Constant Propagation", false, false) + +// createSCCPPass - This is the public interface to this file. +FunctionPass *llvm::createSCCPPass() { return new SCCPLegacyPass(); } + +static void findReturnsToZap(Function &F, + SmallVector<ReturnInst *, 8> &ReturnsToZap, + SCCPSolver &Solver) { + // We can only do this if we know that nothing else can call the function. + if (!Solver.isArgumentTrackedFunction(&F)) + return; + + for (BasicBlock &BB : F) + if (auto *RI = dyn_cast<ReturnInst>(BB.getTerminator())) + if (!isa<UndefValue>(RI->getOperand(0))) + ReturnsToZap.push_back(RI); +} + +static bool runIPSCCP(Module &M, const DataLayout &DL, + const TargetLibraryInfo *TLI) { + SCCPSolver Solver(DL, TLI); + + // Loop over all functions, marking arguments to those with their addresses + // taken or that are external as overdefined. + for (Function &F : M) { + if (F.isDeclaration()) + continue; + + // Determine if we can track the function's return values. If so, add the + // function to the solver's set of return-tracked functions. + if (canTrackReturnsInterprocedurally(&F)) + Solver.AddTrackedFunction(&F); + + // Determine if we can track the function's arguments. If so, add the + // function to the solver's set of argument-tracked functions. + if (canTrackArgumentsInterprocedurally(&F)) { + Solver.AddArgumentTrackedFunction(&F); + continue; + } + + // Assume the function is called. + Solver.MarkBlockExecutable(&F.front()); + + // Assume nothing about the incoming arguments. + for (Argument &AI : F.args()) + Solver.markOverdefined(&AI); + } + + // Determine if we can track any of the module's global variables. If so, add + // the global variables we can track to the solver's set of tracked global + // variables. + for (GlobalVariable &G : M.globals()) { + G.removeDeadConstantUsers(); + if (canTrackGlobalVariableInterprocedurally(&G)) + Solver.TrackValueOfGlobalVariable(&G); + } + + // Solve for constants. + bool ResolvedUndefs = true; + while (ResolvedUndefs) { + Solver.Solve(); + + DEBUG(dbgs() << "RESOLVING UNDEFS\n"); + ResolvedUndefs = false; + for (Function &F : M) + ResolvedUndefs |= Solver.ResolvedUndefsIn(F); + } + + bool MadeChanges = false; + + // Iterate over all of the instructions in the module, replacing them with + // constants if we have found them to be of constant values. + SmallVector<BasicBlock*, 512> BlocksToErase; + + for (Function &F : M) { + if (F.isDeclaration()) + continue; + + if (Solver.isBlockExecutable(&F.front())) + for (Function::arg_iterator AI = F.arg_begin(), E = F.arg_end(); AI != E; + ++AI) { + if (!AI->use_empty() && tryToReplaceWithConstant(Solver, &*AI)) { + ++IPNumArgsElimed; + continue; + } + + if (!AI->use_empty() && tryToReplaceWithConstantRange(Solver, &*AI)) + ++IPNumRangeInfoUsed; + } + + for (Function::iterator BB = F.begin(), E = F.end(); BB != E; ++BB) { + if (!Solver.isBlockExecutable(&*BB)) { + DEBUG(dbgs() << " BasicBlock Dead:" << *BB); + + ++NumDeadBlocks; + NumInstRemoved += + changeToUnreachable(BB->getFirstNonPHI(), /*UseLLVMTrap=*/false); + + MadeChanges = true; + + if (&*BB != &F.front()) + BlocksToErase.push_back(&*BB); + continue; + } + + for (BasicBlock::iterator BI = BB->begin(), E = BB->end(); BI != E; ) { + Instruction *Inst = &*BI++; + if (Inst->getType()->isVoidTy()) + continue; + if (tryToReplaceWithConstant(Solver, Inst)) { + if (!isa<CallInst>(Inst) && !isa<TerminatorInst>(Inst)) + Inst->eraseFromParent(); + // Hey, we just changed something! + MadeChanges = true; + ++IPNumInstRemoved; + } + } + } + + // Now that all instructions in the function are constant folded, erase dead + // blocks, because we can now use ConstantFoldTerminator to get rid of + // in-edges. + for (unsigned i = 0, e = BlocksToErase.size(); i != e; ++i) { + // If there are any PHI nodes in this successor, drop entries for BB now. + BasicBlock *DeadBB = BlocksToErase[i]; + for (Value::user_iterator UI = DeadBB->user_begin(), + UE = DeadBB->user_end(); + UI != UE;) { + // Grab the user and then increment the iterator early, as the user + // will be deleted. Step past all adjacent uses from the same user. + auto *I = dyn_cast<Instruction>(*UI); + do { ++UI; } while (UI != UE && *UI == I); + + // Ignore blockaddress users; BasicBlock's dtor will handle them. + if (!I) continue; + + bool Folded = ConstantFoldTerminator(I->getParent()); + if (!Folded) { + // The constant folder may not have been able to fold the terminator + // if this is a branch or switch on undef. Fold it manually as a + // branch to the first successor. +#ifndef NDEBUG + if (auto *BI = dyn_cast<BranchInst>(I)) { + assert(BI->isConditional() && isa<UndefValue>(BI->getCondition()) && + "Branch should be foldable!"); + } else if (auto *SI = dyn_cast<SwitchInst>(I)) { + assert(isa<UndefValue>(SI->getCondition()) && "Switch should fold"); + } else { + llvm_unreachable("Didn't fold away reference to block!"); + } +#endif + + // Make this an uncond branch to the first successor. + TerminatorInst *TI = I->getParent()->getTerminator(); + BranchInst::Create(TI->getSuccessor(0), TI); + + // Remove entries in successor phi nodes to remove edges. + for (unsigned i = 1, e = TI->getNumSuccessors(); i != e; ++i) + TI->getSuccessor(i)->removePredecessor(TI->getParent()); + + // Remove the old terminator. + TI->eraseFromParent(); + } + } + + // Finally, delete the basic block. + F.getBasicBlockList().erase(DeadBB); + } + BlocksToErase.clear(); + } + + // If we inferred constant or undef return values for a function, we replaced + // all call uses with the inferred value. This means we don't need to bother + // actually returning anything from the function. Replace all return + // instructions with return undef. + // + // Do this in two stages: first identify the functions we should process, then + // actually zap their returns. This is important because we can only do this + // if the address of the function isn't taken. In cases where a return is the + // last use of a function, the order of processing functions would affect + // whether other functions are optimizable. + SmallVector<ReturnInst*, 8> ReturnsToZap; + + const DenseMap<Function*, LatticeVal> &RV = Solver.getTrackedRetVals(); + for (const auto &I : RV) { + Function *F = I.first; + if (I.second.isOverdefined() || F->getReturnType()->isVoidTy()) + continue; + findReturnsToZap(*F, ReturnsToZap, Solver); + } + + for (const auto &F : Solver.getMRVFunctionsTracked()) { + assert(F->getReturnType()->isStructTy() && + "The return type should be a struct"); + StructType *STy = cast<StructType>(F->getReturnType()); + if (Solver.isStructLatticeConstant(F, STy)) + findReturnsToZap(*F, ReturnsToZap, Solver); + } + + // Zap all returns which we've identified as zap to change. + for (unsigned i = 0, e = ReturnsToZap.size(); i != e; ++i) { + Function *F = ReturnsToZap[i]->getParent()->getParent(); + ReturnsToZap[i]->setOperand(0, UndefValue::get(F->getReturnType())); + } + + // If we inferred constant or undef values for globals variables, we can + // delete the global and any stores that remain to it. + const DenseMap<GlobalVariable*, LatticeVal> &TG = Solver.getTrackedGlobals(); + for (DenseMap<GlobalVariable*, LatticeVal>::const_iterator I = TG.begin(), + E = TG.end(); I != E; ++I) { + GlobalVariable *GV = I->first; + assert(!I->second.isOverdefined() && + "Overdefined values should have been taken out of the map!"); + DEBUG(dbgs() << "Found that GV '" << GV->getName() << "' is constant!\n"); + while (!GV->use_empty()) { + StoreInst *SI = cast<StoreInst>(GV->user_back()); + SI->eraseFromParent(); + } + M.getGlobalList().erase(GV); + ++IPNumGlobalConst; + } + + return MadeChanges; +} + +PreservedAnalyses IPSCCPPass::run(Module &M, ModuleAnalysisManager &AM) { + const DataLayout &DL = M.getDataLayout(); + auto &TLI = AM.getResult<TargetLibraryAnalysis>(M); + if (!runIPSCCP(M, DL, &TLI)) + return PreservedAnalyses::all(); + return PreservedAnalyses::none(); +} + +namespace { + +//===--------------------------------------------------------------------===// +// +/// IPSCCP Class - This class implements interprocedural Sparse Conditional +/// Constant Propagation. +/// +class IPSCCPLegacyPass : public ModulePass { +public: + static char ID; + + IPSCCPLegacyPass() : ModulePass(ID) { + initializeIPSCCPLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnModule(Module &M) override { + if (skipModule(M)) + return false; + const DataLayout &DL = M.getDataLayout(); + const TargetLibraryInfo *TLI = + &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + return runIPSCCP(M, DL, TLI); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } +}; + +} // end anonymous namespace + +char IPSCCPLegacyPass::ID = 0; + +INITIALIZE_PASS_BEGIN(IPSCCPLegacyPass, "ipsccp", + "Interprocedural Sparse Conditional Constant Propagation", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END(IPSCCPLegacyPass, "ipsccp", + "Interprocedural Sparse Conditional Constant Propagation", + false, false) + +// createIPSCCPPass - This is the public interface to this file. +ModulePass *llvm::createIPSCCPPass() { return new IPSCCPLegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/SROA.cpp b/contrib/llvm/lib/Transforms/Scalar/SROA.cpp new file mode 100644 index 000000000000..bfe3754f0769 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SROA.cpp @@ -0,0 +1,4447 @@ +//===- SROA.cpp - Scalar Replacement Of Aggregates ------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +/// \file +/// This transformation implements the well known scalar replacement of +/// aggregates transformation. It tries to identify promotable elements of an +/// aggregate alloca, and promote them to registers. It will also try to +/// convert uses of an element (or set of elements) of an alloca into a vector +/// or bitfield-style integer scalar if appropriate. +/// +/// It works to do this with minimal slicing of the alloca so that regions +/// which are merely transferred in and out of external memory remain unchanged +/// and are not decomposed to scalar code. +/// +/// Because this also performs alloca promotion, it can be thought of as also +/// serving the purpose of SSA formation. The algorithm iterates on the +/// function until all opportunities for promotion have been realized. +/// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/SROA.h" +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/PointerIntPair.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/ADT/iterator.h" +#include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/PtrUseVisitor.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/ConstantFolder.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DIBuilder.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DebugInfoMetadata.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/GlobalAlias.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PassManager.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Compiler.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include "llvm/Transforms/Utils/PromoteMemToReg.h" +#include <algorithm> +#include <cassert> +#include <chrono> +#include <cstddef> +#include <cstdint> +#include <cstring> +#include <iterator> +#include <string> +#include <tuple> +#include <utility> +#include <vector> + +#ifndef NDEBUG +// We only use this for a debug check. +#include <random> +#endif + +using namespace llvm; +using namespace llvm::sroa; + +#define DEBUG_TYPE "sroa" + +STATISTIC(NumAllocasAnalyzed, "Number of allocas analyzed for replacement"); +STATISTIC(NumAllocaPartitions, "Number of alloca partitions formed"); +STATISTIC(MaxPartitionsPerAlloca, "Maximum number of partitions per alloca"); +STATISTIC(NumAllocaPartitionUses, "Number of alloca partition uses rewritten"); +STATISTIC(MaxUsesPerAllocaPartition, "Maximum number of uses of a partition"); +STATISTIC(NumNewAllocas, "Number of new, smaller allocas introduced"); +STATISTIC(NumPromoted, "Number of allocas promoted to SSA values"); +STATISTIC(NumLoadsSpeculated, "Number of loads speculated to allow promotion"); +STATISTIC(NumDeleted, "Number of instructions deleted"); +STATISTIC(NumVectorized, "Number of vectorized aggregates"); + +/// Hidden option to enable randomly shuffling the slices to help uncover +/// instability in their order. +static cl::opt<bool> SROARandomShuffleSlices("sroa-random-shuffle-slices", + cl::init(false), cl::Hidden); + +/// Hidden option to experiment with completely strict handling of inbounds +/// GEPs. +static cl::opt<bool> SROAStrictInbounds("sroa-strict-inbounds", cl::init(false), + cl::Hidden); + +/// Hidden option to allow more aggressive splitting. +static cl::opt<bool> +SROASplitNonWholeAllocaSlices("sroa-split-nonwhole-alloca-slices", + cl::init(false), cl::Hidden); + +namespace { + +/// \brief A custom IRBuilder inserter which prefixes all names, but only in +/// Assert builds. +class IRBuilderPrefixedInserter : public IRBuilderDefaultInserter { + std::string Prefix; + + const Twine getNameWithPrefix(const Twine &Name) const { + return Name.isTriviallyEmpty() ? Name : Prefix + Name; + } + +public: + void SetNamePrefix(const Twine &P) { Prefix = P.str(); } + +protected: + void InsertHelper(Instruction *I, const Twine &Name, BasicBlock *BB, + BasicBlock::iterator InsertPt) const { + IRBuilderDefaultInserter::InsertHelper(I, getNameWithPrefix(Name), BB, + InsertPt); + } +}; + +/// \brief Provide a type for IRBuilder that drops names in release builds. +using IRBuilderTy = IRBuilder<ConstantFolder, IRBuilderPrefixedInserter>; + +/// \brief A used slice of an alloca. +/// +/// This structure represents a slice of an alloca used by some instruction. It +/// stores both the begin and end offsets of this use, a pointer to the use +/// itself, and a flag indicating whether we can classify the use as splittable +/// or not when forming partitions of the alloca. +class Slice { + /// \brief The beginning offset of the range. + uint64_t BeginOffset = 0; + + /// \brief The ending offset, not included in the range. + uint64_t EndOffset = 0; + + /// \brief Storage for both the use of this slice and whether it can be + /// split. + PointerIntPair<Use *, 1, bool> UseAndIsSplittable; + +public: + Slice() = default; + + Slice(uint64_t BeginOffset, uint64_t EndOffset, Use *U, bool IsSplittable) + : BeginOffset(BeginOffset), EndOffset(EndOffset), + UseAndIsSplittable(U, IsSplittable) {} + + uint64_t beginOffset() const { return BeginOffset; } + uint64_t endOffset() const { return EndOffset; } + + bool isSplittable() const { return UseAndIsSplittable.getInt(); } + void makeUnsplittable() { UseAndIsSplittable.setInt(false); } + + Use *getUse() const { return UseAndIsSplittable.getPointer(); } + + bool isDead() const { return getUse() == nullptr; } + void kill() { UseAndIsSplittable.setPointer(nullptr); } + + /// \brief Support for ordering ranges. + /// + /// This provides an ordering over ranges such that start offsets are + /// always increasing, and within equal start offsets, the end offsets are + /// decreasing. Thus the spanning range comes first in a cluster with the + /// same start position. + bool operator<(const Slice &RHS) const { + if (beginOffset() < RHS.beginOffset()) + return true; + if (beginOffset() > RHS.beginOffset()) + return false; + if (isSplittable() != RHS.isSplittable()) + return !isSplittable(); + if (endOffset() > RHS.endOffset()) + return true; + return false; + } + + /// \brief Support comparison with a single offset to allow binary searches. + friend LLVM_ATTRIBUTE_UNUSED bool operator<(const Slice &LHS, + uint64_t RHSOffset) { + return LHS.beginOffset() < RHSOffset; + } + friend LLVM_ATTRIBUTE_UNUSED bool operator<(uint64_t LHSOffset, + const Slice &RHS) { + return LHSOffset < RHS.beginOffset(); + } + + bool operator==(const Slice &RHS) const { + return isSplittable() == RHS.isSplittable() && + beginOffset() == RHS.beginOffset() && endOffset() == RHS.endOffset(); + } + bool operator!=(const Slice &RHS) const { return !operator==(RHS); } +}; + +} // end anonymous namespace + +namespace llvm { + +template <typename T> struct isPodLike; +template <> struct isPodLike<Slice> { static const bool value = true; }; + +} // end namespace llvm + +/// \brief Representation of the alloca slices. +/// +/// This class represents the slices of an alloca which are formed by its +/// various uses. If a pointer escapes, we can't fully build a representation +/// for the slices used and we reflect that in this structure. The uses are +/// stored, sorted by increasing beginning offset and with unsplittable slices +/// starting at a particular offset before splittable slices. +class llvm::sroa::AllocaSlices { +public: + /// \brief Construct the slices of a particular alloca. + AllocaSlices(const DataLayout &DL, AllocaInst &AI); + + /// \brief Test whether a pointer to the allocation escapes our analysis. + /// + /// If this is true, the slices are never fully built and should be + /// ignored. + bool isEscaped() const { return PointerEscapingInstr; } + + /// \brief Support for iterating over the slices. + /// @{ + using iterator = SmallVectorImpl<Slice>::iterator; + using range = iterator_range<iterator>; + + iterator begin() { return Slices.begin(); } + iterator end() { return Slices.end(); } + + using const_iterator = SmallVectorImpl<Slice>::const_iterator; + using const_range = iterator_range<const_iterator>; + + const_iterator begin() const { return Slices.begin(); } + const_iterator end() const { return Slices.end(); } + /// @} + + /// \brief Erase a range of slices. + void erase(iterator Start, iterator Stop) { Slices.erase(Start, Stop); } + + /// \brief Insert new slices for this alloca. + /// + /// This moves the slices into the alloca's slices collection, and re-sorts + /// everything so that the usual ordering properties of the alloca's slices + /// hold. + void insert(ArrayRef<Slice> NewSlices) { + int OldSize = Slices.size(); + Slices.append(NewSlices.begin(), NewSlices.end()); + auto SliceI = Slices.begin() + OldSize; + std::sort(SliceI, Slices.end()); + std::inplace_merge(Slices.begin(), SliceI, Slices.end()); + } + + // Forward declare the iterator and range accessor for walking the + // partitions. + class partition_iterator; + iterator_range<partition_iterator> partitions(); + + /// \brief Access the dead users for this alloca. + ArrayRef<Instruction *> getDeadUsers() const { return DeadUsers; } + + /// \brief Access the dead operands referring to this alloca. + /// + /// These are operands which have cannot actually be used to refer to the + /// alloca as they are outside its range and the user doesn't correct for + /// that. These mostly consist of PHI node inputs and the like which we just + /// need to replace with undef. + ArrayRef<Use *> getDeadOperands() const { return DeadOperands; } + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + void print(raw_ostream &OS, const_iterator I, StringRef Indent = " ") const; + void printSlice(raw_ostream &OS, const_iterator I, + StringRef Indent = " ") const; + void printUse(raw_ostream &OS, const_iterator I, + StringRef Indent = " ") const; + void print(raw_ostream &OS) const; + void dump(const_iterator I) const; + void dump() const; +#endif + +private: + template <typename DerivedT, typename RetT = void> class BuilderBase; + class SliceBuilder; + + friend class AllocaSlices::SliceBuilder; + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + /// \brief Handle to alloca instruction to simplify method interfaces. + AllocaInst &AI; +#endif + + /// \brief The instruction responsible for this alloca not having a known set + /// of slices. + /// + /// When an instruction (potentially) escapes the pointer to the alloca, we + /// store a pointer to that here and abort trying to form slices of the + /// alloca. This will be null if the alloca slices are analyzed successfully. + Instruction *PointerEscapingInstr; + + /// \brief The slices of the alloca. + /// + /// We store a vector of the slices formed by uses of the alloca here. This + /// vector is sorted by increasing begin offset, and then the unsplittable + /// slices before the splittable ones. See the Slice inner class for more + /// details. + SmallVector<Slice, 8> Slices; + + /// \brief Instructions which will become dead if we rewrite the alloca. + /// + /// Note that these are not separated by slice. This is because we expect an + /// alloca to be completely rewritten or not rewritten at all. If rewritten, + /// all these instructions can simply be removed and replaced with undef as + /// they come from outside of the allocated space. + SmallVector<Instruction *, 8> DeadUsers; + + /// \brief Operands which will become dead if we rewrite the alloca. + /// + /// These are operands that in their particular use can be replaced with + /// undef when we rewrite the alloca. These show up in out-of-bounds inputs + /// to PHI nodes and the like. They aren't entirely dead (there might be + /// a GEP back into the bounds using it elsewhere) and nor is the PHI, but we + /// want to swap this particular input for undef to simplify the use lists of + /// the alloca. + SmallVector<Use *, 8> DeadOperands; +}; + +/// \brief A partition of the slices. +/// +/// An ephemeral representation for a range of slices which can be viewed as +/// a partition of the alloca. This range represents a span of the alloca's +/// memory which cannot be split, and provides access to all of the slices +/// overlapping some part of the partition. +/// +/// Objects of this type are produced by traversing the alloca's slices, but +/// are only ephemeral and not persistent. +class llvm::sroa::Partition { +private: + friend class AllocaSlices; + friend class AllocaSlices::partition_iterator; + + using iterator = AllocaSlices::iterator; + + /// \brief The beginning and ending offsets of the alloca for this + /// partition. + uint64_t BeginOffset, EndOffset; + + /// \brief The start and end iterators of this partition. + iterator SI, SJ; + + /// \brief A collection of split slice tails overlapping the partition. + SmallVector<Slice *, 4> SplitTails; + + /// \brief Raw constructor builds an empty partition starting and ending at + /// the given iterator. + Partition(iterator SI) : SI(SI), SJ(SI) {} + +public: + /// \brief The start offset of this partition. + /// + /// All of the contained slices start at or after this offset. + uint64_t beginOffset() const { return BeginOffset; } + + /// \brief The end offset of this partition. + /// + /// All of the contained slices end at or before this offset. + uint64_t endOffset() const { return EndOffset; } + + /// \brief The size of the partition. + /// + /// Note that this can never be zero. + uint64_t size() const { + assert(BeginOffset < EndOffset && "Partitions must span some bytes!"); + return EndOffset - BeginOffset; + } + + /// \brief Test whether this partition contains no slices, and merely spans + /// a region occupied by split slices. + bool empty() const { return SI == SJ; } + + /// \name Iterate slices that start within the partition. + /// These may be splittable or unsplittable. They have a begin offset >= the + /// partition begin offset. + /// @{ + // FIXME: We should probably define a "concat_iterator" helper and use that + // to stitch together pointee_iterators over the split tails and the + // contiguous iterators of the partition. That would give a much nicer + // interface here. We could then additionally expose filtered iterators for + // split, unsplit, and unsplittable splices based on the usage patterns. + iterator begin() const { return SI; } + iterator end() const { return SJ; } + /// @} + + /// \brief Get the sequence of split slice tails. + /// + /// These tails are of slices which start before this partition but are + /// split and overlap into the partition. We accumulate these while forming + /// partitions. + ArrayRef<Slice *> splitSliceTails() const { return SplitTails; } +}; + +/// \brief An iterator over partitions of the alloca's slices. +/// +/// This iterator implements the core algorithm for partitioning the alloca's +/// slices. It is a forward iterator as we don't support backtracking for +/// efficiency reasons, and re-use a single storage area to maintain the +/// current set of split slices. +/// +/// It is templated on the slice iterator type to use so that it can operate +/// with either const or non-const slice iterators. +class AllocaSlices::partition_iterator + : public iterator_facade_base<partition_iterator, std::forward_iterator_tag, + Partition> { + friend class AllocaSlices; + + /// \brief Most of the state for walking the partitions is held in a class + /// with a nice interface for examining them. + Partition P; + + /// \brief We need to keep the end of the slices to know when to stop. + AllocaSlices::iterator SE; + + /// \brief We also need to keep track of the maximum split end offset seen. + /// FIXME: Do we really? + uint64_t MaxSplitSliceEndOffset = 0; + + /// \brief Sets the partition to be empty at given iterator, and sets the + /// end iterator. + partition_iterator(AllocaSlices::iterator SI, AllocaSlices::iterator SE) + : P(SI), SE(SE) { + // If not already at the end, advance our state to form the initial + // partition. + if (SI != SE) + advance(); + } + + /// \brief Advance the iterator to the next partition. + /// + /// Requires that the iterator not be at the end of the slices. + void advance() { + assert((P.SI != SE || !P.SplitTails.empty()) && + "Cannot advance past the end of the slices!"); + + // Clear out any split uses which have ended. + if (!P.SplitTails.empty()) { + if (P.EndOffset >= MaxSplitSliceEndOffset) { + // If we've finished all splits, this is easy. + P.SplitTails.clear(); + MaxSplitSliceEndOffset = 0; + } else { + // Remove the uses which have ended in the prior partition. This + // cannot change the max split slice end because we just checked that + // the prior partition ended prior to that max. + P.SplitTails.erase(llvm::remove_if(P.SplitTails, + [&](Slice *S) { + return S->endOffset() <= + P.EndOffset; + }), + P.SplitTails.end()); + assert(llvm::any_of(P.SplitTails, + [&](Slice *S) { + return S->endOffset() == MaxSplitSliceEndOffset; + }) && + "Could not find the current max split slice offset!"); + assert(llvm::all_of(P.SplitTails, + [&](Slice *S) { + return S->endOffset() <= MaxSplitSliceEndOffset; + }) && + "Max split slice end offset is not actually the max!"); + } + } + + // If P.SI is already at the end, then we've cleared the split tail and + // now have an end iterator. + if (P.SI == SE) { + assert(P.SplitTails.empty() && "Failed to clear the split slices!"); + return; + } + + // If we had a non-empty partition previously, set up the state for + // subsequent partitions. + if (P.SI != P.SJ) { + // Accumulate all the splittable slices which started in the old + // partition into the split list. + for (Slice &S : P) + if (S.isSplittable() && S.endOffset() > P.EndOffset) { + P.SplitTails.push_back(&S); + MaxSplitSliceEndOffset = + std::max(S.endOffset(), MaxSplitSliceEndOffset); + } + + // Start from the end of the previous partition. + P.SI = P.SJ; + + // If P.SI is now at the end, we at most have a tail of split slices. + if (P.SI == SE) { + P.BeginOffset = P.EndOffset; + P.EndOffset = MaxSplitSliceEndOffset; + return; + } + + // If the we have split slices and the next slice is after a gap and is + // not splittable immediately form an empty partition for the split + // slices up until the next slice begins. + if (!P.SplitTails.empty() && P.SI->beginOffset() != P.EndOffset && + !P.SI->isSplittable()) { + P.BeginOffset = P.EndOffset; + P.EndOffset = P.SI->beginOffset(); + return; + } + } + + // OK, we need to consume new slices. Set the end offset based on the + // current slice, and step SJ past it. The beginning offset of the + // partition is the beginning offset of the next slice unless we have + // pre-existing split slices that are continuing, in which case we begin + // at the prior end offset. + P.BeginOffset = P.SplitTails.empty() ? P.SI->beginOffset() : P.EndOffset; + P.EndOffset = P.SI->endOffset(); + ++P.SJ; + + // There are two strategies to form a partition based on whether the + // partition starts with an unsplittable slice or a splittable slice. + if (!P.SI->isSplittable()) { + // When we're forming an unsplittable region, it must always start at + // the first slice and will extend through its end. + assert(P.BeginOffset == P.SI->beginOffset()); + + // Form a partition including all of the overlapping slices with this + // unsplittable slice. + while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { + if (!P.SJ->isSplittable()) + P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); + ++P.SJ; + } + + // We have a partition across a set of overlapping unsplittable + // partitions. + return; + } + + // If we're starting with a splittable slice, then we need to form + // a synthetic partition spanning it and any other overlapping splittable + // splices. + assert(P.SI->isSplittable() && "Forming a splittable partition!"); + + // Collect all of the overlapping splittable slices. + while (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset && + P.SJ->isSplittable()) { + P.EndOffset = std::max(P.EndOffset, P.SJ->endOffset()); + ++P.SJ; + } + + // Back upiP.EndOffset if we ended the span early when encountering an + // unsplittable slice. This synthesizes the early end offset of + // a partition spanning only splittable slices. + if (P.SJ != SE && P.SJ->beginOffset() < P.EndOffset) { + assert(!P.SJ->isSplittable()); + P.EndOffset = P.SJ->beginOffset(); + } + } + +public: + bool operator==(const partition_iterator &RHS) const { + assert(SE == RHS.SE && + "End iterators don't match between compared partition iterators!"); + + // The observed positions of partitions is marked by the P.SI iterator and + // the emptiness of the split slices. The latter is only relevant when + // P.SI == SE, as the end iterator will additionally have an empty split + // slices list, but the prior may have the same P.SI and a tail of split + // slices. + if (P.SI == RHS.P.SI && P.SplitTails.empty() == RHS.P.SplitTails.empty()) { + assert(P.SJ == RHS.P.SJ && + "Same set of slices formed two different sized partitions!"); + assert(P.SplitTails.size() == RHS.P.SplitTails.size() && + "Same slice position with differently sized non-empty split " + "slice tails!"); + return true; + } + return false; + } + + partition_iterator &operator++() { + advance(); + return *this; + } + + Partition &operator*() { return P; } +}; + +/// \brief A forward range over the partitions of the alloca's slices. +/// +/// This accesses an iterator range over the partitions of the alloca's +/// slices. It computes these partitions on the fly based on the overlapping +/// offsets of the slices and the ability to split them. It will visit "empty" +/// partitions to cover regions of the alloca only accessed via split +/// slices. +iterator_range<AllocaSlices::partition_iterator> AllocaSlices::partitions() { + return make_range(partition_iterator(begin(), end()), + partition_iterator(end(), end())); +} + +static Value *foldSelectInst(SelectInst &SI) { + // If the condition being selected on is a constant or the same value is + // being selected between, fold the select. Yes this does (rarely) happen + // early on. + if (ConstantInt *CI = dyn_cast<ConstantInt>(SI.getCondition())) + return SI.getOperand(1 + CI->isZero()); + if (SI.getOperand(1) == SI.getOperand(2)) + return SI.getOperand(1); + + return nullptr; +} + +/// \brief A helper that folds a PHI node or a select. +static Value *foldPHINodeOrSelectInst(Instruction &I) { + if (PHINode *PN = dyn_cast<PHINode>(&I)) { + // If PN merges together the same value, return that value. + return PN->hasConstantValue(); + } + return foldSelectInst(cast<SelectInst>(I)); +} + +/// \brief Builder for the alloca slices. +/// +/// This class builds a set of alloca slices by recursively visiting the uses +/// of an alloca and making a slice for each load and store at each offset. +class AllocaSlices::SliceBuilder : public PtrUseVisitor<SliceBuilder> { + friend class PtrUseVisitor<SliceBuilder>; + friend class InstVisitor<SliceBuilder>; + + using Base = PtrUseVisitor<SliceBuilder>; + + const uint64_t AllocSize; + AllocaSlices &AS; + + SmallDenseMap<Instruction *, unsigned> MemTransferSliceMap; + SmallDenseMap<Instruction *, uint64_t> PHIOrSelectSizes; + + /// \brief Set to de-duplicate dead instructions found in the use walk. + SmallPtrSet<Instruction *, 4> VisitedDeadInsts; + +public: + SliceBuilder(const DataLayout &DL, AllocaInst &AI, AllocaSlices &AS) + : PtrUseVisitor<SliceBuilder>(DL), + AllocSize(DL.getTypeAllocSize(AI.getAllocatedType())), AS(AS) {} + +private: + void markAsDead(Instruction &I) { + if (VisitedDeadInsts.insert(&I).second) + AS.DeadUsers.push_back(&I); + } + + void insertUse(Instruction &I, const APInt &Offset, uint64_t Size, + bool IsSplittable = false) { + // Completely skip uses which have a zero size or start either before or + // past the end of the allocation. + if (Size == 0 || Offset.uge(AllocSize)) { + DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte use @" << Offset + << " which has zero size or starts outside of the " + << AllocSize << " byte alloca:\n" + << " alloca: " << AS.AI << "\n" + << " use: " << I << "\n"); + return markAsDead(I); + } + + uint64_t BeginOffset = Offset.getZExtValue(); + uint64_t EndOffset = BeginOffset + Size; + + // Clamp the end offset to the end of the allocation. Note that this is + // formulated to handle even the case where "BeginOffset + Size" overflows. + // This may appear superficially to be something we could ignore entirely, + // but that is not so! There may be widened loads or PHI-node uses where + // some instructions are dead but not others. We can't completely ignore + // them, and so have to record at least the information here. + assert(AllocSize >= BeginOffset); // Established above. + if (Size > AllocSize - BeginOffset) { + DEBUG(dbgs() << "WARNING: Clamping a " << Size << " byte use @" << Offset + << " to remain within the " << AllocSize << " byte alloca:\n" + << " alloca: " << AS.AI << "\n" + << " use: " << I << "\n"); + EndOffset = AllocSize; + } + + AS.Slices.push_back(Slice(BeginOffset, EndOffset, U, IsSplittable)); + } + + void visitBitCastInst(BitCastInst &BC) { + if (BC.use_empty()) + return markAsDead(BC); + + return Base::visitBitCastInst(BC); + } + + void visitGetElementPtrInst(GetElementPtrInst &GEPI) { + if (GEPI.use_empty()) + return markAsDead(GEPI); + + if (SROAStrictInbounds && GEPI.isInBounds()) { + // FIXME: This is a manually un-factored variant of the basic code inside + // of GEPs with checking of the inbounds invariant specified in the + // langref in a very strict sense. If we ever want to enable + // SROAStrictInbounds, this code should be factored cleanly into + // PtrUseVisitor, but it is easier to experiment with SROAStrictInbounds + // by writing out the code here where we have the underlying allocation + // size readily available. + APInt GEPOffset = Offset; + const DataLayout &DL = GEPI.getModule()->getDataLayout(); + for (gep_type_iterator GTI = gep_type_begin(GEPI), + GTE = gep_type_end(GEPI); + GTI != GTE; ++GTI) { + ConstantInt *OpC = dyn_cast<ConstantInt>(GTI.getOperand()); + if (!OpC) + break; + + // Handle a struct index, which adds its field offset to the pointer. + if (StructType *STy = GTI.getStructTypeOrNull()) { + unsigned ElementIdx = OpC->getZExtValue(); + const StructLayout *SL = DL.getStructLayout(STy); + GEPOffset += + APInt(Offset.getBitWidth(), SL->getElementOffset(ElementIdx)); + } else { + // For array or vector indices, scale the index by the size of the + // type. + APInt Index = OpC->getValue().sextOrTrunc(Offset.getBitWidth()); + GEPOffset += Index * APInt(Offset.getBitWidth(), + DL.getTypeAllocSize(GTI.getIndexedType())); + } + + // If this index has computed an intermediate pointer which is not + // inbounds, then the result of the GEP is a poison value and we can + // delete it and all uses. + if (GEPOffset.ugt(AllocSize)) + return markAsDead(GEPI); + } + } + + return Base::visitGetElementPtrInst(GEPI); + } + + void handleLoadOrStore(Type *Ty, Instruction &I, const APInt &Offset, + uint64_t Size, bool IsVolatile) { + // We allow splitting of non-volatile loads and stores where the type is an + // integer type. These may be used to implement 'memcpy' or other "transfer + // of bits" patterns. + bool IsSplittable = Ty->isIntegerTy() && !IsVolatile; + + insertUse(I, Offset, Size, IsSplittable); + } + + void visitLoadInst(LoadInst &LI) { + assert((!LI.isSimple() || LI.getType()->isSingleValueType()) && + "All simple FCA loads should have been pre-split"); + + if (!IsOffsetKnown) + return PI.setAborted(&LI); + + const DataLayout &DL = LI.getModule()->getDataLayout(); + uint64_t Size = DL.getTypeStoreSize(LI.getType()); + return handleLoadOrStore(LI.getType(), LI, Offset, Size, LI.isVolatile()); + } + + void visitStoreInst(StoreInst &SI) { + Value *ValOp = SI.getValueOperand(); + if (ValOp == *U) + return PI.setEscapedAndAborted(&SI); + if (!IsOffsetKnown) + return PI.setAborted(&SI); + + const DataLayout &DL = SI.getModule()->getDataLayout(); + uint64_t Size = DL.getTypeStoreSize(ValOp->getType()); + + // If this memory access can be shown to *statically* extend outside the + // bounds of of the allocation, it's behavior is undefined, so simply + // ignore it. Note that this is more strict than the generic clamping + // behavior of insertUse. We also try to handle cases which might run the + // risk of overflow. + // FIXME: We should instead consider the pointer to have escaped if this + // function is being instrumented for addressing bugs or race conditions. + if (Size > AllocSize || Offset.ugt(AllocSize - Size)) { + DEBUG(dbgs() << "WARNING: Ignoring " << Size << " byte store @" << Offset + << " which extends past the end of the " << AllocSize + << " byte alloca:\n" + << " alloca: " << AS.AI << "\n" + << " use: " << SI << "\n"); + return markAsDead(SI); + } + + assert((!SI.isSimple() || ValOp->getType()->isSingleValueType()) && + "All simple FCA stores should have been pre-split"); + handleLoadOrStore(ValOp->getType(), SI, Offset, Size, SI.isVolatile()); + } + + void visitMemSetInst(MemSetInst &II) { + assert(II.getRawDest() == *U && "Pointer use is not the destination?"); + ConstantInt *Length = dyn_cast<ConstantInt>(II.getLength()); + if ((Length && Length->getValue() == 0) || + (IsOffsetKnown && Offset.uge(AllocSize))) + // Zero-length mem transfer intrinsics can be ignored entirely. + return markAsDead(II); + + if (!IsOffsetKnown) + return PI.setAborted(&II); + + insertUse(II, Offset, Length ? Length->getLimitedValue() + : AllocSize - Offset.getLimitedValue(), + (bool)Length); + } + + void visitMemTransferInst(MemTransferInst &II) { + ConstantInt *Length = dyn_cast<ConstantInt>(II.getLength()); + if (Length && Length->getValue() == 0) + // Zero-length mem transfer intrinsics can be ignored entirely. + return markAsDead(II); + + // Because we can visit these intrinsics twice, also check to see if the + // first time marked this instruction as dead. If so, skip it. + if (VisitedDeadInsts.count(&II)) + return; + + if (!IsOffsetKnown) + return PI.setAborted(&II); + + // This side of the transfer is completely out-of-bounds, and so we can + // nuke the entire transfer. However, we also need to nuke the other side + // if already added to our partitions. + // FIXME: Yet another place we really should bypass this when + // instrumenting for ASan. + if (Offset.uge(AllocSize)) { + SmallDenseMap<Instruction *, unsigned>::iterator MTPI = + MemTransferSliceMap.find(&II); + if (MTPI != MemTransferSliceMap.end()) + AS.Slices[MTPI->second].kill(); + return markAsDead(II); + } + + uint64_t RawOffset = Offset.getLimitedValue(); + uint64_t Size = Length ? Length->getLimitedValue() : AllocSize - RawOffset; + + // Check for the special case where the same exact value is used for both + // source and dest. + if (*U == II.getRawDest() && *U == II.getRawSource()) { + // For non-volatile transfers this is a no-op. + if (!II.isVolatile()) + return markAsDead(II); + + return insertUse(II, Offset, Size, /*IsSplittable=*/false); + } + + // If we have seen both source and destination for a mem transfer, then + // they both point to the same alloca. + bool Inserted; + SmallDenseMap<Instruction *, unsigned>::iterator MTPI; + std::tie(MTPI, Inserted) = + MemTransferSliceMap.insert(std::make_pair(&II, AS.Slices.size())); + unsigned PrevIdx = MTPI->second; + if (!Inserted) { + Slice &PrevP = AS.Slices[PrevIdx]; + + // Check if the begin offsets match and this is a non-volatile transfer. + // In that case, we can completely elide the transfer. + if (!II.isVolatile() && PrevP.beginOffset() == RawOffset) { + PrevP.kill(); + return markAsDead(II); + } + + // Otherwise we have an offset transfer within the same alloca. We can't + // split those. + PrevP.makeUnsplittable(); + } + + // Insert the use now that we've fixed up the splittable nature. + insertUse(II, Offset, Size, /*IsSplittable=*/Inserted && Length); + + // Check that we ended up with a valid index in the map. + assert(AS.Slices[PrevIdx].getUse()->getUser() == &II && + "Map index doesn't point back to a slice with this user."); + } + + // Disable SRoA for any intrinsics except for lifetime invariants. + // FIXME: What about debug intrinsics? This matches old behavior, but + // doesn't make sense. + void visitIntrinsicInst(IntrinsicInst &II) { + if (!IsOffsetKnown) + return PI.setAborted(&II); + + if (II.getIntrinsicID() == Intrinsic::lifetime_start || + II.getIntrinsicID() == Intrinsic::lifetime_end) { + ConstantInt *Length = cast<ConstantInt>(II.getArgOperand(0)); + uint64_t Size = std::min(AllocSize - Offset.getLimitedValue(), + Length->getLimitedValue()); + insertUse(II, Offset, Size, true); + return; + } + + Base::visitIntrinsicInst(II); + } + + Instruction *hasUnsafePHIOrSelectUse(Instruction *Root, uint64_t &Size) { + // We consider any PHI or select that results in a direct load or store of + // the same offset to be a viable use for slicing purposes. These uses + // are considered unsplittable and the size is the maximum loaded or stored + // size. + SmallPtrSet<Instruction *, 4> Visited; + SmallVector<std::pair<Instruction *, Instruction *>, 4> Uses; + Visited.insert(Root); + Uses.push_back(std::make_pair(cast<Instruction>(*U), Root)); + const DataLayout &DL = Root->getModule()->getDataLayout(); + // If there are no loads or stores, the access is dead. We mark that as + // a size zero access. + Size = 0; + do { + Instruction *I, *UsedI; + std::tie(UsedI, I) = Uses.pop_back_val(); + + if (LoadInst *LI = dyn_cast<LoadInst>(I)) { + Size = std::max(Size, DL.getTypeStoreSize(LI->getType())); + continue; + } + if (StoreInst *SI = dyn_cast<StoreInst>(I)) { + Value *Op = SI->getOperand(0); + if (Op == UsedI) + return SI; + Size = std::max(Size, DL.getTypeStoreSize(Op->getType())); + continue; + } + + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I)) { + if (!GEP->hasAllZeroIndices()) + return GEP; + } else if (!isa<BitCastInst>(I) && !isa<PHINode>(I) && + !isa<SelectInst>(I)) { + return I; + } + + for (User *U : I->users()) + if (Visited.insert(cast<Instruction>(U)).second) + Uses.push_back(std::make_pair(I, cast<Instruction>(U))); + } while (!Uses.empty()); + + return nullptr; + } + + void visitPHINodeOrSelectInst(Instruction &I) { + assert(isa<PHINode>(I) || isa<SelectInst>(I)); + if (I.use_empty()) + return markAsDead(I); + + // TODO: We could use SimplifyInstruction here to fold PHINodes and + // SelectInsts. However, doing so requires to change the current + // dead-operand-tracking mechanism. For instance, suppose neither loading + // from %U nor %other traps. Then "load (select undef, %U, %other)" does not + // trap either. However, if we simply replace %U with undef using the + // current dead-operand-tracking mechanism, "load (select undef, undef, + // %other)" may trap because the select may return the first operand + // "undef". + if (Value *Result = foldPHINodeOrSelectInst(I)) { + if (Result == *U) + // If the result of the constant fold will be the pointer, recurse + // through the PHI/select as if we had RAUW'ed it. + enqueueUsers(I); + else + // Otherwise the operand to the PHI/select is dead, and we can replace + // it with undef. + AS.DeadOperands.push_back(U); + + return; + } + + if (!IsOffsetKnown) + return PI.setAborted(&I); + + // See if we already have computed info on this node. + uint64_t &Size = PHIOrSelectSizes[&I]; + if (!Size) { + // This is a new PHI/Select, check for an unsafe use of it. + if (Instruction *UnsafeI = hasUnsafePHIOrSelectUse(&I, Size)) + return PI.setAborted(UnsafeI); + } + + // For PHI and select operands outside the alloca, we can't nuke the entire + // phi or select -- the other side might still be relevant, so we special + // case them here and use a separate structure to track the operands + // themselves which should be replaced with undef. + // FIXME: This should instead be escaped in the event we're instrumenting + // for address sanitization. + if (Offset.uge(AllocSize)) { + AS.DeadOperands.push_back(U); + return; + } + + insertUse(I, Offset, Size); + } + + void visitPHINode(PHINode &PN) { visitPHINodeOrSelectInst(PN); } + + void visitSelectInst(SelectInst &SI) { visitPHINodeOrSelectInst(SI); } + + /// \brief Disable SROA entirely if there are unhandled users of the alloca. + void visitInstruction(Instruction &I) { PI.setAborted(&I); } +}; + +AllocaSlices::AllocaSlices(const DataLayout &DL, AllocaInst &AI) + : +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + AI(AI), +#endif + PointerEscapingInstr(nullptr) { + SliceBuilder PB(DL, AI, *this); + SliceBuilder::PtrInfo PtrI = PB.visitPtr(AI); + if (PtrI.isEscaped() || PtrI.isAborted()) { + // FIXME: We should sink the escape vs. abort info into the caller nicely, + // possibly by just storing the PtrInfo in the AllocaSlices. + PointerEscapingInstr = PtrI.getEscapingInst() ? PtrI.getEscapingInst() + : PtrI.getAbortingInst(); + assert(PointerEscapingInstr && "Did not track a bad instruction"); + return; + } + + Slices.erase( + llvm::remove_if(Slices, [](const Slice &S) { return S.isDead(); }), + Slices.end()); + +#ifndef NDEBUG + if (SROARandomShuffleSlices) { + std::mt19937 MT(static_cast<unsigned>( + std::chrono::system_clock::now().time_since_epoch().count())); + std::shuffle(Slices.begin(), Slices.end(), MT); + } +#endif + + // Sort the uses. This arranges for the offsets to be in ascending order, + // and the sizes to be in descending order. + std::sort(Slices.begin(), Slices.end()); +} + +#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + +void AllocaSlices::print(raw_ostream &OS, const_iterator I, + StringRef Indent) const { + printSlice(OS, I, Indent); + OS << "\n"; + printUse(OS, I, Indent); +} + +void AllocaSlices::printSlice(raw_ostream &OS, const_iterator I, + StringRef Indent) const { + OS << Indent << "[" << I->beginOffset() << "," << I->endOffset() << ")" + << " slice #" << (I - begin()) + << (I->isSplittable() ? " (splittable)" : ""); +} + +void AllocaSlices::printUse(raw_ostream &OS, const_iterator I, + StringRef Indent) const { + OS << Indent << " used by: " << *I->getUse()->getUser() << "\n"; +} + +void AllocaSlices::print(raw_ostream &OS) const { + if (PointerEscapingInstr) { + OS << "Can't analyze slices for alloca: " << AI << "\n" + << " A pointer to this alloca escaped by:\n" + << " " << *PointerEscapingInstr << "\n"; + return; + } + + OS << "Slices of alloca: " << AI << "\n"; + for (const_iterator I = begin(), E = end(); I != E; ++I) + print(OS, I); +} + +LLVM_DUMP_METHOD void AllocaSlices::dump(const_iterator I) const { + print(dbgs(), I); +} +LLVM_DUMP_METHOD void AllocaSlices::dump() const { print(dbgs()); } + +#endif // !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP) + +/// Walk the range of a partitioning looking for a common type to cover this +/// sequence of slices. +static Type *findCommonType(AllocaSlices::const_iterator B, + AllocaSlices::const_iterator E, + uint64_t EndOffset) { + Type *Ty = nullptr; + bool TyIsCommon = true; + IntegerType *ITy = nullptr; + + // Note that we need to look at *every* alloca slice's Use to ensure we + // always get consistent results regardless of the order of slices. + for (AllocaSlices::const_iterator I = B; I != E; ++I) { + Use *U = I->getUse(); + if (isa<IntrinsicInst>(*U->getUser())) + continue; + if (I->beginOffset() != B->beginOffset() || I->endOffset() != EndOffset) + continue; + + Type *UserTy = nullptr; + if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) { + UserTy = LI->getType(); + } else if (StoreInst *SI = dyn_cast<StoreInst>(U->getUser())) { + UserTy = SI->getValueOperand()->getType(); + } + + if (IntegerType *UserITy = dyn_cast_or_null<IntegerType>(UserTy)) { + // If the type is larger than the partition, skip it. We only encounter + // this for split integer operations where we want to use the type of the + // entity causing the split. Also skip if the type is not a byte width + // multiple. + if (UserITy->getBitWidth() % 8 != 0 || + UserITy->getBitWidth() / 8 > (EndOffset - B->beginOffset())) + continue; + + // Track the largest bitwidth integer type used in this way in case there + // is no common type. + if (!ITy || ITy->getBitWidth() < UserITy->getBitWidth()) + ITy = UserITy; + } + + // To avoid depending on the order of slices, Ty and TyIsCommon must not + // depend on types skipped above. + if (!UserTy || (Ty && Ty != UserTy)) + TyIsCommon = false; // Give up on anything but an iN type. + else + Ty = UserTy; + } + + return TyIsCommon ? Ty : ITy; +} + +/// PHI instructions that use an alloca and are subsequently loaded can be +/// rewritten to load both input pointers in the pred blocks and then PHI the +/// results, allowing the load of the alloca to be promoted. +/// From this: +/// %P2 = phi [i32* %Alloca, i32* %Other] +/// %V = load i32* %P2 +/// to: +/// %V1 = load i32* %Alloca -> will be mem2reg'd +/// ... +/// %V2 = load i32* %Other +/// ... +/// %V = phi [i32 %V1, i32 %V2] +/// +/// We can do this to a select if its only uses are loads and if the operands +/// to the select can be loaded unconditionally. +/// +/// FIXME: This should be hoisted into a generic utility, likely in +/// Transforms/Util/Local.h +static bool isSafePHIToSpeculate(PHINode &PN) { + // For now, we can only do this promotion if the load is in the same block + // as the PHI, and if there are no stores between the phi and load. + // TODO: Allow recursive phi users. + // TODO: Allow stores. + BasicBlock *BB = PN.getParent(); + unsigned MaxAlign = 0; + bool HaveLoad = false; + for (User *U : PN.users()) { + LoadInst *LI = dyn_cast<LoadInst>(U); + if (!LI || !LI->isSimple()) + return false; + + // For now we only allow loads in the same block as the PHI. This is + // a common case that happens when instcombine merges two loads through + // a PHI. + if (LI->getParent() != BB) + return false; + + // Ensure that there are no instructions between the PHI and the load that + // could store. + for (BasicBlock::iterator BBI(PN); &*BBI != LI; ++BBI) + if (BBI->mayWriteToMemory()) + return false; + + MaxAlign = std::max(MaxAlign, LI->getAlignment()); + HaveLoad = true; + } + + if (!HaveLoad) + return false; + + const DataLayout &DL = PN.getModule()->getDataLayout(); + + // We can only transform this if it is safe to push the loads into the + // predecessor blocks. The only thing to watch out for is that we can't put + // a possibly trapping load in the predecessor if it is a critical edge. + for (unsigned Idx = 0, Num = PN.getNumIncomingValues(); Idx != Num; ++Idx) { + TerminatorInst *TI = PN.getIncomingBlock(Idx)->getTerminator(); + Value *InVal = PN.getIncomingValue(Idx); + + // If the value is produced by the terminator of the predecessor (an + // invoke) or it has side-effects, there is no valid place to put a load + // in the predecessor. + if (TI == InVal || TI->mayHaveSideEffects()) + return false; + + // If the predecessor has a single successor, then the edge isn't + // critical. + if (TI->getNumSuccessors() == 1) + continue; + + // If this pointer is always safe to load, or if we can prove that there + // is already a load in the block, then we can move the load to the pred + // block. + if (isSafeToLoadUnconditionally(InVal, MaxAlign, DL, TI)) + continue; + + return false; + } + + return true; +} + +static void speculatePHINodeLoads(PHINode &PN) { + DEBUG(dbgs() << " original: " << PN << "\n"); + + Type *LoadTy = cast<PointerType>(PN.getType())->getElementType(); + IRBuilderTy PHIBuilder(&PN); + PHINode *NewPN = PHIBuilder.CreatePHI(LoadTy, PN.getNumIncomingValues(), + PN.getName() + ".sroa.speculated"); + + // Get the AA tags and alignment to use from one of the loads. It doesn't + // matter which one we get and if any differ. + LoadInst *SomeLoad = cast<LoadInst>(PN.user_back()); + + AAMDNodes AATags; + SomeLoad->getAAMetadata(AATags); + unsigned Align = SomeLoad->getAlignment(); + + // Rewrite all loads of the PN to use the new PHI. + while (!PN.use_empty()) { + LoadInst *LI = cast<LoadInst>(PN.user_back()); + LI->replaceAllUsesWith(NewPN); + LI->eraseFromParent(); + } + + // Inject loads into all of the pred blocks. + for (unsigned Idx = 0, Num = PN.getNumIncomingValues(); Idx != Num; ++Idx) { + BasicBlock *Pred = PN.getIncomingBlock(Idx); + TerminatorInst *TI = Pred->getTerminator(); + Value *InVal = PN.getIncomingValue(Idx); + IRBuilderTy PredBuilder(TI); + + LoadInst *Load = PredBuilder.CreateLoad( + InVal, (PN.getName() + ".sroa.speculate.load." + Pred->getName())); + ++NumLoadsSpeculated; + Load->setAlignment(Align); + if (AATags) + Load->setAAMetadata(AATags); + NewPN->addIncoming(Load, Pred); + } + + DEBUG(dbgs() << " speculated to: " << *NewPN << "\n"); + PN.eraseFromParent(); +} + +/// Select instructions that use an alloca and are subsequently loaded can be +/// rewritten to load both input pointers and then select between the result, +/// allowing the load of the alloca to be promoted. +/// From this: +/// %P2 = select i1 %cond, i32* %Alloca, i32* %Other +/// %V = load i32* %P2 +/// to: +/// %V1 = load i32* %Alloca -> will be mem2reg'd +/// %V2 = load i32* %Other +/// %V = select i1 %cond, i32 %V1, i32 %V2 +/// +/// We can do this to a select if its only uses are loads and if the operand +/// to the select can be loaded unconditionally. +static bool isSafeSelectToSpeculate(SelectInst &SI) { + Value *TValue = SI.getTrueValue(); + Value *FValue = SI.getFalseValue(); + const DataLayout &DL = SI.getModule()->getDataLayout(); + + for (User *U : SI.users()) { + LoadInst *LI = dyn_cast<LoadInst>(U); + if (!LI || !LI->isSimple()) + return false; + + // Both operands to the select need to be dereferenceable, either + // absolutely (e.g. allocas) or at this point because we can see other + // accesses to it. + if (!isSafeToLoadUnconditionally(TValue, LI->getAlignment(), DL, LI)) + return false; + if (!isSafeToLoadUnconditionally(FValue, LI->getAlignment(), DL, LI)) + return false; + } + + return true; +} + +static void speculateSelectInstLoads(SelectInst &SI) { + DEBUG(dbgs() << " original: " << SI << "\n"); + + IRBuilderTy IRB(&SI); + Value *TV = SI.getTrueValue(); + Value *FV = SI.getFalseValue(); + // Replace the loads of the select with a select of two loads. + while (!SI.use_empty()) { + LoadInst *LI = cast<LoadInst>(SI.user_back()); + assert(LI->isSimple() && "We only speculate simple loads"); + + IRB.SetInsertPoint(LI); + LoadInst *TL = + IRB.CreateLoad(TV, LI->getName() + ".sroa.speculate.load.true"); + LoadInst *FL = + IRB.CreateLoad(FV, LI->getName() + ".sroa.speculate.load.false"); + NumLoadsSpeculated += 2; + + // Transfer alignment and AA info if present. + TL->setAlignment(LI->getAlignment()); + FL->setAlignment(LI->getAlignment()); + + AAMDNodes Tags; + LI->getAAMetadata(Tags); + if (Tags) { + TL->setAAMetadata(Tags); + FL->setAAMetadata(Tags); + } + + Value *V = IRB.CreateSelect(SI.getCondition(), TL, FL, + LI->getName() + ".sroa.speculated"); + + DEBUG(dbgs() << " speculated to: " << *V << "\n"); + LI->replaceAllUsesWith(V); + LI->eraseFromParent(); + } + SI.eraseFromParent(); +} + +/// \brief Build a GEP out of a base pointer and indices. +/// +/// This will return the BasePtr if that is valid, or build a new GEP +/// instruction using the IRBuilder if GEP-ing is needed. +static Value *buildGEP(IRBuilderTy &IRB, Value *BasePtr, + SmallVectorImpl<Value *> &Indices, Twine NamePrefix) { + if (Indices.empty()) + return BasePtr; + + // A single zero index is a no-op, so check for this and avoid building a GEP + // in that case. + if (Indices.size() == 1 && cast<ConstantInt>(Indices.back())->isZero()) + return BasePtr; + + return IRB.CreateInBoundsGEP(nullptr, BasePtr, Indices, + NamePrefix + "sroa_idx"); +} + +/// \brief Get a natural GEP off of the BasePtr walking through Ty toward +/// TargetTy without changing the offset of the pointer. +/// +/// This routine assumes we've already established a properly offset GEP with +/// Indices, and arrived at the Ty type. The goal is to continue to GEP with +/// zero-indices down through type layers until we find one the same as +/// TargetTy. If we can't find one with the same type, we at least try to use +/// one with the same size. If none of that works, we just produce the GEP as +/// indicated by Indices to have the correct offset. +static Value *getNaturalGEPWithType(IRBuilderTy &IRB, const DataLayout &DL, + Value *BasePtr, Type *Ty, Type *TargetTy, + SmallVectorImpl<Value *> &Indices, + Twine NamePrefix) { + if (Ty == TargetTy) + return buildGEP(IRB, BasePtr, Indices, NamePrefix); + + // Pointer size to use for the indices. + unsigned PtrSize = DL.getPointerTypeSizeInBits(BasePtr->getType()); + + // See if we can descend into a struct and locate a field with the correct + // type. + unsigned NumLayers = 0; + Type *ElementTy = Ty; + do { + if (ElementTy->isPointerTy()) + break; + + if (ArrayType *ArrayTy = dyn_cast<ArrayType>(ElementTy)) { + ElementTy = ArrayTy->getElementType(); + Indices.push_back(IRB.getIntN(PtrSize, 0)); + } else if (VectorType *VectorTy = dyn_cast<VectorType>(ElementTy)) { + ElementTy = VectorTy->getElementType(); + Indices.push_back(IRB.getInt32(0)); + } else if (StructType *STy = dyn_cast<StructType>(ElementTy)) { + if (STy->element_begin() == STy->element_end()) + break; // Nothing left to descend into. + ElementTy = *STy->element_begin(); + Indices.push_back(IRB.getInt32(0)); + } else { + break; + } + ++NumLayers; + } while (ElementTy != TargetTy); + if (ElementTy != TargetTy) + Indices.erase(Indices.end() - NumLayers, Indices.end()); + + return buildGEP(IRB, BasePtr, Indices, NamePrefix); +} + +/// \brief Recursively compute indices for a natural GEP. +/// +/// This is the recursive step for getNaturalGEPWithOffset that walks down the +/// element types adding appropriate indices for the GEP. +static Value *getNaturalGEPRecursively(IRBuilderTy &IRB, const DataLayout &DL, + Value *Ptr, Type *Ty, APInt &Offset, + Type *TargetTy, + SmallVectorImpl<Value *> &Indices, + Twine NamePrefix) { + if (Offset == 0) + return getNaturalGEPWithType(IRB, DL, Ptr, Ty, TargetTy, Indices, + NamePrefix); + + // We can't recurse through pointer types. + if (Ty->isPointerTy()) + return nullptr; + + // We try to analyze GEPs over vectors here, but note that these GEPs are + // extremely poorly defined currently. The long-term goal is to remove GEPing + // over a vector from the IR completely. + if (VectorType *VecTy = dyn_cast<VectorType>(Ty)) { + unsigned ElementSizeInBits = DL.getTypeSizeInBits(VecTy->getScalarType()); + if (ElementSizeInBits % 8 != 0) { + // GEPs over non-multiple of 8 size vector elements are invalid. + return nullptr; + } + APInt ElementSize(Offset.getBitWidth(), ElementSizeInBits / 8); + APInt NumSkippedElements = Offset.sdiv(ElementSize); + if (NumSkippedElements.ugt(VecTy->getNumElements())) + return nullptr; + Offset -= NumSkippedElements * ElementSize; + Indices.push_back(IRB.getInt(NumSkippedElements)); + return getNaturalGEPRecursively(IRB, DL, Ptr, VecTy->getElementType(), + Offset, TargetTy, Indices, NamePrefix); + } + + if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) { + Type *ElementTy = ArrTy->getElementType(); + APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); + APInt NumSkippedElements = Offset.sdiv(ElementSize); + if (NumSkippedElements.ugt(ArrTy->getNumElements())) + return nullptr; + + Offset -= NumSkippedElements * ElementSize; + Indices.push_back(IRB.getInt(NumSkippedElements)); + return getNaturalGEPRecursively(IRB, DL, Ptr, ElementTy, Offset, TargetTy, + Indices, NamePrefix); + } + + StructType *STy = dyn_cast<StructType>(Ty); + if (!STy) + return nullptr; + + const StructLayout *SL = DL.getStructLayout(STy); + uint64_t StructOffset = Offset.getZExtValue(); + if (StructOffset >= SL->getSizeInBytes()) + return nullptr; + unsigned Index = SL->getElementContainingOffset(StructOffset); + Offset -= APInt(Offset.getBitWidth(), SL->getElementOffset(Index)); + Type *ElementTy = STy->getElementType(Index); + if (Offset.uge(DL.getTypeAllocSize(ElementTy))) + return nullptr; // The offset points into alignment padding. + + Indices.push_back(IRB.getInt32(Index)); + return getNaturalGEPRecursively(IRB, DL, Ptr, ElementTy, Offset, TargetTy, + Indices, NamePrefix); +} + +/// \brief Get a natural GEP from a base pointer to a particular offset and +/// resulting in a particular type. +/// +/// The goal is to produce a "natural" looking GEP that works with the existing +/// composite types to arrive at the appropriate offset and element type for +/// a pointer. TargetTy is the element type the returned GEP should point-to if +/// possible. We recurse by decreasing Offset, adding the appropriate index to +/// Indices, and setting Ty to the result subtype. +/// +/// If no natural GEP can be constructed, this function returns null. +static Value *getNaturalGEPWithOffset(IRBuilderTy &IRB, const DataLayout &DL, + Value *Ptr, APInt Offset, Type *TargetTy, + SmallVectorImpl<Value *> &Indices, + Twine NamePrefix) { + PointerType *Ty = cast<PointerType>(Ptr->getType()); + + // Don't consider any GEPs through an i8* as natural unless the TargetTy is + // an i8. + if (Ty == IRB.getInt8PtrTy(Ty->getAddressSpace()) && TargetTy->isIntegerTy(8)) + return nullptr; + + Type *ElementTy = Ty->getElementType(); + if (!ElementTy->isSized()) + return nullptr; // We can't GEP through an unsized element. + APInt ElementSize(Offset.getBitWidth(), DL.getTypeAllocSize(ElementTy)); + if (ElementSize == 0) + return nullptr; // Zero-length arrays can't help us build a natural GEP. + APInt NumSkippedElements = Offset.sdiv(ElementSize); + + Offset -= NumSkippedElements * ElementSize; + Indices.push_back(IRB.getInt(NumSkippedElements)); + return getNaturalGEPRecursively(IRB, DL, Ptr, ElementTy, Offset, TargetTy, + Indices, NamePrefix); +} + +/// \brief Compute an adjusted pointer from Ptr by Offset bytes where the +/// resulting pointer has PointerTy. +/// +/// This tries very hard to compute a "natural" GEP which arrives at the offset +/// and produces the pointer type desired. Where it cannot, it will try to use +/// the natural GEP to arrive at the offset and bitcast to the type. Where that +/// fails, it will try to use an existing i8* and GEP to the byte offset and +/// bitcast to the type. +/// +/// The strategy for finding the more natural GEPs is to peel off layers of the +/// pointer, walking back through bit casts and GEPs, searching for a base +/// pointer from which we can compute a natural GEP with the desired +/// properties. The algorithm tries to fold as many constant indices into +/// a single GEP as possible, thus making each GEP more independent of the +/// surrounding code. +static Value *getAdjustedPtr(IRBuilderTy &IRB, const DataLayout &DL, Value *Ptr, + APInt Offset, Type *PointerTy, Twine NamePrefix) { + // Even though we don't look through PHI nodes, we could be called on an + // instruction in an unreachable block, which may be on a cycle. + SmallPtrSet<Value *, 4> Visited; + Visited.insert(Ptr); + SmallVector<Value *, 4> Indices; + + // We may end up computing an offset pointer that has the wrong type. If we + // never are able to compute one directly that has the correct type, we'll + // fall back to it, so keep it and the base it was computed from around here. + Value *OffsetPtr = nullptr; + Value *OffsetBasePtr; + + // Remember any i8 pointer we come across to re-use if we need to do a raw + // byte offset. + Value *Int8Ptr = nullptr; + APInt Int8PtrOffset(Offset.getBitWidth(), 0); + + Type *TargetTy = PointerTy->getPointerElementType(); + + do { + // First fold any existing GEPs into the offset. + while (GEPOperator *GEP = dyn_cast<GEPOperator>(Ptr)) { + APInt GEPOffset(Offset.getBitWidth(), 0); + if (!GEP->accumulateConstantOffset(DL, GEPOffset)) + break; + Offset += GEPOffset; + Ptr = GEP->getPointerOperand(); + if (!Visited.insert(Ptr).second) + break; + } + + // See if we can perform a natural GEP here. + Indices.clear(); + if (Value *P = getNaturalGEPWithOffset(IRB, DL, Ptr, Offset, TargetTy, + Indices, NamePrefix)) { + // If we have a new natural pointer at the offset, clear out any old + // offset pointer we computed. Unless it is the base pointer or + // a non-instruction, we built a GEP we don't need. Zap it. + if (OffsetPtr && OffsetPtr != OffsetBasePtr) + if (Instruction *I = dyn_cast<Instruction>(OffsetPtr)) { + assert(I->use_empty() && "Built a GEP with uses some how!"); + I->eraseFromParent(); + } + OffsetPtr = P; + OffsetBasePtr = Ptr; + // If we also found a pointer of the right type, we're done. + if (P->getType() == PointerTy) + return P; + } + + // Stash this pointer if we've found an i8*. + if (Ptr->getType()->isIntegerTy(8)) { + Int8Ptr = Ptr; + Int8PtrOffset = Offset; + } + + // Peel off a layer of the pointer and update the offset appropriately. + if (Operator::getOpcode(Ptr) == Instruction::BitCast) { + Ptr = cast<Operator>(Ptr)->getOperand(0); + } else if (GlobalAlias *GA = dyn_cast<GlobalAlias>(Ptr)) { + if (GA->isInterposable()) + break; + Ptr = GA->getAliasee(); + } else { + break; + } + assert(Ptr->getType()->isPointerTy() && "Unexpected operand type!"); + } while (Visited.insert(Ptr).second); + + if (!OffsetPtr) { + if (!Int8Ptr) { + Int8Ptr = IRB.CreateBitCast( + Ptr, IRB.getInt8PtrTy(PointerTy->getPointerAddressSpace()), + NamePrefix + "sroa_raw_cast"); + Int8PtrOffset = Offset; + } + + OffsetPtr = Int8PtrOffset == 0 + ? Int8Ptr + : IRB.CreateInBoundsGEP(IRB.getInt8Ty(), Int8Ptr, + IRB.getInt(Int8PtrOffset), + NamePrefix + "sroa_raw_idx"); + } + Ptr = OffsetPtr; + + // On the off chance we were targeting i8*, guard the bitcast here. + if (Ptr->getType() != PointerTy) + Ptr = IRB.CreateBitCast(Ptr, PointerTy, NamePrefix + "sroa_cast"); + + return Ptr; +} + +/// \brief Compute the adjusted alignment for a load or store from an offset. +static unsigned getAdjustedAlignment(Instruction *I, uint64_t Offset, + const DataLayout &DL) { + unsigned Alignment; + Type *Ty; + if (auto *LI = dyn_cast<LoadInst>(I)) { + Alignment = LI->getAlignment(); + Ty = LI->getType(); + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + Alignment = SI->getAlignment(); + Ty = SI->getValueOperand()->getType(); + } else { + llvm_unreachable("Only loads and stores are allowed!"); + } + + if (!Alignment) + Alignment = DL.getABITypeAlignment(Ty); + + return MinAlign(Alignment, Offset); +} + +/// \brief Test whether we can convert a value from the old to the new type. +/// +/// This predicate should be used to guard calls to convertValue in order to +/// ensure that we only try to convert viable values. The strategy is that we +/// will peel off single element struct and array wrappings to get to an +/// underlying value, and convert that value. +static bool canConvertValue(const DataLayout &DL, Type *OldTy, Type *NewTy) { + if (OldTy == NewTy) + return true; + + // For integer types, we can't handle any bit-width differences. This would + // break both vector conversions with extension and introduce endianness + // issues when in conjunction with loads and stores. + if (isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) { + assert(cast<IntegerType>(OldTy)->getBitWidth() != + cast<IntegerType>(NewTy)->getBitWidth() && + "We can't have the same bitwidth for different int types"); + return false; + } + + if (DL.getTypeSizeInBits(NewTy) != DL.getTypeSizeInBits(OldTy)) + return false; + if (!NewTy->isSingleValueType() || !OldTy->isSingleValueType()) + return false; + + // We can convert pointers to integers and vice-versa. Same for vectors + // of pointers and integers. + OldTy = OldTy->getScalarType(); + NewTy = NewTy->getScalarType(); + if (NewTy->isPointerTy() || OldTy->isPointerTy()) { + if (NewTy->isPointerTy() && OldTy->isPointerTy()) { + return cast<PointerType>(NewTy)->getPointerAddressSpace() == + cast<PointerType>(OldTy)->getPointerAddressSpace(); + } + + // We can convert integers to integral pointers, but not to non-integral + // pointers. + if (OldTy->isIntegerTy()) + return !DL.isNonIntegralPointerType(NewTy); + + // We can convert integral pointers to integers, but non-integral pointers + // need to remain pointers. + if (!DL.isNonIntegralPointerType(OldTy)) + return NewTy->isIntegerTy(); + + return false; + } + + return true; +} + +/// \brief Generic routine to convert an SSA value to a value of a different +/// type. +/// +/// This will try various different casting techniques, such as bitcasts, +/// inttoptr, and ptrtoint casts. Use the \c canConvertValue predicate to test +/// two types for viability with this routine. +static Value *convertValue(const DataLayout &DL, IRBuilderTy &IRB, Value *V, + Type *NewTy) { + Type *OldTy = V->getType(); + assert(canConvertValue(DL, OldTy, NewTy) && "Value not convertable to type"); + + if (OldTy == NewTy) + return V; + + assert(!(isa<IntegerType>(OldTy) && isa<IntegerType>(NewTy)) && + "Integer types must be the exact same to convert."); + + // See if we need inttoptr for this type pair. A cast involving both scalars + // and vectors requires and additional bitcast. + if (OldTy->isIntOrIntVectorTy() && NewTy->isPtrOrPtrVectorTy()) { + // Expand <2 x i32> to i8* --> <2 x i32> to i64 to i8* + if (OldTy->isVectorTy() && !NewTy->isVectorTy()) + return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), + NewTy); + + // Expand i128 to <2 x i8*> --> i128 to <2 x i64> to <2 x i8*> + if (!OldTy->isVectorTy() && NewTy->isVectorTy()) + return IRB.CreateIntToPtr(IRB.CreateBitCast(V, DL.getIntPtrType(NewTy)), + NewTy); + + return IRB.CreateIntToPtr(V, NewTy); + } + + // See if we need ptrtoint for this type pair. A cast involving both scalars + // and vectors requires and additional bitcast. + if (OldTy->isPtrOrPtrVectorTy() && NewTy->isIntOrIntVectorTy()) { + // Expand <2 x i8*> to i128 --> <2 x i8*> to <2 x i64> to i128 + if (OldTy->isVectorTy() && !NewTy->isVectorTy()) + return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), + NewTy); + + // Expand i8* to <2 x i32> --> i8* to i64 to <2 x i32> + if (!OldTy->isVectorTy() && NewTy->isVectorTy()) + return IRB.CreateBitCast(IRB.CreatePtrToInt(V, DL.getIntPtrType(OldTy)), + NewTy); + + return IRB.CreatePtrToInt(V, NewTy); + } + + return IRB.CreateBitCast(V, NewTy); +} + +/// \brief Test whether the given slice use can be promoted to a vector. +/// +/// This function is called to test each entry in a partition which is slated +/// for a single slice. +static bool isVectorPromotionViableForSlice(Partition &P, const Slice &S, + VectorType *Ty, + uint64_t ElementSize, + const DataLayout &DL) { + // First validate the slice offsets. + uint64_t BeginOffset = + std::max(S.beginOffset(), P.beginOffset()) - P.beginOffset(); + uint64_t BeginIndex = BeginOffset / ElementSize; + if (BeginIndex * ElementSize != BeginOffset || + BeginIndex >= Ty->getNumElements()) + return false; + uint64_t EndOffset = + std::min(S.endOffset(), P.endOffset()) - P.beginOffset(); + uint64_t EndIndex = EndOffset / ElementSize; + if (EndIndex * ElementSize != EndOffset || EndIndex > Ty->getNumElements()) + return false; + + assert(EndIndex > BeginIndex && "Empty vector!"); + uint64_t NumElements = EndIndex - BeginIndex; + Type *SliceTy = (NumElements == 1) + ? Ty->getElementType() + : VectorType::get(Ty->getElementType(), NumElements); + + Type *SplitIntTy = + Type::getIntNTy(Ty->getContext(), NumElements * ElementSize * 8); + + Use *U = S.getUse(); + + if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(U->getUser())) { + if (MI->isVolatile()) + return false; + if (!S.isSplittable()) + return false; // Skip any unsplittable intrinsics. + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { + if (II->getIntrinsicID() != Intrinsic::lifetime_start && + II->getIntrinsicID() != Intrinsic::lifetime_end) + return false; + } else if (U->get()->getType()->getPointerElementType()->isStructTy()) { + // Disable vector promotion when there are loads or stores of an FCA. + return false; + } else if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) { + if (LI->isVolatile()) + return false; + Type *LTy = LI->getType(); + if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) { + assert(LTy->isIntegerTy()); + LTy = SplitIntTy; + } + if (!canConvertValue(DL, SliceTy, LTy)) + return false; + } else if (StoreInst *SI = dyn_cast<StoreInst>(U->getUser())) { + if (SI->isVolatile()) + return false; + Type *STy = SI->getValueOperand()->getType(); + if (P.beginOffset() > S.beginOffset() || P.endOffset() < S.endOffset()) { + assert(STy->isIntegerTy()); + STy = SplitIntTy; + } + if (!canConvertValue(DL, STy, SliceTy)) + return false; + } else { + return false; + } + + return true; +} + +/// \brief Test whether the given alloca partitioning and range of slices can be +/// promoted to a vector. +/// +/// This is a quick test to check whether we can rewrite a particular alloca +/// partition (and its newly formed alloca) into a vector alloca with only +/// whole-vector loads and stores such that it could be promoted to a vector +/// SSA value. We only can ensure this for a limited set of operations, and we +/// don't want to do the rewrites unless we are confident that the result will +/// be promotable, so we have an early test here. +static VectorType *isVectorPromotionViable(Partition &P, const DataLayout &DL) { + // Collect the candidate types for vector-based promotion. Also track whether + // we have different element types. + SmallVector<VectorType *, 4> CandidateTys; + Type *CommonEltTy = nullptr; + bool HaveCommonEltTy = true; + auto CheckCandidateType = [&](Type *Ty) { + if (auto *VTy = dyn_cast<VectorType>(Ty)) { + CandidateTys.push_back(VTy); + if (!CommonEltTy) + CommonEltTy = VTy->getElementType(); + else if (CommonEltTy != VTy->getElementType()) + HaveCommonEltTy = false; + } + }; + // Consider any loads or stores that are the exact size of the slice. + for (const Slice &S : P) + if (S.beginOffset() == P.beginOffset() && + S.endOffset() == P.endOffset()) { + if (auto *LI = dyn_cast<LoadInst>(S.getUse()->getUser())) + CheckCandidateType(LI->getType()); + else if (auto *SI = dyn_cast<StoreInst>(S.getUse()->getUser())) + CheckCandidateType(SI->getValueOperand()->getType()); + } + + // If we didn't find a vector type, nothing to do here. + if (CandidateTys.empty()) + return nullptr; + + // Remove non-integer vector types if we had multiple common element types. + // FIXME: It'd be nice to replace them with integer vector types, but we can't + // do that until all the backends are known to produce good code for all + // integer vector types. + if (!HaveCommonEltTy) { + CandidateTys.erase( + llvm::remove_if(CandidateTys, + [](VectorType *VTy) { + return !VTy->getElementType()->isIntegerTy(); + }), + CandidateTys.end()); + + // If there were no integer vector types, give up. + if (CandidateTys.empty()) + return nullptr; + + // Rank the remaining candidate vector types. This is easy because we know + // they're all integer vectors. We sort by ascending number of elements. + auto RankVectorTypes = [&DL](VectorType *RHSTy, VectorType *LHSTy) { + (void)DL; + assert(DL.getTypeSizeInBits(RHSTy) == DL.getTypeSizeInBits(LHSTy) && + "Cannot have vector types of different sizes!"); + assert(RHSTy->getElementType()->isIntegerTy() && + "All non-integer types eliminated!"); + assert(LHSTy->getElementType()->isIntegerTy() && + "All non-integer types eliminated!"); + return RHSTy->getNumElements() < LHSTy->getNumElements(); + }; + std::sort(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes); + CandidateTys.erase( + std::unique(CandidateTys.begin(), CandidateTys.end(), RankVectorTypes), + CandidateTys.end()); + } else { +// The only way to have the same element type in every vector type is to +// have the same vector type. Check that and remove all but one. +#ifndef NDEBUG + for (VectorType *VTy : CandidateTys) { + assert(VTy->getElementType() == CommonEltTy && + "Unaccounted for element type!"); + assert(VTy == CandidateTys[0] && + "Different vector types with the same element type!"); + } +#endif + CandidateTys.resize(1); + } + + // Try each vector type, and return the one which works. + auto CheckVectorTypeForPromotion = [&](VectorType *VTy) { + uint64_t ElementSize = DL.getTypeSizeInBits(VTy->getElementType()); + + // While the definition of LLVM vectors is bitpacked, we don't support sizes + // that aren't byte sized. + if (ElementSize % 8) + return false; + assert((DL.getTypeSizeInBits(VTy) % 8) == 0 && + "vector size not a multiple of element size?"); + ElementSize /= 8; + + for (const Slice &S : P) + if (!isVectorPromotionViableForSlice(P, S, VTy, ElementSize, DL)) + return false; + + for (const Slice *S : P.splitSliceTails()) + if (!isVectorPromotionViableForSlice(P, *S, VTy, ElementSize, DL)) + return false; + + return true; + }; + for (VectorType *VTy : CandidateTys) + if (CheckVectorTypeForPromotion(VTy)) + return VTy; + + return nullptr; +} + +/// \brief Test whether a slice of an alloca is valid for integer widening. +/// +/// This implements the necessary checking for the \c isIntegerWideningViable +/// test below on a single slice of the alloca. +static bool isIntegerWideningViableForSlice(const Slice &S, + uint64_t AllocBeginOffset, + Type *AllocaTy, + const DataLayout &DL, + bool &WholeAllocaOp) { + uint64_t Size = DL.getTypeStoreSize(AllocaTy); + + uint64_t RelBegin = S.beginOffset() - AllocBeginOffset; + uint64_t RelEnd = S.endOffset() - AllocBeginOffset; + + // We can't reasonably handle cases where the load or store extends past + // the end of the alloca's type and into its padding. + if (RelEnd > Size) + return false; + + Use *U = S.getUse(); + + if (LoadInst *LI = dyn_cast<LoadInst>(U->getUser())) { + if (LI->isVolatile()) + return false; + // We can't handle loads that extend past the allocated memory. + if (DL.getTypeStoreSize(LI->getType()) > Size) + return false; + // Note that we don't count vector loads or stores as whole-alloca + // operations which enable integer widening because we would prefer to use + // vector widening instead. + if (!isa<VectorType>(LI->getType()) && RelBegin == 0 && RelEnd == Size) + WholeAllocaOp = true; + if (IntegerType *ITy = dyn_cast<IntegerType>(LI->getType())) { + if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) + return false; + } else if (RelBegin != 0 || RelEnd != Size || + !canConvertValue(DL, AllocaTy, LI->getType())) { + // Non-integer loads need to be convertible from the alloca type so that + // they are promotable. + return false; + } + } else if (StoreInst *SI = dyn_cast<StoreInst>(U->getUser())) { + Type *ValueTy = SI->getValueOperand()->getType(); + if (SI->isVolatile()) + return false; + // We can't handle stores that extend past the allocated memory. + if (DL.getTypeStoreSize(ValueTy) > Size) + return false; + // Note that we don't count vector loads or stores as whole-alloca + // operations which enable integer widening because we would prefer to use + // vector widening instead. + if (!isa<VectorType>(ValueTy) && RelBegin == 0 && RelEnd == Size) + WholeAllocaOp = true; + if (IntegerType *ITy = dyn_cast<IntegerType>(ValueTy)) { + if (ITy->getBitWidth() < DL.getTypeStoreSizeInBits(ITy)) + return false; + } else if (RelBegin != 0 || RelEnd != Size || + !canConvertValue(DL, ValueTy, AllocaTy)) { + // Non-integer stores need to be convertible to the alloca type so that + // they are promotable. + return false; + } + } else if (MemIntrinsic *MI = dyn_cast<MemIntrinsic>(U->getUser())) { + if (MI->isVolatile() || !isa<Constant>(MI->getLength())) + return false; + if (!S.isSplittable()) + return false; // Skip any unsplittable intrinsics. + } else if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U->getUser())) { + if (II->getIntrinsicID() != Intrinsic::lifetime_start && + II->getIntrinsicID() != Intrinsic::lifetime_end) + return false; + } else { + return false; + } + + return true; +} + +/// \brief Test whether the given alloca partition's integer operations can be +/// widened to promotable ones. +/// +/// This is a quick test to check whether we can rewrite the integer loads and +/// stores to a particular alloca into wider loads and stores and be able to +/// promote the resulting alloca. +static bool isIntegerWideningViable(Partition &P, Type *AllocaTy, + const DataLayout &DL) { + uint64_t SizeInBits = DL.getTypeSizeInBits(AllocaTy); + // Don't create integer types larger than the maximum bitwidth. + if (SizeInBits > IntegerType::MAX_INT_BITS) + return false; + + // Don't try to handle allocas with bit-padding. + if (SizeInBits != DL.getTypeStoreSizeInBits(AllocaTy)) + return false; + + // We need to ensure that an integer type with the appropriate bitwidth can + // be converted to the alloca type, whatever that is. We don't want to force + // the alloca itself to have an integer type if there is a more suitable one. + Type *IntTy = Type::getIntNTy(AllocaTy->getContext(), SizeInBits); + if (!canConvertValue(DL, AllocaTy, IntTy) || + !canConvertValue(DL, IntTy, AllocaTy)) + return false; + + // While examining uses, we ensure that the alloca has a covering load or + // store. We don't want to widen the integer operations only to fail to + // promote due to some other unsplittable entry (which we may make splittable + // later). However, if there are only splittable uses, go ahead and assume + // that we cover the alloca. + // FIXME: We shouldn't consider split slices that happen to start in the + // partition here... + bool WholeAllocaOp = + P.begin() != P.end() ? false : DL.isLegalInteger(SizeInBits); + + for (const Slice &S : P) + if (!isIntegerWideningViableForSlice(S, P.beginOffset(), AllocaTy, DL, + WholeAllocaOp)) + return false; + + for (const Slice *S : P.splitSliceTails()) + if (!isIntegerWideningViableForSlice(*S, P.beginOffset(), AllocaTy, DL, + WholeAllocaOp)) + return false; + + return WholeAllocaOp; +} + +static Value *extractInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *V, + IntegerType *Ty, uint64_t Offset, + const Twine &Name) { + DEBUG(dbgs() << " start: " << *V << "\n"); + IntegerType *IntTy = cast<IntegerType>(V->getType()); + assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && + "Element extends past full value"); + uint64_t ShAmt = 8 * Offset; + if (DL.isBigEndian()) + ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); + if (ShAmt) { + V = IRB.CreateLShr(V, ShAmt, Name + ".shift"); + DEBUG(dbgs() << " shifted: " << *V << "\n"); + } + assert(Ty->getBitWidth() <= IntTy->getBitWidth() && + "Cannot extract to a larger integer!"); + if (Ty != IntTy) { + V = IRB.CreateTrunc(V, Ty, Name + ".trunc"); + DEBUG(dbgs() << " trunced: " << *V << "\n"); + } + return V; +} + +static Value *insertInteger(const DataLayout &DL, IRBuilderTy &IRB, Value *Old, + Value *V, uint64_t Offset, const Twine &Name) { + IntegerType *IntTy = cast<IntegerType>(Old->getType()); + IntegerType *Ty = cast<IntegerType>(V->getType()); + assert(Ty->getBitWidth() <= IntTy->getBitWidth() && + "Cannot insert a larger integer!"); + DEBUG(dbgs() << " start: " << *V << "\n"); + if (Ty != IntTy) { + V = IRB.CreateZExt(V, IntTy, Name + ".ext"); + DEBUG(dbgs() << " extended: " << *V << "\n"); + } + assert(DL.getTypeStoreSize(Ty) + Offset <= DL.getTypeStoreSize(IntTy) && + "Element store outside of alloca store"); + uint64_t ShAmt = 8 * Offset; + if (DL.isBigEndian()) + ShAmt = 8 * (DL.getTypeStoreSize(IntTy) - DL.getTypeStoreSize(Ty) - Offset); + if (ShAmt) { + V = IRB.CreateShl(V, ShAmt, Name + ".shift"); + DEBUG(dbgs() << " shifted: " << *V << "\n"); + } + + if (ShAmt || Ty->getBitWidth() < IntTy->getBitWidth()) { + APInt Mask = ~Ty->getMask().zext(IntTy->getBitWidth()).shl(ShAmt); + Old = IRB.CreateAnd(Old, Mask, Name + ".mask"); + DEBUG(dbgs() << " masked: " << *Old << "\n"); + V = IRB.CreateOr(Old, V, Name + ".insert"); + DEBUG(dbgs() << " inserted: " << *V << "\n"); + } + return V; +} + +static Value *extractVector(IRBuilderTy &IRB, Value *V, unsigned BeginIndex, + unsigned EndIndex, const Twine &Name) { + VectorType *VecTy = cast<VectorType>(V->getType()); + unsigned NumElements = EndIndex - BeginIndex; + assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); + + if (NumElements == VecTy->getNumElements()) + return V; + + if (NumElements == 1) { + V = IRB.CreateExtractElement(V, IRB.getInt32(BeginIndex), + Name + ".extract"); + DEBUG(dbgs() << " extract: " << *V << "\n"); + return V; + } + + SmallVector<Constant *, 8> Mask; + Mask.reserve(NumElements); + for (unsigned i = BeginIndex; i != EndIndex; ++i) + Mask.push_back(IRB.getInt32(i)); + V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), + ConstantVector::get(Mask), Name + ".extract"); + DEBUG(dbgs() << " shuffle: " << *V << "\n"); + return V; +} + +static Value *insertVector(IRBuilderTy &IRB, Value *Old, Value *V, + unsigned BeginIndex, const Twine &Name) { + VectorType *VecTy = cast<VectorType>(Old->getType()); + assert(VecTy && "Can only insert a vector into a vector"); + + VectorType *Ty = dyn_cast<VectorType>(V->getType()); + if (!Ty) { + // Single element to insert. + V = IRB.CreateInsertElement(Old, V, IRB.getInt32(BeginIndex), + Name + ".insert"); + DEBUG(dbgs() << " insert: " << *V << "\n"); + return V; + } + + assert(Ty->getNumElements() <= VecTy->getNumElements() && + "Too many elements!"); + if (Ty->getNumElements() == VecTy->getNumElements()) { + assert(V->getType() == VecTy && "Vector type mismatch"); + return V; + } + unsigned EndIndex = BeginIndex + Ty->getNumElements(); + + // When inserting a smaller vector into the larger to store, we first + // use a shuffle vector to widen it with undef elements, and then + // a second shuffle vector to select between the loaded vector and the + // incoming vector. + SmallVector<Constant *, 8> Mask; + Mask.reserve(VecTy->getNumElements()); + for (unsigned i = 0; i != VecTy->getNumElements(); ++i) + if (i >= BeginIndex && i < EndIndex) + Mask.push_back(IRB.getInt32(i - BeginIndex)); + else + Mask.push_back(UndefValue::get(IRB.getInt32Ty())); + V = IRB.CreateShuffleVector(V, UndefValue::get(V->getType()), + ConstantVector::get(Mask), Name + ".expand"); + DEBUG(dbgs() << " shuffle: " << *V << "\n"); + + Mask.clear(); + for (unsigned i = 0; i != VecTy->getNumElements(); ++i) + Mask.push_back(IRB.getInt1(i >= BeginIndex && i < EndIndex)); + + V = IRB.CreateSelect(ConstantVector::get(Mask), V, Old, Name + "blend"); + + DEBUG(dbgs() << " blend: " << *V << "\n"); + return V; +} + +/// \brief Visitor to rewrite instructions using p particular slice of an alloca +/// to use a new alloca. +/// +/// Also implements the rewriting to vector-based accesses when the partition +/// passes the isVectorPromotionViable predicate. Most of the rewriting logic +/// lives here. +class llvm::sroa::AllocaSliceRewriter + : public InstVisitor<AllocaSliceRewriter, bool> { + // Befriend the base class so it can delegate to private visit methods. + friend class InstVisitor<AllocaSliceRewriter, bool>; + + using Base = InstVisitor<AllocaSliceRewriter, bool>; + + const DataLayout &DL; + AllocaSlices &AS; + SROA &Pass; + AllocaInst &OldAI, &NewAI; + const uint64_t NewAllocaBeginOffset, NewAllocaEndOffset; + Type *NewAllocaTy; + + // This is a convenience and flag variable that will be null unless the new + // alloca's integer operations should be widened to this integer type due to + // passing isIntegerWideningViable above. If it is non-null, the desired + // integer type will be stored here for easy access during rewriting. + IntegerType *IntTy; + + // If we are rewriting an alloca partition which can be written as pure + // vector operations, we stash extra information here. When VecTy is + // non-null, we have some strict guarantees about the rewritten alloca: + // - The new alloca is exactly the size of the vector type here. + // - The accesses all either map to the entire vector or to a single + // element. + // - The set of accessing instructions is only one of those handled above + // in isVectorPromotionViable. Generally these are the same access kinds + // which are promotable via mem2reg. + VectorType *VecTy; + Type *ElementTy; + uint64_t ElementSize; + + // The original offset of the slice currently being rewritten relative to + // the original alloca. + uint64_t BeginOffset = 0; + uint64_t EndOffset = 0; + + // The new offsets of the slice currently being rewritten relative to the + // original alloca. + uint64_t NewBeginOffset, NewEndOffset; + + uint64_t SliceSize; + bool IsSplittable = false; + bool IsSplit = false; + Use *OldUse = nullptr; + Instruction *OldPtr = nullptr; + + // Track post-rewrite users which are PHI nodes and Selects. + SmallSetVector<PHINode *, 8> &PHIUsers; + SmallSetVector<SelectInst *, 8> &SelectUsers; + + // Utility IR builder, whose name prefix is setup for each visited use, and + // the insertion point is set to point to the user. + IRBuilderTy IRB; + +public: + AllocaSliceRewriter(const DataLayout &DL, AllocaSlices &AS, SROA &Pass, + AllocaInst &OldAI, AllocaInst &NewAI, + uint64_t NewAllocaBeginOffset, + uint64_t NewAllocaEndOffset, bool IsIntegerPromotable, + VectorType *PromotableVecTy, + SmallSetVector<PHINode *, 8> &PHIUsers, + SmallSetVector<SelectInst *, 8> &SelectUsers) + : DL(DL), AS(AS), Pass(Pass), OldAI(OldAI), NewAI(NewAI), + NewAllocaBeginOffset(NewAllocaBeginOffset), + NewAllocaEndOffset(NewAllocaEndOffset), + NewAllocaTy(NewAI.getAllocatedType()), + IntTy(IsIntegerPromotable + ? Type::getIntNTy( + NewAI.getContext(), + DL.getTypeSizeInBits(NewAI.getAllocatedType())) + : nullptr), + VecTy(PromotableVecTy), + ElementTy(VecTy ? VecTy->getElementType() : nullptr), + ElementSize(VecTy ? DL.getTypeSizeInBits(ElementTy) / 8 : 0), + PHIUsers(PHIUsers), SelectUsers(SelectUsers), + IRB(NewAI.getContext(), ConstantFolder()) { + if (VecTy) { + assert((DL.getTypeSizeInBits(ElementTy) % 8) == 0 && + "Only multiple-of-8 sized vector elements are viable"); + ++NumVectorized; + } + assert((!IntTy && !VecTy) || (IntTy && !VecTy) || (!IntTy && VecTy)); + } + + bool visit(AllocaSlices::const_iterator I) { + bool CanSROA = true; + BeginOffset = I->beginOffset(); + EndOffset = I->endOffset(); + IsSplittable = I->isSplittable(); + IsSplit = + BeginOffset < NewAllocaBeginOffset || EndOffset > NewAllocaEndOffset; + DEBUG(dbgs() << " rewriting " << (IsSplit ? "split " : "")); + DEBUG(AS.printSlice(dbgs(), I, "")); + DEBUG(dbgs() << "\n"); + + // Compute the intersecting offset range. + assert(BeginOffset < NewAllocaEndOffset); + assert(EndOffset > NewAllocaBeginOffset); + NewBeginOffset = std::max(BeginOffset, NewAllocaBeginOffset); + NewEndOffset = std::min(EndOffset, NewAllocaEndOffset); + + SliceSize = NewEndOffset - NewBeginOffset; + + OldUse = I->getUse(); + OldPtr = cast<Instruction>(OldUse->get()); + + Instruction *OldUserI = cast<Instruction>(OldUse->getUser()); + IRB.SetInsertPoint(OldUserI); + IRB.SetCurrentDebugLocation(OldUserI->getDebugLoc()); + IRB.SetNamePrefix(Twine(NewAI.getName()) + "." + Twine(BeginOffset) + "."); + + CanSROA &= visit(cast<Instruction>(OldUse->getUser())); + if (VecTy || IntTy) + assert(CanSROA); + return CanSROA; + } + +private: + // Make sure the other visit overloads are visible. + using Base::visit; + + // Every instruction which can end up as a user must have a rewrite rule. + bool visitInstruction(Instruction &I) { + DEBUG(dbgs() << " !!!! Cannot rewrite: " << I << "\n"); + llvm_unreachable("No rewrite rule for this instruction!"); + } + + Value *getNewAllocaSlicePtr(IRBuilderTy &IRB, Type *PointerTy) { + // Note that the offset computation can use BeginOffset or NewBeginOffset + // interchangeably for unsplit slices. + assert(IsSplit || BeginOffset == NewBeginOffset); + uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; + +#ifndef NDEBUG + StringRef OldName = OldPtr->getName(); + // Skip through the last '.sroa.' component of the name. + size_t LastSROAPrefix = OldName.rfind(".sroa."); + if (LastSROAPrefix != StringRef::npos) { + OldName = OldName.substr(LastSROAPrefix + strlen(".sroa.")); + // Look for an SROA slice index. + size_t IndexEnd = OldName.find_first_not_of("0123456789"); + if (IndexEnd != StringRef::npos && OldName[IndexEnd] == '.') { + // Strip the index and look for the offset. + OldName = OldName.substr(IndexEnd + 1); + size_t OffsetEnd = OldName.find_first_not_of("0123456789"); + if (OffsetEnd != StringRef::npos && OldName[OffsetEnd] == '.') + // Strip the offset. + OldName = OldName.substr(OffsetEnd + 1); + } + } + // Strip any SROA suffixes as well. + OldName = OldName.substr(0, OldName.find(".sroa_")); +#endif + + return getAdjustedPtr(IRB, DL, &NewAI, + APInt(DL.getPointerTypeSizeInBits(PointerTy), Offset), + PointerTy, +#ifndef NDEBUG + Twine(OldName) + "." +#else + Twine() +#endif + ); + } + + /// \brief Compute suitable alignment to access this slice of the *new* + /// alloca. + /// + /// You can optionally pass a type to this routine and if that type's ABI + /// alignment is itself suitable, this will return zero. + unsigned getSliceAlign(Type *Ty = nullptr) { + unsigned NewAIAlign = NewAI.getAlignment(); + if (!NewAIAlign) + NewAIAlign = DL.getABITypeAlignment(NewAI.getAllocatedType()); + unsigned Align = + MinAlign(NewAIAlign, NewBeginOffset - NewAllocaBeginOffset); + return (Ty && Align == DL.getABITypeAlignment(Ty)) ? 0 : Align; + } + + unsigned getIndex(uint64_t Offset) { + assert(VecTy && "Can only call getIndex when rewriting a vector"); + uint64_t RelOffset = Offset - NewAllocaBeginOffset; + assert(RelOffset / ElementSize < UINT32_MAX && "Index out of bounds"); + uint32_t Index = RelOffset / ElementSize; + assert(Index * ElementSize == RelOffset); + return Index; + } + + void deleteIfTriviallyDead(Value *V) { + Instruction *I = cast<Instruction>(V); + if (isInstructionTriviallyDead(I)) + Pass.DeadInsts.insert(I); + } + + Value *rewriteVectorizedLoadInst() { + unsigned BeginIndex = getIndex(NewBeginOffset); + unsigned EndIndex = getIndex(NewEndOffset); + assert(EndIndex > BeginIndex && "Empty vector!"); + + Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + return extractVector(IRB, V, BeginIndex, EndIndex, "vec"); + } + + Value *rewriteIntegerLoad(LoadInst &LI) { + assert(IntTy && "We cannot insert an integer to the alloca"); + assert(!LI.isVolatile()); + Value *V = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + V = convertValue(DL, IRB, V, IntTy); + assert(NewBeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); + uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; + if (Offset > 0 || NewEndOffset < NewAllocaEndOffset) { + IntegerType *ExtractTy = Type::getIntNTy(LI.getContext(), SliceSize * 8); + V = extractInteger(DL, IRB, V, ExtractTy, Offset, "extract"); + } + // It is possible that the extracted type is not the load type. This + // happens if there is a load past the end of the alloca, and as + // a consequence the slice is narrower but still a candidate for integer + // lowering. To handle this case, we just zero extend the extracted + // integer. + assert(cast<IntegerType>(LI.getType())->getBitWidth() >= SliceSize * 8 && + "Can only handle an extract for an overly wide load"); + if (cast<IntegerType>(LI.getType())->getBitWidth() > SliceSize * 8) + V = IRB.CreateZExt(V, LI.getType()); + return V; + } + + bool visitLoadInst(LoadInst &LI) { + DEBUG(dbgs() << " original: " << LI << "\n"); + Value *OldOp = LI.getOperand(0); + assert(OldOp == OldPtr); + + unsigned AS = LI.getPointerAddressSpace(); + + Type *TargetTy = IsSplit ? Type::getIntNTy(LI.getContext(), SliceSize * 8) + : LI.getType(); + const bool IsLoadPastEnd = DL.getTypeStoreSize(TargetTy) > SliceSize; + bool IsPtrAdjusted = false; + Value *V; + if (VecTy) { + V = rewriteVectorizedLoadInst(); + } else if (IntTy && LI.getType()->isIntegerTy()) { + V = rewriteIntegerLoad(LI); + } else if (NewBeginOffset == NewAllocaBeginOffset && + NewEndOffset == NewAllocaEndOffset && + (canConvertValue(DL, NewAllocaTy, TargetTy) || + (IsLoadPastEnd && NewAllocaTy->isIntegerTy() && + TargetTy->isIntegerTy()))) { + LoadInst *NewLI = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), + LI.isVolatile(), LI.getName()); + if (LI.isVolatile()) + NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); + + // Any !nonnull metadata or !range metadata on the old load is also valid + // on the new load. This is even true in some cases even when the loads + // are different types, for example by mapping !nonnull metadata to + // !range metadata by modeling the null pointer constant converted to the + // integer type. + // FIXME: Add support for range metadata here. Currently the utilities + // for this don't propagate range metadata in trivial cases from one + // integer load to another, don't handle non-addrspace-0 null pointers + // correctly, and don't have any support for mapping ranges as the + // integer type becomes winder or narrower. + if (MDNode *N = LI.getMetadata(LLVMContext::MD_nonnull)) + copyNonnullMetadata(LI, N, *NewLI); + + // Try to preserve nonnull metadata + V = NewLI; + + // If this is an integer load past the end of the slice (which means the + // bytes outside the slice are undef or this load is dead) just forcibly + // fix the integer size with correct handling of endianness. + if (auto *AITy = dyn_cast<IntegerType>(NewAllocaTy)) + if (auto *TITy = dyn_cast<IntegerType>(TargetTy)) + if (AITy->getBitWidth() < TITy->getBitWidth()) { + V = IRB.CreateZExt(V, TITy, "load.ext"); + if (DL.isBigEndian()) + V = IRB.CreateShl(V, TITy->getBitWidth() - AITy->getBitWidth(), + "endian_shift"); + } + } else { + Type *LTy = TargetTy->getPointerTo(AS); + LoadInst *NewLI = IRB.CreateAlignedLoad(getNewAllocaSlicePtr(IRB, LTy), + getSliceAlign(TargetTy), + LI.isVolatile(), LI.getName()); + if (LI.isVolatile()) + NewLI->setAtomic(LI.getOrdering(), LI.getSyncScopeID()); + + V = NewLI; + IsPtrAdjusted = true; + } + V = convertValue(DL, IRB, V, TargetTy); + + if (IsSplit) { + assert(!LI.isVolatile()); + assert(LI.getType()->isIntegerTy() && + "Only integer type loads and stores are split"); + assert(SliceSize < DL.getTypeStoreSize(LI.getType()) && + "Split load isn't smaller than original load"); + assert(LI.getType()->getIntegerBitWidth() == + DL.getTypeStoreSizeInBits(LI.getType()) && + "Non-byte-multiple bit width"); + // Move the insertion point just past the load so that we can refer to it. + IRB.SetInsertPoint(&*std::next(BasicBlock::iterator(&LI))); + // Create a placeholder value with the same type as LI to use as the + // basis for the new value. This allows us to replace the uses of LI with + // the computed value, and then replace the placeholder with LI, leaving + // LI only used for this computation. + Value *Placeholder = + new LoadInst(UndefValue::get(LI.getType()->getPointerTo(AS))); + V = insertInteger(DL, IRB, Placeholder, V, NewBeginOffset - BeginOffset, + "insert"); + LI.replaceAllUsesWith(V); + Placeholder->replaceAllUsesWith(&LI); + Placeholder->deleteValue(); + } else { + LI.replaceAllUsesWith(V); + } + + Pass.DeadInsts.insert(&LI); + deleteIfTriviallyDead(OldOp); + DEBUG(dbgs() << " to: " << *V << "\n"); + return !LI.isVolatile() && !IsPtrAdjusted; + } + + bool rewriteVectorizedStoreInst(Value *V, StoreInst &SI, Value *OldOp) { + if (V->getType() != VecTy) { + unsigned BeginIndex = getIndex(NewBeginOffset); + unsigned EndIndex = getIndex(NewEndOffset); + assert(EndIndex > BeginIndex && "Empty vector!"); + unsigned NumElements = EndIndex - BeginIndex; + assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); + Type *SliceTy = (NumElements == 1) + ? ElementTy + : VectorType::get(ElementTy, NumElements); + if (V->getType() != SliceTy) + V = convertValue(DL, IRB, V, SliceTy); + + // Mix in the existing elements. + Value *Old = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + V = insertVector(IRB, Old, V, BeginIndex, "vec"); + } + StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); + Pass.DeadInsts.insert(&SI); + + (void)Store; + DEBUG(dbgs() << " to: " << *Store << "\n"); + return true; + } + + bool rewriteIntegerStore(Value *V, StoreInst &SI) { + assert(IntTy && "We cannot extract an integer from the alloca"); + assert(!SI.isVolatile()); + if (DL.getTypeSizeInBits(V->getType()) != IntTy->getBitWidth()) { + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Old = convertValue(DL, IRB, Old, IntTy); + assert(BeginOffset >= NewAllocaBeginOffset && "Out of bounds offset"); + uint64_t Offset = BeginOffset - NewAllocaBeginOffset; + V = insertInteger(DL, IRB, Old, SI.getValueOperand(), Offset, "insert"); + } + V = convertValue(DL, IRB, V, NewAllocaTy); + StoreInst *Store = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment()); + Store->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); + Pass.DeadInsts.insert(&SI); + DEBUG(dbgs() << " to: " << *Store << "\n"); + return true; + } + + bool visitStoreInst(StoreInst &SI) { + DEBUG(dbgs() << " original: " << SI << "\n"); + Value *OldOp = SI.getOperand(1); + assert(OldOp == OldPtr); + + Value *V = SI.getValueOperand(); + + // Strip all inbounds GEPs and pointer casts to try to dig out any root + // alloca that should be re-examined after promoting this alloca. + if (V->getType()->isPointerTy()) + if (AllocaInst *AI = dyn_cast<AllocaInst>(V->stripInBoundsOffsets())) + Pass.PostPromotionWorklist.insert(AI); + + if (SliceSize < DL.getTypeStoreSize(V->getType())) { + assert(!SI.isVolatile()); + assert(V->getType()->isIntegerTy() && + "Only integer type loads and stores are split"); + assert(V->getType()->getIntegerBitWidth() == + DL.getTypeStoreSizeInBits(V->getType()) && + "Non-byte-multiple bit width"); + IntegerType *NarrowTy = Type::getIntNTy(SI.getContext(), SliceSize * 8); + V = extractInteger(DL, IRB, V, NarrowTy, NewBeginOffset - BeginOffset, + "extract"); + } + + if (VecTy) + return rewriteVectorizedStoreInst(V, SI, OldOp); + if (IntTy && V->getType()->isIntegerTy()) + return rewriteIntegerStore(V, SI); + + const bool IsStorePastEnd = DL.getTypeStoreSize(V->getType()) > SliceSize; + StoreInst *NewSI; + if (NewBeginOffset == NewAllocaBeginOffset && + NewEndOffset == NewAllocaEndOffset && + (canConvertValue(DL, V->getType(), NewAllocaTy) || + (IsStorePastEnd && NewAllocaTy->isIntegerTy() && + V->getType()->isIntegerTy()))) { + // If this is an integer store past the end of slice (and thus the bytes + // past that point are irrelevant or this is unreachable), truncate the + // value prior to storing. + if (auto *VITy = dyn_cast<IntegerType>(V->getType())) + if (auto *AITy = dyn_cast<IntegerType>(NewAllocaTy)) + if (VITy->getBitWidth() > AITy->getBitWidth()) { + if (DL.isBigEndian()) + V = IRB.CreateLShr(V, VITy->getBitWidth() - AITy->getBitWidth(), + "endian_shift"); + V = IRB.CreateTrunc(V, AITy, "load.trunc"); + } + + V = convertValue(DL, IRB, V, NewAllocaTy); + NewSI = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), + SI.isVolatile()); + } else { + unsigned AS = SI.getPointerAddressSpace(); + Value *NewPtr = getNewAllocaSlicePtr(IRB, V->getType()->getPointerTo(AS)); + NewSI = IRB.CreateAlignedStore(V, NewPtr, getSliceAlign(V->getType()), + SI.isVolatile()); + } + NewSI->copyMetadata(SI, LLVMContext::MD_mem_parallel_loop_access); + if (SI.isVolatile()) + NewSI->setAtomic(SI.getOrdering(), SI.getSyncScopeID()); + Pass.DeadInsts.insert(&SI); + deleteIfTriviallyDead(OldOp); + + DEBUG(dbgs() << " to: " << *NewSI << "\n"); + return NewSI->getPointerOperand() == &NewAI && !SI.isVolatile(); + } + + /// \brief Compute an integer value from splatting an i8 across the given + /// number of bytes. + /// + /// Note that this routine assumes an i8 is a byte. If that isn't true, don't + /// call this routine. + /// FIXME: Heed the advice above. + /// + /// \param V The i8 value to splat. + /// \param Size The number of bytes in the output (assuming i8 is one byte) + Value *getIntegerSplat(Value *V, unsigned Size) { + assert(Size > 0 && "Expected a positive number of bytes."); + IntegerType *VTy = cast<IntegerType>(V->getType()); + assert(VTy->getBitWidth() == 8 && "Expected an i8 value for the byte"); + if (Size == 1) + return V; + + Type *SplatIntTy = Type::getIntNTy(VTy->getContext(), Size * 8); + V = IRB.CreateMul( + IRB.CreateZExt(V, SplatIntTy, "zext"), + ConstantExpr::getUDiv( + Constant::getAllOnesValue(SplatIntTy), + ConstantExpr::getZExt(Constant::getAllOnesValue(V->getType()), + SplatIntTy)), + "isplat"); + return V; + } + + /// \brief Compute a vector splat for a given element value. + Value *getVectorSplat(Value *V, unsigned NumElements) { + V = IRB.CreateVectorSplat(NumElements, V, "vsplat"); + DEBUG(dbgs() << " splat: " << *V << "\n"); + return V; + } + + bool visitMemSetInst(MemSetInst &II) { + DEBUG(dbgs() << " original: " << II << "\n"); + assert(II.getRawDest() == OldPtr); + + // If the memset has a variable size, it cannot be split, just adjust the + // pointer to the new alloca. + if (!isa<Constant>(II.getLength())) { + assert(!IsSplit); + assert(NewBeginOffset == BeginOffset); + II.setDest(getNewAllocaSlicePtr(IRB, OldPtr->getType())); + Type *CstTy = II.getAlignmentCst()->getType(); + II.setAlignment(ConstantInt::get(CstTy, getSliceAlign())); + + deleteIfTriviallyDead(OldPtr); + return false; + } + + // Record this instruction for deletion. + Pass.DeadInsts.insert(&II); + + Type *AllocaTy = NewAI.getAllocatedType(); + Type *ScalarTy = AllocaTy->getScalarType(); + + // If this doesn't map cleanly onto the alloca type, and that type isn't + // a single value type, just emit a memset. + if (!VecTy && !IntTy && + (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || + SliceSize != DL.getTypeStoreSize(AllocaTy) || + !AllocaTy->isSingleValueType() || + !DL.isLegalInteger(DL.getTypeSizeInBits(ScalarTy)) || + DL.getTypeSizeInBits(ScalarTy) % 8 != 0)) { + Type *SizeTy = II.getLength()->getType(); + Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset); + CallInst *New = IRB.CreateMemSet( + getNewAllocaSlicePtr(IRB, OldPtr->getType()), II.getValue(), Size, + getSliceAlign(), II.isVolatile()); + (void)New; + DEBUG(dbgs() << " to: " << *New << "\n"); + return false; + } + + // If we can represent this as a simple value, we have to build the actual + // value to store, which requires expanding the byte present in memset to + // a sensible representation for the alloca type. This is essentially + // splatting the byte to a sufficiently wide integer, splatting it across + // any desired vector width, and bitcasting to the final type. + Value *V; + + if (VecTy) { + // If this is a memset of a vectorized alloca, insert it. + assert(ElementTy == ScalarTy); + + unsigned BeginIndex = getIndex(NewBeginOffset); + unsigned EndIndex = getIndex(NewEndOffset); + assert(EndIndex > BeginIndex && "Empty vector!"); + unsigned NumElements = EndIndex - BeginIndex; + assert(NumElements <= VecTy->getNumElements() && "Too many elements!"); + + Value *Splat = + getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ElementTy) / 8); + Splat = convertValue(DL, IRB, Splat, ElementTy); + if (NumElements > 1) + Splat = getVectorSplat(Splat, NumElements); + + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + V = insertVector(IRB, Old, Splat, BeginIndex, "vec"); + } else if (IntTy) { + // If this is a memset on an alloca where we can widen stores, insert the + // set integer. + assert(!II.isVolatile()); + + uint64_t Size = NewEndOffset - NewBeginOffset; + V = getIntegerSplat(II.getValue(), Size); + + if (IntTy && (BeginOffset != NewAllocaBeginOffset || + EndOffset != NewAllocaBeginOffset)) { + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Old = convertValue(DL, IRB, Old, IntTy); + uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; + V = insertInteger(DL, IRB, Old, V, Offset, "insert"); + } else { + assert(V->getType() == IntTy && + "Wrong type for an alloca wide integer!"); + } + V = convertValue(DL, IRB, V, AllocaTy); + } else { + // Established these invariants above. + assert(NewBeginOffset == NewAllocaBeginOffset); + assert(NewEndOffset == NewAllocaEndOffset); + + V = getIntegerSplat(II.getValue(), DL.getTypeSizeInBits(ScalarTy) / 8); + if (VectorType *AllocaVecTy = dyn_cast<VectorType>(AllocaTy)) + V = getVectorSplat(V, AllocaVecTy->getNumElements()); + + V = convertValue(DL, IRB, V, AllocaTy); + } + + Value *New = IRB.CreateAlignedStore(V, &NewAI, NewAI.getAlignment(), + II.isVolatile()); + (void)New; + DEBUG(dbgs() << " to: " << *New << "\n"); + return !II.isVolatile(); + } + + bool visitMemTransferInst(MemTransferInst &II) { + // Rewriting of memory transfer instructions can be a bit tricky. We break + // them into two categories: split intrinsics and unsplit intrinsics. + + DEBUG(dbgs() << " original: " << II << "\n"); + + bool IsDest = &II.getRawDestUse() == OldUse; + assert((IsDest && II.getRawDest() == OldPtr) || + (!IsDest && II.getRawSource() == OldPtr)); + + unsigned SliceAlign = getSliceAlign(); + + // For unsplit intrinsics, we simply modify the source and destination + // pointers in place. This isn't just an optimization, it is a matter of + // correctness. With unsplit intrinsics we may be dealing with transfers + // within a single alloca before SROA ran, or with transfers that have + // a variable length. We may also be dealing with memmove instead of + // memcpy, and so simply updating the pointers is the necessary for us to + // update both source and dest of a single call. + if (!IsSplittable) { + Value *AdjustedPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); + if (IsDest) + II.setDest(AdjustedPtr); + else + II.setSource(AdjustedPtr); + + if (II.getAlignment() > SliceAlign) { + Type *CstTy = II.getAlignmentCst()->getType(); + II.setAlignment( + ConstantInt::get(CstTy, MinAlign(II.getAlignment(), SliceAlign))); + } + + DEBUG(dbgs() << " to: " << II << "\n"); + deleteIfTriviallyDead(OldPtr); + return false; + } + // For split transfer intrinsics we have an incredibly useful assurance: + // the source and destination do not reside within the same alloca, and at + // least one of them does not escape. This means that we can replace + // memmove with memcpy, and we don't need to worry about all manner of + // downsides to splitting and transforming the operations. + + // If this doesn't map cleanly onto the alloca type, and that type isn't + // a single value type, just emit a memcpy. + bool EmitMemCpy = + !VecTy && !IntTy && + (BeginOffset > NewAllocaBeginOffset || EndOffset < NewAllocaEndOffset || + SliceSize != DL.getTypeStoreSize(NewAI.getAllocatedType()) || + !NewAI.getAllocatedType()->isSingleValueType()); + + // If we're just going to emit a memcpy, the alloca hasn't changed, and the + // size hasn't been shrunk based on analysis of the viable range, this is + // a no-op. + if (EmitMemCpy && &OldAI == &NewAI) { + // Ensure the start lines up. + assert(NewBeginOffset == BeginOffset); + + // Rewrite the size as needed. + if (NewEndOffset != EndOffset) + II.setLength(ConstantInt::get(II.getLength()->getType(), + NewEndOffset - NewBeginOffset)); + return false; + } + // Record this instruction for deletion. + Pass.DeadInsts.insert(&II); + + // Strip all inbounds GEPs and pointer casts to try to dig out any root + // alloca that should be re-examined after rewriting this instruction. + Value *OtherPtr = IsDest ? II.getRawSource() : II.getRawDest(); + if (AllocaInst *AI = + dyn_cast<AllocaInst>(OtherPtr->stripInBoundsOffsets())) { + assert(AI != &OldAI && AI != &NewAI && + "Splittable transfers cannot reach the same alloca on both ends."); + Pass.Worklist.insert(AI); + } + + Type *OtherPtrTy = OtherPtr->getType(); + unsigned OtherAS = OtherPtrTy->getPointerAddressSpace(); + + // Compute the relative offset for the other pointer within the transfer. + unsigned IntPtrWidth = DL.getPointerSizeInBits(OtherAS); + APInt OtherOffset(IntPtrWidth, NewBeginOffset - BeginOffset); + unsigned OtherAlign = MinAlign(II.getAlignment() ? II.getAlignment() : 1, + OtherOffset.zextOrTrunc(64).getZExtValue()); + + if (EmitMemCpy) { + // Compute the other pointer, folding as much as possible to produce + // a single, simple GEP in most cases. + OtherPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy, + OtherPtr->getName() + "."); + + Value *OurPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); + Type *SizeTy = II.getLength()->getType(); + Constant *Size = ConstantInt::get(SizeTy, NewEndOffset - NewBeginOffset); + + CallInst *New = IRB.CreateMemCpy( + IsDest ? OurPtr : OtherPtr, IsDest ? OtherPtr : OurPtr, Size, + MinAlign(SliceAlign, OtherAlign), II.isVolatile()); + (void)New; + DEBUG(dbgs() << " to: " << *New << "\n"); + return false; + } + + bool IsWholeAlloca = NewBeginOffset == NewAllocaBeginOffset && + NewEndOffset == NewAllocaEndOffset; + uint64_t Size = NewEndOffset - NewBeginOffset; + unsigned BeginIndex = VecTy ? getIndex(NewBeginOffset) : 0; + unsigned EndIndex = VecTy ? getIndex(NewEndOffset) : 0; + unsigned NumElements = EndIndex - BeginIndex; + IntegerType *SubIntTy = + IntTy ? Type::getIntNTy(IntTy->getContext(), Size * 8) : nullptr; + + // Reset the other pointer type to match the register type we're going to + // use, but using the address space of the original other pointer. + if (VecTy && !IsWholeAlloca) { + if (NumElements == 1) + OtherPtrTy = VecTy->getElementType(); + else + OtherPtrTy = VectorType::get(VecTy->getElementType(), NumElements); + + OtherPtrTy = OtherPtrTy->getPointerTo(OtherAS); + } else if (IntTy && !IsWholeAlloca) { + OtherPtrTy = SubIntTy->getPointerTo(OtherAS); + } else { + OtherPtrTy = NewAllocaTy->getPointerTo(OtherAS); + } + + Value *SrcPtr = getAdjustedPtr(IRB, DL, OtherPtr, OtherOffset, OtherPtrTy, + OtherPtr->getName() + "."); + unsigned SrcAlign = OtherAlign; + Value *DstPtr = &NewAI; + unsigned DstAlign = SliceAlign; + if (!IsDest) { + std::swap(SrcPtr, DstPtr); + std::swap(SrcAlign, DstAlign); + } + + Value *Src; + if (VecTy && !IsWholeAlloca && !IsDest) { + Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + Src = extractVector(IRB, Src, BeginIndex, EndIndex, "vec"); + } else if (IntTy && !IsWholeAlloca && !IsDest) { + Src = IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "load"); + Src = convertValue(DL, IRB, Src, IntTy); + uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; + Src = extractInteger(DL, IRB, Src, SubIntTy, Offset, "extract"); + } else { + Src = + IRB.CreateAlignedLoad(SrcPtr, SrcAlign, II.isVolatile(), "copyload"); + } + + if (VecTy && !IsWholeAlloca && IsDest) { + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Src = insertVector(IRB, Old, Src, BeginIndex, "vec"); + } else if (IntTy && !IsWholeAlloca && IsDest) { + Value *Old = + IRB.CreateAlignedLoad(&NewAI, NewAI.getAlignment(), "oldload"); + Old = convertValue(DL, IRB, Old, IntTy); + uint64_t Offset = NewBeginOffset - NewAllocaBeginOffset; + Src = insertInteger(DL, IRB, Old, Src, Offset, "insert"); + Src = convertValue(DL, IRB, Src, NewAllocaTy); + } + + StoreInst *Store = cast<StoreInst>( + IRB.CreateAlignedStore(Src, DstPtr, DstAlign, II.isVolatile())); + (void)Store; + DEBUG(dbgs() << " to: " << *Store << "\n"); + return !II.isVolatile(); + } + + bool visitIntrinsicInst(IntrinsicInst &II) { + assert(II.getIntrinsicID() == Intrinsic::lifetime_start || + II.getIntrinsicID() == Intrinsic::lifetime_end); + DEBUG(dbgs() << " original: " << II << "\n"); + assert(II.getArgOperand(1) == OldPtr); + + // Record this instruction for deletion. + Pass.DeadInsts.insert(&II); + + // Lifetime intrinsics are only promotable if they cover the whole alloca. + // Therefore, we drop lifetime intrinsics which don't cover the whole + // alloca. + // (In theory, intrinsics which partially cover an alloca could be + // promoted, but PromoteMemToReg doesn't handle that case.) + // FIXME: Check whether the alloca is promotable before dropping the + // lifetime intrinsics? + if (NewBeginOffset != NewAllocaBeginOffset || + NewEndOffset != NewAllocaEndOffset) + return true; + + ConstantInt *Size = + ConstantInt::get(cast<IntegerType>(II.getArgOperand(0)->getType()), + NewEndOffset - NewBeginOffset); + Value *Ptr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); + Value *New; + if (II.getIntrinsicID() == Intrinsic::lifetime_start) + New = IRB.CreateLifetimeStart(Ptr, Size); + else + New = IRB.CreateLifetimeEnd(Ptr, Size); + + (void)New; + DEBUG(dbgs() << " to: " << *New << "\n"); + + return true; + } + + bool visitPHINode(PHINode &PN) { + DEBUG(dbgs() << " original: " << PN << "\n"); + assert(BeginOffset >= NewAllocaBeginOffset && "PHIs are unsplittable"); + assert(EndOffset <= NewAllocaEndOffset && "PHIs are unsplittable"); + + // We would like to compute a new pointer in only one place, but have it be + // as local as possible to the PHI. To do that, we re-use the location of + // the old pointer, which necessarily must be in the right position to + // dominate the PHI. + IRBuilderTy PtrBuilder(IRB); + if (isa<PHINode>(OldPtr)) + PtrBuilder.SetInsertPoint(&*OldPtr->getParent()->getFirstInsertionPt()); + else + PtrBuilder.SetInsertPoint(OldPtr); + PtrBuilder.SetCurrentDebugLocation(OldPtr->getDebugLoc()); + + Value *NewPtr = getNewAllocaSlicePtr(PtrBuilder, OldPtr->getType()); + // Replace the operands which were using the old pointer. + std::replace(PN.op_begin(), PN.op_end(), cast<Value>(OldPtr), NewPtr); + + DEBUG(dbgs() << " to: " << PN << "\n"); + deleteIfTriviallyDead(OldPtr); + + // PHIs can't be promoted on their own, but often can be speculated. We + // check the speculation outside of the rewriter so that we see the + // fully-rewritten alloca. + PHIUsers.insert(&PN); + return true; + } + + bool visitSelectInst(SelectInst &SI) { + DEBUG(dbgs() << " original: " << SI << "\n"); + assert((SI.getTrueValue() == OldPtr || SI.getFalseValue() == OldPtr) && + "Pointer isn't an operand!"); + assert(BeginOffset >= NewAllocaBeginOffset && "Selects are unsplittable"); + assert(EndOffset <= NewAllocaEndOffset && "Selects are unsplittable"); + + Value *NewPtr = getNewAllocaSlicePtr(IRB, OldPtr->getType()); + // Replace the operands which were using the old pointer. + if (SI.getOperand(1) == OldPtr) + SI.setOperand(1, NewPtr); + if (SI.getOperand(2) == OldPtr) + SI.setOperand(2, NewPtr); + + DEBUG(dbgs() << " to: " << SI << "\n"); + deleteIfTriviallyDead(OldPtr); + + // Selects can't be promoted on their own, but often can be speculated. We + // check the speculation outside of the rewriter so that we see the + // fully-rewritten alloca. + SelectUsers.insert(&SI); + return true; + } +}; + +namespace { + +/// \brief Visitor to rewrite aggregate loads and stores as scalar. +/// +/// This pass aggressively rewrites all aggregate loads and stores on +/// a particular pointer (or any pointer derived from it which we can identify) +/// with scalar loads and stores. +class AggLoadStoreRewriter : public InstVisitor<AggLoadStoreRewriter, bool> { + // Befriend the base class so it can delegate to private visit methods. + friend class InstVisitor<AggLoadStoreRewriter, bool>; + + /// Queue of pointer uses to analyze and potentially rewrite. + SmallVector<Use *, 8> Queue; + + /// Set to prevent us from cycling with phi nodes and loops. + SmallPtrSet<User *, 8> Visited; + + /// The current pointer use being rewritten. This is used to dig up the used + /// value (as opposed to the user). + Use *U; + +public: + /// Rewrite loads and stores through a pointer and all pointers derived from + /// it. + bool rewrite(Instruction &I) { + DEBUG(dbgs() << " Rewriting FCA loads and stores...\n"); + enqueueUsers(I); + bool Changed = false; + while (!Queue.empty()) { + U = Queue.pop_back_val(); + Changed |= visit(cast<Instruction>(U->getUser())); + } + return Changed; + } + +private: + /// Enqueue all the users of the given instruction for further processing. + /// This uses a set to de-duplicate users. + void enqueueUsers(Instruction &I) { + for (Use &U : I.uses()) + if (Visited.insert(U.getUser()).second) + Queue.push_back(&U); + } + + // Conservative default is to not rewrite anything. + bool visitInstruction(Instruction &I) { return false; } + + /// \brief Generic recursive split emission class. + template <typename Derived> class OpSplitter { + protected: + /// The builder used to form new instructions. + IRBuilderTy IRB; + + /// The indices which to be used with insert- or extractvalue to select the + /// appropriate value within the aggregate. + SmallVector<unsigned, 4> Indices; + + /// The indices to a GEP instruction which will move Ptr to the correct slot + /// within the aggregate. + SmallVector<Value *, 4> GEPIndices; + + /// The base pointer of the original op, used as a base for GEPing the + /// split operations. + Value *Ptr; + + /// Initialize the splitter with an insertion point, Ptr and start with a + /// single zero GEP index. + OpSplitter(Instruction *InsertionPoint, Value *Ptr) + : IRB(InsertionPoint), GEPIndices(1, IRB.getInt32(0)), Ptr(Ptr) {} + + public: + /// \brief Generic recursive split emission routine. + /// + /// This method recursively splits an aggregate op (load or store) into + /// scalar or vector ops. It splits recursively until it hits a single value + /// and emits that single value operation via the template argument. + /// + /// The logic of this routine relies on GEPs and insertvalue and + /// extractvalue all operating with the same fundamental index list, merely + /// formatted differently (GEPs need actual values). + /// + /// \param Ty The type being split recursively into smaller ops. + /// \param Agg The aggregate value being built up or stored, depending on + /// whether this is splitting a load or a store respectively. + void emitSplitOps(Type *Ty, Value *&Agg, const Twine &Name) { + if (Ty->isSingleValueType()) + return static_cast<Derived *>(this)->emitFunc(Ty, Agg, Name); + + if (ArrayType *ATy = dyn_cast<ArrayType>(Ty)) { + unsigned OldSize = Indices.size(); + (void)OldSize; + for (unsigned Idx = 0, Size = ATy->getNumElements(); Idx != Size; + ++Idx) { + assert(Indices.size() == OldSize && "Did not return to the old size"); + Indices.push_back(Idx); + GEPIndices.push_back(IRB.getInt32(Idx)); + emitSplitOps(ATy->getElementType(), Agg, Name + "." + Twine(Idx)); + GEPIndices.pop_back(); + Indices.pop_back(); + } + return; + } + + if (StructType *STy = dyn_cast<StructType>(Ty)) { + unsigned OldSize = Indices.size(); + (void)OldSize; + for (unsigned Idx = 0, Size = STy->getNumElements(); Idx != Size; + ++Idx) { + assert(Indices.size() == OldSize && "Did not return to the old size"); + Indices.push_back(Idx); + GEPIndices.push_back(IRB.getInt32(Idx)); + emitSplitOps(STy->getElementType(Idx), Agg, Name + "." + Twine(Idx)); + GEPIndices.pop_back(); + Indices.pop_back(); + } + return; + } + + llvm_unreachable("Only arrays and structs are aggregate loadable types"); + } + }; + + struct LoadOpSplitter : public OpSplitter<LoadOpSplitter> { + LoadOpSplitter(Instruction *InsertionPoint, Value *Ptr) + : OpSplitter<LoadOpSplitter>(InsertionPoint, Ptr) {} + + /// Emit a leaf load of a single value. This is called at the leaves of the + /// recursive emission to actually load values. + void emitFunc(Type *Ty, Value *&Agg, const Twine &Name) { + assert(Ty->isSingleValueType()); + // Load the single value and insert it using the indices. + Value *GEP = + IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); + Value *Load = IRB.CreateLoad(GEP, Name + ".load"); + Agg = IRB.CreateInsertValue(Agg, Load, Indices, Name + ".insert"); + DEBUG(dbgs() << " to: " << *Load << "\n"); + } + }; + + bool visitLoadInst(LoadInst &LI) { + assert(LI.getPointerOperand() == *U); + if (!LI.isSimple() || LI.getType()->isSingleValueType()) + return false; + + // We have an aggregate being loaded, split it apart. + DEBUG(dbgs() << " original: " << LI << "\n"); + LoadOpSplitter Splitter(&LI, *U); + Value *V = UndefValue::get(LI.getType()); + Splitter.emitSplitOps(LI.getType(), V, LI.getName() + ".fca"); + LI.replaceAllUsesWith(V); + LI.eraseFromParent(); + return true; + } + + struct StoreOpSplitter : public OpSplitter<StoreOpSplitter> { + StoreOpSplitter(Instruction *InsertionPoint, Value *Ptr) + : OpSplitter<StoreOpSplitter>(InsertionPoint, Ptr) {} + + /// Emit a leaf store of a single value. This is called at the leaves of the + /// recursive emission to actually produce stores. + void emitFunc(Type *Ty, Value *&Agg, const Twine &Name) { + assert(Ty->isSingleValueType()); + // Extract the single value and store it using the indices. + // + // The gep and extractvalue values are factored out of the CreateStore + // call to make the output independent of the argument evaluation order. + Value *ExtractValue = + IRB.CreateExtractValue(Agg, Indices, Name + ".extract"); + Value *InBoundsGEP = + IRB.CreateInBoundsGEP(nullptr, Ptr, GEPIndices, Name + ".gep"); + Value *Store = IRB.CreateStore(ExtractValue, InBoundsGEP); + (void)Store; + DEBUG(dbgs() << " to: " << *Store << "\n"); + } + }; + + bool visitStoreInst(StoreInst &SI) { + if (!SI.isSimple() || SI.getPointerOperand() != *U) + return false; + Value *V = SI.getValueOperand(); + if (V->getType()->isSingleValueType()) + return false; + + // We have an aggregate being stored, split it apart. + DEBUG(dbgs() << " original: " << SI << "\n"); + StoreOpSplitter Splitter(&SI, *U); + Splitter.emitSplitOps(V->getType(), V, V->getName() + ".fca"); + SI.eraseFromParent(); + return true; + } + + bool visitBitCastInst(BitCastInst &BC) { + enqueueUsers(BC); + return false; + } + + bool visitGetElementPtrInst(GetElementPtrInst &GEPI) { + enqueueUsers(GEPI); + return false; + } + + bool visitPHINode(PHINode &PN) { + enqueueUsers(PN); + return false; + } + + bool visitSelectInst(SelectInst &SI) { + enqueueUsers(SI); + return false; + } +}; + +} // end anonymous namespace + +/// \brief Strip aggregate type wrapping. +/// +/// This removes no-op aggregate types wrapping an underlying type. It will +/// strip as many layers of types as it can without changing either the type +/// size or the allocated size. +static Type *stripAggregateTypeWrapping(const DataLayout &DL, Type *Ty) { + if (Ty->isSingleValueType()) + return Ty; + + uint64_t AllocSize = DL.getTypeAllocSize(Ty); + uint64_t TypeSize = DL.getTypeSizeInBits(Ty); + + Type *InnerTy; + if (ArrayType *ArrTy = dyn_cast<ArrayType>(Ty)) { + InnerTy = ArrTy->getElementType(); + } else if (StructType *STy = dyn_cast<StructType>(Ty)) { + const StructLayout *SL = DL.getStructLayout(STy); + unsigned Index = SL->getElementContainingOffset(0); + InnerTy = STy->getElementType(Index); + } else { + return Ty; + } + + if (AllocSize > DL.getTypeAllocSize(InnerTy) || + TypeSize > DL.getTypeSizeInBits(InnerTy)) + return Ty; + + return stripAggregateTypeWrapping(DL, InnerTy); +} + +/// \brief Try to find a partition of the aggregate type passed in for a given +/// offset and size. +/// +/// This recurses through the aggregate type and tries to compute a subtype +/// based on the offset and size. When the offset and size span a sub-section +/// of an array, it will even compute a new array type for that sub-section, +/// and the same for structs. +/// +/// Note that this routine is very strict and tries to find a partition of the +/// type which produces the *exact* right offset and size. It is not forgiving +/// when the size or offset cause either end of type-based partition to be off. +/// Also, this is a best-effort routine. It is reasonable to give up and not +/// return a type if necessary. +static Type *getTypePartition(const DataLayout &DL, Type *Ty, uint64_t Offset, + uint64_t Size) { + if (Offset == 0 && DL.getTypeAllocSize(Ty) == Size) + return stripAggregateTypeWrapping(DL, Ty); + if (Offset > DL.getTypeAllocSize(Ty) || + (DL.getTypeAllocSize(Ty) - Offset) < Size) + return nullptr; + + if (SequentialType *SeqTy = dyn_cast<SequentialType>(Ty)) { + Type *ElementTy = SeqTy->getElementType(); + uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); + uint64_t NumSkippedElements = Offset / ElementSize; + if (NumSkippedElements >= SeqTy->getNumElements()) + return nullptr; + Offset -= NumSkippedElements * ElementSize; + + // First check if we need to recurse. + if (Offset > 0 || Size < ElementSize) { + // Bail if the partition ends in a different array element. + if ((Offset + Size) > ElementSize) + return nullptr; + // Recurse through the element type trying to peel off offset bytes. + return getTypePartition(DL, ElementTy, Offset, Size); + } + assert(Offset == 0); + + if (Size == ElementSize) + return stripAggregateTypeWrapping(DL, ElementTy); + assert(Size > ElementSize); + uint64_t NumElements = Size / ElementSize; + if (NumElements * ElementSize != Size) + return nullptr; + return ArrayType::get(ElementTy, NumElements); + } + + StructType *STy = dyn_cast<StructType>(Ty); + if (!STy) + return nullptr; + + const StructLayout *SL = DL.getStructLayout(STy); + if (Offset >= SL->getSizeInBytes()) + return nullptr; + uint64_t EndOffset = Offset + Size; + if (EndOffset > SL->getSizeInBytes()) + return nullptr; + + unsigned Index = SL->getElementContainingOffset(Offset); + Offset -= SL->getElementOffset(Index); + + Type *ElementTy = STy->getElementType(Index); + uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); + if (Offset >= ElementSize) + return nullptr; // The offset points into alignment padding. + + // See if any partition must be contained by the element. + if (Offset > 0 || Size < ElementSize) { + if ((Offset + Size) > ElementSize) + return nullptr; + return getTypePartition(DL, ElementTy, Offset, Size); + } + assert(Offset == 0); + + if (Size == ElementSize) + return stripAggregateTypeWrapping(DL, ElementTy); + + StructType::element_iterator EI = STy->element_begin() + Index, + EE = STy->element_end(); + if (EndOffset < SL->getSizeInBytes()) { + unsigned EndIndex = SL->getElementContainingOffset(EndOffset); + if (Index == EndIndex) + return nullptr; // Within a single element and its padding. + + // Don't try to form "natural" types if the elements don't line up with the + // expected size. + // FIXME: We could potentially recurse down through the last element in the + // sub-struct to find a natural end point. + if (SL->getElementOffset(EndIndex) != EndOffset) + return nullptr; + + assert(Index < EndIndex); + EE = STy->element_begin() + EndIndex; + } + + // Try to build up a sub-structure. + StructType *SubTy = + StructType::get(STy->getContext(), makeArrayRef(EI, EE), STy->isPacked()); + const StructLayout *SubSL = DL.getStructLayout(SubTy); + if (Size != SubSL->getSizeInBytes()) + return nullptr; // The sub-struct doesn't have quite the size needed. + + return SubTy; +} + +/// \brief Pre-split loads and stores to simplify rewriting. +/// +/// We want to break up the splittable load+store pairs as much as +/// possible. This is important to do as a preprocessing step, as once we +/// start rewriting the accesses to partitions of the alloca we lose the +/// necessary information to correctly split apart paired loads and stores +/// which both point into this alloca. The case to consider is something like +/// the following: +/// +/// %a = alloca [12 x i8] +/// %gep1 = getelementptr [12 x i8]* %a, i32 0, i32 0 +/// %gep2 = getelementptr [12 x i8]* %a, i32 0, i32 4 +/// %gep3 = getelementptr [12 x i8]* %a, i32 0, i32 8 +/// %iptr1 = bitcast i8* %gep1 to i64* +/// %iptr2 = bitcast i8* %gep2 to i64* +/// %fptr1 = bitcast i8* %gep1 to float* +/// %fptr2 = bitcast i8* %gep2 to float* +/// %fptr3 = bitcast i8* %gep3 to float* +/// store float 0.0, float* %fptr1 +/// store float 1.0, float* %fptr2 +/// %v = load i64* %iptr1 +/// store i64 %v, i64* %iptr2 +/// %f1 = load float* %fptr2 +/// %f2 = load float* %fptr3 +/// +/// Here we want to form 3 partitions of the alloca, each 4 bytes large, and +/// promote everything so we recover the 2 SSA values that should have been +/// there all along. +/// +/// \returns true if any changes are made. +bool SROA::presplitLoadsAndStores(AllocaInst &AI, AllocaSlices &AS) { + DEBUG(dbgs() << "Pre-splitting loads and stores\n"); + + // Track the loads and stores which are candidates for pre-splitting here, in + // the order they first appear during the partition scan. These give stable + // iteration order and a basis for tracking which loads and stores we + // actually split. + SmallVector<LoadInst *, 4> Loads; + SmallVector<StoreInst *, 4> Stores; + + // We need to accumulate the splits required of each load or store where we + // can find them via a direct lookup. This is important to cross-check loads + // and stores against each other. We also track the slice so that we can kill + // all the slices that end up split. + struct SplitOffsets { + Slice *S; + std::vector<uint64_t> Splits; + }; + SmallDenseMap<Instruction *, SplitOffsets, 8> SplitOffsetsMap; + + // Track loads out of this alloca which cannot, for any reason, be pre-split. + // This is important as we also cannot pre-split stores of those loads! + // FIXME: This is all pretty gross. It means that we can be more aggressive + // in pre-splitting when the load feeding the store happens to come from + // a separate alloca. Put another way, the effectiveness of SROA would be + // decreased by a frontend which just concatenated all of its local allocas + // into one big flat alloca. But defeating such patterns is exactly the job + // SROA is tasked with! Sadly, to not have this discrepancy we would have + // change store pre-splitting to actually force pre-splitting of the load + // that feeds it *and all stores*. That makes pre-splitting much harder, but + // maybe it would make it more principled? + SmallPtrSet<LoadInst *, 8> UnsplittableLoads; + + DEBUG(dbgs() << " Searching for candidate loads and stores\n"); + for (auto &P : AS.partitions()) { + for (Slice &S : P) { + Instruction *I = cast<Instruction>(S.getUse()->getUser()); + if (!S.isSplittable() || S.endOffset() <= P.endOffset()) { + // If this is a load we have to track that it can't participate in any + // pre-splitting. If this is a store of a load we have to track that + // that load also can't participate in any pre-splitting. + if (auto *LI = dyn_cast<LoadInst>(I)) + UnsplittableLoads.insert(LI); + else if (auto *SI = dyn_cast<StoreInst>(I)) + if (auto *LI = dyn_cast<LoadInst>(SI->getValueOperand())) + UnsplittableLoads.insert(LI); + continue; + } + assert(P.endOffset() > S.beginOffset() && + "Empty or backwards partition!"); + + // Determine if this is a pre-splittable slice. + if (auto *LI = dyn_cast<LoadInst>(I)) { + assert(!LI->isVolatile() && "Cannot split volatile loads!"); + + // The load must be used exclusively to store into other pointers for + // us to be able to arbitrarily pre-split it. The stores must also be + // simple to avoid changing semantics. + auto IsLoadSimplyStored = [](LoadInst *LI) { + for (User *LU : LI->users()) { + auto *SI = dyn_cast<StoreInst>(LU); + if (!SI || !SI->isSimple()) + return false; + } + return true; + }; + if (!IsLoadSimplyStored(LI)) { + UnsplittableLoads.insert(LI); + continue; + } + + Loads.push_back(LI); + } else if (auto *SI = dyn_cast<StoreInst>(I)) { + if (S.getUse() != &SI->getOperandUse(SI->getPointerOperandIndex())) + // Skip stores *of* pointers. FIXME: This shouldn't even be possible! + continue; + auto *StoredLoad = dyn_cast<LoadInst>(SI->getValueOperand()); + if (!StoredLoad || !StoredLoad->isSimple()) + continue; + assert(!SI->isVolatile() && "Cannot split volatile stores!"); + + Stores.push_back(SI); + } else { + // Other uses cannot be pre-split. + continue; + } + + // Record the initial split. + DEBUG(dbgs() << " Candidate: " << *I << "\n"); + auto &Offsets = SplitOffsetsMap[I]; + assert(Offsets.Splits.empty() && + "Should not have splits the first time we see an instruction!"); + Offsets.S = &S; + Offsets.Splits.push_back(P.endOffset() - S.beginOffset()); + } + + // Now scan the already split slices, and add a split for any of them which + // we're going to pre-split. + for (Slice *S : P.splitSliceTails()) { + auto SplitOffsetsMapI = + SplitOffsetsMap.find(cast<Instruction>(S->getUse()->getUser())); + if (SplitOffsetsMapI == SplitOffsetsMap.end()) + continue; + auto &Offsets = SplitOffsetsMapI->second; + + assert(Offsets.S == S && "Found a mismatched slice!"); + assert(!Offsets.Splits.empty() && + "Cannot have an empty set of splits on the second partition!"); + assert(Offsets.Splits.back() == + P.beginOffset() - Offsets.S->beginOffset() && + "Previous split does not end where this one begins!"); + + // Record each split. The last partition's end isn't needed as the size + // of the slice dictates that. + if (S->endOffset() > P.endOffset()) + Offsets.Splits.push_back(P.endOffset() - Offsets.S->beginOffset()); + } + } + + // We may have split loads where some of their stores are split stores. For + // such loads and stores, we can only pre-split them if their splits exactly + // match relative to their starting offset. We have to verify this prior to + // any rewriting. + Stores.erase( + llvm::remove_if(Stores, + [&UnsplittableLoads, &SplitOffsetsMap](StoreInst *SI) { + // Lookup the load we are storing in our map of split + // offsets. + auto *LI = cast<LoadInst>(SI->getValueOperand()); + // If it was completely unsplittable, then we're done, + // and this store can't be pre-split. + if (UnsplittableLoads.count(LI)) + return true; + + auto LoadOffsetsI = SplitOffsetsMap.find(LI); + if (LoadOffsetsI == SplitOffsetsMap.end()) + return false; // Unrelated loads are definitely safe. + auto &LoadOffsets = LoadOffsetsI->second; + + // Now lookup the store's offsets. + auto &StoreOffsets = SplitOffsetsMap[SI]; + + // If the relative offsets of each split in the load and + // store match exactly, then we can split them and we + // don't need to remove them here. + if (LoadOffsets.Splits == StoreOffsets.Splits) + return false; + + DEBUG(dbgs() + << " Mismatched splits for load and store:\n" + << " " << *LI << "\n" + << " " << *SI << "\n"); + + // We've found a store and load that we need to split + // with mismatched relative splits. Just give up on them + // and remove both instructions from our list of + // candidates. + UnsplittableLoads.insert(LI); + return true; + }), + Stores.end()); + // Now we have to go *back* through all the stores, because a later store may + // have caused an earlier store's load to become unsplittable and if it is + // unsplittable for the later store, then we can't rely on it being split in + // the earlier store either. + Stores.erase(llvm::remove_if(Stores, + [&UnsplittableLoads](StoreInst *SI) { + auto *LI = + cast<LoadInst>(SI->getValueOperand()); + return UnsplittableLoads.count(LI); + }), + Stores.end()); + // Once we've established all the loads that can't be split for some reason, + // filter any that made it into our list out. + Loads.erase(llvm::remove_if(Loads, + [&UnsplittableLoads](LoadInst *LI) { + return UnsplittableLoads.count(LI); + }), + Loads.end()); + + // If no loads or stores are left, there is no pre-splitting to be done for + // this alloca. + if (Loads.empty() && Stores.empty()) + return false; + + // From here on, we can't fail and will be building new accesses, so rig up + // an IR builder. + IRBuilderTy IRB(&AI); + + // Collect the new slices which we will merge into the alloca slices. + SmallVector<Slice, 4> NewSlices; + + // Track any allocas we end up splitting loads and stores for so we iterate + // on them. + SmallPtrSet<AllocaInst *, 4> ResplitPromotableAllocas; + + // At this point, we have collected all of the loads and stores we can + // pre-split, and the specific splits needed for them. We actually do the + // splitting in a specific order in order to handle when one of the loads in + // the value operand to one of the stores. + // + // First, we rewrite all of the split loads, and just accumulate each split + // load in a parallel structure. We also build the slices for them and append + // them to the alloca slices. + SmallDenseMap<LoadInst *, std::vector<LoadInst *>, 1> SplitLoadsMap; + std::vector<LoadInst *> SplitLoads; + const DataLayout &DL = AI.getModule()->getDataLayout(); + for (LoadInst *LI : Loads) { + SplitLoads.clear(); + + IntegerType *Ty = cast<IntegerType>(LI->getType()); + uint64_t LoadSize = Ty->getBitWidth() / 8; + assert(LoadSize > 0 && "Cannot have a zero-sized integer load!"); + + auto &Offsets = SplitOffsetsMap[LI]; + assert(LoadSize == Offsets.S->endOffset() - Offsets.S->beginOffset() && + "Slice size should always match load size exactly!"); + uint64_t BaseOffset = Offsets.S->beginOffset(); + assert(BaseOffset + LoadSize > BaseOffset && + "Cannot represent alloca access size using 64-bit integers!"); + + Instruction *BasePtr = cast<Instruction>(LI->getPointerOperand()); + IRB.SetInsertPoint(LI); + + DEBUG(dbgs() << " Splitting load: " << *LI << "\n"); + + uint64_t PartOffset = 0, PartSize = Offsets.Splits.front(); + int Idx = 0, Size = Offsets.Splits.size(); + for (;;) { + auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); + auto AS = LI->getPointerAddressSpace(); + auto *PartPtrTy = PartTy->getPointerTo(AS); + LoadInst *PLoad = IRB.CreateAlignedLoad( + getAdjustedPtr(IRB, DL, BasePtr, + APInt(DL.getPointerSizeInBits(AS), PartOffset), + PartPtrTy, BasePtr->getName() + "."), + getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, + LI->getName()); + PLoad->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); + + // Append this load onto the list of split loads so we can find it later + // to rewrite the stores. + SplitLoads.push_back(PLoad); + + // Now build a new slice for the alloca. + NewSlices.push_back( + Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, + &PLoad->getOperandUse(PLoad->getPointerOperandIndex()), + /*IsSplittable*/ false)); + DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() + << ", " << NewSlices.back().endOffset() << "): " << *PLoad + << "\n"); + + // See if we've handled all the splits. + if (Idx >= Size) + break; + + // Setup the next partition. + PartOffset = Offsets.Splits[Idx]; + ++Idx; + PartSize = (Idx < Size ? Offsets.Splits[Idx] : LoadSize) - PartOffset; + } + + // Now that we have the split loads, do the slow walk over all uses of the + // load and rewrite them as split stores, or save the split loads to use + // below if the store is going to be split there anyways. + bool DeferredStores = false; + for (User *LU : LI->users()) { + StoreInst *SI = cast<StoreInst>(LU); + if (!Stores.empty() && SplitOffsetsMap.count(SI)) { + DeferredStores = true; + DEBUG(dbgs() << " Deferred splitting of store: " << *SI << "\n"); + continue; + } + + Value *StoreBasePtr = SI->getPointerOperand(); + IRB.SetInsertPoint(SI); + + DEBUG(dbgs() << " Splitting store of load: " << *SI << "\n"); + + for (int Idx = 0, Size = SplitLoads.size(); Idx < Size; ++Idx) { + LoadInst *PLoad = SplitLoads[Idx]; + uint64_t PartOffset = Idx == 0 ? 0 : Offsets.Splits[Idx - 1]; + auto *PartPtrTy = + PLoad->getType()->getPointerTo(SI->getPointerAddressSpace()); + + auto AS = SI->getPointerAddressSpace(); + StoreInst *PStore = IRB.CreateAlignedStore( + PLoad, + getAdjustedPtr(IRB, DL, StoreBasePtr, + APInt(DL.getPointerSizeInBits(AS), PartOffset), + PartPtrTy, StoreBasePtr->getName() + "."), + getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); + PStore->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); + DEBUG(dbgs() << " +" << PartOffset << ":" << *PStore << "\n"); + } + + // We want to immediately iterate on any allocas impacted by splitting + // this store, and we have to track any promotable alloca (indicated by + // a direct store) as needing to be resplit because it is no longer + // promotable. + if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(StoreBasePtr)) { + ResplitPromotableAllocas.insert(OtherAI); + Worklist.insert(OtherAI); + } else if (AllocaInst *OtherAI = dyn_cast<AllocaInst>( + StoreBasePtr->stripInBoundsOffsets())) { + Worklist.insert(OtherAI); + } + + // Mark the original store as dead. + DeadInsts.insert(SI); + } + + // Save the split loads if there are deferred stores among the users. + if (DeferredStores) + SplitLoadsMap.insert(std::make_pair(LI, std::move(SplitLoads))); + + // Mark the original load as dead and kill the original slice. + DeadInsts.insert(LI); + Offsets.S->kill(); + } + + // Second, we rewrite all of the split stores. At this point, we know that + // all loads from this alloca have been split already. For stores of such + // loads, we can simply look up the pre-existing split loads. For stores of + // other loads, we split those loads first and then write split stores of + // them. + for (StoreInst *SI : Stores) { + auto *LI = cast<LoadInst>(SI->getValueOperand()); + IntegerType *Ty = cast<IntegerType>(LI->getType()); + uint64_t StoreSize = Ty->getBitWidth() / 8; + assert(StoreSize > 0 && "Cannot have a zero-sized integer store!"); + + auto &Offsets = SplitOffsetsMap[SI]; + assert(StoreSize == Offsets.S->endOffset() - Offsets.S->beginOffset() && + "Slice size should always match load size exactly!"); + uint64_t BaseOffset = Offsets.S->beginOffset(); + assert(BaseOffset + StoreSize > BaseOffset && + "Cannot represent alloca access size using 64-bit integers!"); + + Value *LoadBasePtr = LI->getPointerOperand(); + Instruction *StoreBasePtr = cast<Instruction>(SI->getPointerOperand()); + + DEBUG(dbgs() << " Splitting store: " << *SI << "\n"); + + // Check whether we have an already split load. + auto SplitLoadsMapI = SplitLoadsMap.find(LI); + std::vector<LoadInst *> *SplitLoads = nullptr; + if (SplitLoadsMapI != SplitLoadsMap.end()) { + SplitLoads = &SplitLoadsMapI->second; + assert(SplitLoads->size() == Offsets.Splits.size() + 1 && + "Too few split loads for the number of splits in the store!"); + } else { + DEBUG(dbgs() << " of load: " << *LI << "\n"); + } + + uint64_t PartOffset = 0, PartSize = Offsets.Splits.front(); + int Idx = 0, Size = Offsets.Splits.size(); + for (;;) { + auto *PartTy = Type::getIntNTy(Ty->getContext(), PartSize * 8); + auto *LoadPartPtrTy = PartTy->getPointerTo(LI->getPointerAddressSpace()); + auto *StorePartPtrTy = PartTy->getPointerTo(SI->getPointerAddressSpace()); + + // Either lookup a split load or create one. + LoadInst *PLoad; + if (SplitLoads) { + PLoad = (*SplitLoads)[Idx]; + } else { + IRB.SetInsertPoint(LI); + auto AS = LI->getPointerAddressSpace(); + PLoad = IRB.CreateAlignedLoad( + getAdjustedPtr(IRB, DL, LoadBasePtr, + APInt(DL.getPointerSizeInBits(AS), PartOffset), + LoadPartPtrTy, LoadBasePtr->getName() + "."), + getAdjustedAlignment(LI, PartOffset, DL), /*IsVolatile*/ false, + LI->getName()); + } + + // And store this partition. + IRB.SetInsertPoint(SI); + auto AS = SI->getPointerAddressSpace(); + StoreInst *PStore = IRB.CreateAlignedStore( + PLoad, + getAdjustedPtr(IRB, DL, StoreBasePtr, + APInt(DL.getPointerSizeInBits(AS), PartOffset), + StorePartPtrTy, StoreBasePtr->getName() + "."), + getAdjustedAlignment(SI, PartOffset, DL), /*IsVolatile*/ false); + + // Now build a new slice for the alloca. + NewSlices.push_back( + Slice(BaseOffset + PartOffset, BaseOffset + PartOffset + PartSize, + &PStore->getOperandUse(PStore->getPointerOperandIndex()), + /*IsSplittable*/ false)); + DEBUG(dbgs() << " new slice [" << NewSlices.back().beginOffset() + << ", " << NewSlices.back().endOffset() << "): " << *PStore + << "\n"); + if (!SplitLoads) { + DEBUG(dbgs() << " of split load: " << *PLoad << "\n"); + } + + // See if we've finished all the splits. + if (Idx >= Size) + break; + + // Setup the next partition. + PartOffset = Offsets.Splits[Idx]; + ++Idx; + PartSize = (Idx < Size ? Offsets.Splits[Idx] : StoreSize) - PartOffset; + } + + // We want to immediately iterate on any allocas impacted by splitting + // this load, which is only relevant if it isn't a load of this alloca and + // thus we didn't already split the loads above. We also have to keep track + // of any promotable allocas we split loads on as they can no longer be + // promoted. + if (!SplitLoads) { + if (AllocaInst *OtherAI = dyn_cast<AllocaInst>(LoadBasePtr)) { + assert(OtherAI != &AI && "We can't re-split our own alloca!"); + ResplitPromotableAllocas.insert(OtherAI); + Worklist.insert(OtherAI); + } else if (AllocaInst *OtherAI = dyn_cast<AllocaInst>( + LoadBasePtr->stripInBoundsOffsets())) { + assert(OtherAI != &AI && "We can't re-split our own alloca!"); + Worklist.insert(OtherAI); + } + } + + // Mark the original store as dead now that we've split it up and kill its + // slice. Note that we leave the original load in place unless this store + // was its only use. It may in turn be split up if it is an alloca load + // for some other alloca, but it may be a normal load. This may introduce + // redundant loads, but where those can be merged the rest of the optimizer + // should handle the merging, and this uncovers SSA splits which is more + // important. In practice, the original loads will almost always be fully + // split and removed eventually, and the splits will be merged by any + // trivial CSE, including instcombine. + if (LI->hasOneUse()) { + assert(*LI->user_begin() == SI && "Single use isn't this store!"); + DeadInsts.insert(LI); + } + DeadInsts.insert(SI); + Offsets.S->kill(); + } + + // Remove the killed slices that have ben pre-split. + AS.erase(llvm::remove_if(AS, [](const Slice &S) { return S.isDead(); }), + AS.end()); + + // Insert our new slices. This will sort and merge them into the sorted + // sequence. + AS.insert(NewSlices); + + DEBUG(dbgs() << " Pre-split slices:\n"); +#ifndef NDEBUG + for (auto I = AS.begin(), E = AS.end(); I != E; ++I) + DEBUG(AS.print(dbgs(), I, " ")); +#endif + + // Finally, don't try to promote any allocas that new require re-splitting. + // They have already been added to the worklist above. + PromotableAllocas.erase( + llvm::remove_if( + PromotableAllocas, + [&](AllocaInst *AI) { return ResplitPromotableAllocas.count(AI); }), + PromotableAllocas.end()); + + return true; +} + +/// \brief Rewrite an alloca partition's users. +/// +/// This routine drives both of the rewriting goals of the SROA pass. It tries +/// to rewrite uses of an alloca partition to be conducive for SSA value +/// promotion. If the partition needs a new, more refined alloca, this will +/// build that new alloca, preserving as much type information as possible, and +/// rewrite the uses of the old alloca to point at the new one and have the +/// appropriate new offsets. It also evaluates how successful the rewrite was +/// at enabling promotion and if it was successful queues the alloca to be +/// promoted. +AllocaInst *SROA::rewritePartition(AllocaInst &AI, AllocaSlices &AS, + Partition &P) { + // Try to compute a friendly type for this partition of the alloca. This + // won't always succeed, in which case we fall back to a legal integer type + // or an i8 array of an appropriate size. + Type *SliceTy = nullptr; + const DataLayout &DL = AI.getModule()->getDataLayout(); + if (Type *CommonUseTy = findCommonType(P.begin(), P.end(), P.endOffset())) + if (DL.getTypeAllocSize(CommonUseTy) >= P.size()) + SliceTy = CommonUseTy; + if (!SliceTy) + if (Type *TypePartitionTy = getTypePartition(DL, AI.getAllocatedType(), + P.beginOffset(), P.size())) + SliceTy = TypePartitionTy; + if ((!SliceTy || (SliceTy->isArrayTy() && + SliceTy->getArrayElementType()->isIntegerTy())) && + DL.isLegalInteger(P.size() * 8)) + SliceTy = Type::getIntNTy(*C, P.size() * 8); + if (!SliceTy) + SliceTy = ArrayType::get(Type::getInt8Ty(*C), P.size()); + assert(DL.getTypeAllocSize(SliceTy) >= P.size()); + + bool IsIntegerPromotable = isIntegerWideningViable(P, SliceTy, DL); + + VectorType *VecTy = + IsIntegerPromotable ? nullptr : isVectorPromotionViable(P, DL); + if (VecTy) + SliceTy = VecTy; + + // Check for the case where we're going to rewrite to a new alloca of the + // exact same type as the original, and with the same access offsets. In that + // case, re-use the existing alloca, but still run through the rewriter to + // perform phi and select speculation. + AllocaInst *NewAI; + if (SliceTy == AI.getAllocatedType()) { + assert(P.beginOffset() == 0 && + "Non-zero begin offset but same alloca type"); + NewAI = &AI; + // FIXME: We should be able to bail at this point with "nothing changed". + // FIXME: We might want to defer PHI speculation until after here. + // FIXME: return nullptr; + } else { + unsigned Alignment = AI.getAlignment(); + if (!Alignment) { + // The minimum alignment which users can rely on when the explicit + // alignment is omitted or zero is that required by the ABI for this + // type. + Alignment = DL.getABITypeAlignment(AI.getAllocatedType()); + } + Alignment = MinAlign(Alignment, P.beginOffset()); + // If we will get at least this much alignment from the type alone, leave + // the alloca's alignment unconstrained. + if (Alignment <= DL.getABITypeAlignment(SliceTy)) + Alignment = 0; + NewAI = new AllocaInst( + SliceTy, AI.getType()->getAddressSpace(), nullptr, Alignment, + AI.getName() + ".sroa." + Twine(P.begin() - AS.begin()), &AI); + ++NumNewAllocas; + } + + DEBUG(dbgs() << "Rewriting alloca partition " + << "[" << P.beginOffset() << "," << P.endOffset() + << ") to: " << *NewAI << "\n"); + + // Track the high watermark on the worklist as it is only relevant for + // promoted allocas. We will reset it to this point if the alloca is not in + // fact scheduled for promotion. + unsigned PPWOldSize = PostPromotionWorklist.size(); + unsigned NumUses = 0; + SmallSetVector<PHINode *, 8> PHIUsers; + SmallSetVector<SelectInst *, 8> SelectUsers; + + AllocaSliceRewriter Rewriter(DL, AS, *this, AI, *NewAI, P.beginOffset(), + P.endOffset(), IsIntegerPromotable, VecTy, + PHIUsers, SelectUsers); + bool Promotable = true; + for (Slice *S : P.splitSliceTails()) { + Promotable &= Rewriter.visit(S); + ++NumUses; + } + for (Slice &S : P) { + Promotable &= Rewriter.visit(&S); + ++NumUses; + } + + NumAllocaPartitionUses += NumUses; + MaxUsesPerAllocaPartition.updateMax(NumUses); + + // Now that we've processed all the slices in the new partition, check if any + // PHIs or Selects would block promotion. + for (PHINode *PHI : PHIUsers) + if (!isSafePHIToSpeculate(*PHI)) { + Promotable = false; + PHIUsers.clear(); + SelectUsers.clear(); + break; + } + + for (SelectInst *Sel : SelectUsers) + if (!isSafeSelectToSpeculate(*Sel)) { + Promotable = false; + PHIUsers.clear(); + SelectUsers.clear(); + break; + } + + if (Promotable) { + if (PHIUsers.empty() && SelectUsers.empty()) { + // Promote the alloca. + PromotableAllocas.push_back(NewAI); + } else { + // If we have either PHIs or Selects to speculate, add them to those + // worklists and re-queue the new alloca so that we promote in on the + // next iteration. + for (PHINode *PHIUser : PHIUsers) + SpeculatablePHIs.insert(PHIUser); + for (SelectInst *SelectUser : SelectUsers) + SpeculatableSelects.insert(SelectUser); + Worklist.insert(NewAI); + } + } else { + // Drop any post-promotion work items if promotion didn't happen. + while (PostPromotionWorklist.size() > PPWOldSize) + PostPromotionWorklist.pop_back(); + + // We couldn't promote and we didn't create a new partition, nothing + // happened. + if (NewAI == &AI) + return nullptr; + + // If we can't promote the alloca, iterate on it to check for new + // refinements exposed by splitting the current alloca. Don't iterate on an + // alloca which didn't actually change and didn't get promoted. + Worklist.insert(NewAI); + } + + return NewAI; +} + +/// \brief Walks the slices of an alloca and form partitions based on them, +/// rewriting each of their uses. +bool SROA::splitAlloca(AllocaInst &AI, AllocaSlices &AS) { + if (AS.begin() == AS.end()) + return false; + + unsigned NumPartitions = 0; + bool Changed = false; + const DataLayout &DL = AI.getModule()->getDataLayout(); + + // First try to pre-split loads and stores. + Changed |= presplitLoadsAndStores(AI, AS); + + // Now that we have identified any pre-splitting opportunities, + // mark loads and stores unsplittable except for the following case. + // We leave a slice splittable if all other slices are disjoint or fully + // included in the slice, such as whole-alloca loads and stores. + // If we fail to split these during pre-splitting, we want to force them + // to be rewritten into a partition. + bool IsSorted = true; + + uint64_t AllocaSize = DL.getTypeAllocSize(AI.getAllocatedType()); + const uint64_t MaxBitVectorSize = 1024; + if (SROASplitNonWholeAllocaSlices && AllocaSize <= MaxBitVectorSize) { + // If a byte boundary is included in any load or store, a slice starting or + // ending at the boundary is not splittable. + SmallBitVector SplittableOffset(AllocaSize + 1, true); + for (Slice &S : AS) + for (unsigned O = S.beginOffset() + 1; + O < S.endOffset() && O < AllocaSize; O++) + SplittableOffset.reset(O); + + for (Slice &S : AS) { + if (!S.isSplittable()) + continue; + + if ((S.beginOffset() > AllocaSize || SplittableOffset[S.beginOffset()]) && + (S.endOffset() > AllocaSize || SplittableOffset[S.endOffset()])) + continue; + + if (isa<LoadInst>(S.getUse()->getUser()) || + isa<StoreInst>(S.getUse()->getUser())) { + S.makeUnsplittable(); + IsSorted = false; + } + } + } + else { + // We only allow whole-alloca splittable loads and stores + // for a large alloca to avoid creating too large BitVector. + for (Slice &S : AS) { + if (!S.isSplittable()) + continue; + + if (S.beginOffset() == 0 && S.endOffset() >= AllocaSize) + continue; + + if (isa<LoadInst>(S.getUse()->getUser()) || + isa<StoreInst>(S.getUse()->getUser())) { + S.makeUnsplittable(); + IsSorted = false; + } + } + } + + if (!IsSorted) + std::sort(AS.begin(), AS.end()); + + /// Describes the allocas introduced by rewritePartition in order to migrate + /// the debug info. + struct Fragment { + AllocaInst *Alloca; + uint64_t Offset; + uint64_t Size; + Fragment(AllocaInst *AI, uint64_t O, uint64_t S) + : Alloca(AI), Offset(O), Size(S) {} + }; + SmallVector<Fragment, 4> Fragments; + + // Rewrite each partition. + for (auto &P : AS.partitions()) { + if (AllocaInst *NewAI = rewritePartition(AI, AS, P)) { + Changed = true; + if (NewAI != &AI) { + uint64_t SizeOfByte = 8; + uint64_t AllocaSize = DL.getTypeSizeInBits(NewAI->getAllocatedType()); + // Don't include any padding. + uint64_t Size = std::min(AllocaSize, P.size() * SizeOfByte); + Fragments.push_back(Fragment(NewAI, P.beginOffset() * SizeOfByte, Size)); + } + } + ++NumPartitions; + } + + NumAllocaPartitions += NumPartitions; + MaxPartitionsPerAlloca.updateMax(NumPartitions); + + // Migrate debug information from the old alloca to the new alloca(s) + // and the individual partitions. + TinyPtrVector<DbgInfoIntrinsic *> DbgDeclares = FindDbgAddrUses(&AI); + if (!DbgDeclares.empty()) { + auto *Var = DbgDeclares.front()->getVariable(); + auto *Expr = DbgDeclares.front()->getExpression(); + auto VarSize = Var->getSizeInBits(); + DIBuilder DIB(*AI.getModule(), /*AllowUnresolved*/ false); + uint64_t AllocaSize = DL.getTypeSizeInBits(AI.getAllocatedType()); + for (auto Fragment : Fragments) { + // Create a fragment expression describing the new partition or reuse AI's + // expression if there is only one partition. + auto *FragmentExpr = Expr; + if (Fragment.Size < AllocaSize || Expr->isFragment()) { + // If this alloca is already a scalar replacement of a larger aggregate, + // Fragment.Offset describes the offset inside the scalar. + auto ExprFragment = Expr->getFragmentInfo(); + uint64_t Offset = ExprFragment ? ExprFragment->OffsetInBits : 0; + uint64_t Start = Offset + Fragment.Offset; + uint64_t Size = Fragment.Size; + if (ExprFragment) { + uint64_t AbsEnd = + ExprFragment->OffsetInBits + ExprFragment->SizeInBits; + if (Start >= AbsEnd) + // No need to describe a SROAed padding. + continue; + Size = std::min(Size, AbsEnd - Start); + } + // The new, smaller fragment is stenciled out from the old fragment. + if (auto OrigFragment = FragmentExpr->getFragmentInfo()) { + assert(Start >= OrigFragment->OffsetInBits && + "new fragment is outside of original fragment"); + Start -= OrigFragment->OffsetInBits; + } + + // The alloca may be larger than the variable. + if (VarSize) { + if (Size > *VarSize) + Size = *VarSize; + if (Size == 0 || Start + Size > *VarSize) + continue; + } + + // Avoid creating a fragment expression that covers the entire variable. + if (!VarSize || *VarSize != Size) { + if (auto E = + DIExpression::createFragmentExpression(Expr, Start, Size)) + FragmentExpr = *E; + else + continue; + } + } + + // Remove any existing intrinsics describing the same alloca. + for (DbgInfoIntrinsic *OldDII : FindDbgAddrUses(Fragment.Alloca)) + OldDII->eraseFromParent(); + + DIB.insertDeclare(Fragment.Alloca, Var, FragmentExpr, + DbgDeclares.front()->getDebugLoc(), &AI); + } + } + return Changed; +} + +/// \brief Clobber a use with undef, deleting the used value if it becomes dead. +void SROA::clobberUse(Use &U) { + Value *OldV = U; + // Replace the use with an undef value. + U = UndefValue::get(OldV->getType()); + + // Check for this making an instruction dead. We have to garbage collect + // all the dead instructions to ensure the uses of any alloca end up being + // minimal. + if (Instruction *OldI = dyn_cast<Instruction>(OldV)) + if (isInstructionTriviallyDead(OldI)) { + DeadInsts.insert(OldI); + } +} + +/// \brief Analyze an alloca for SROA. +/// +/// This analyzes the alloca to ensure we can reason about it, builds +/// the slices of the alloca, and then hands it off to be split and +/// rewritten as needed. +bool SROA::runOnAlloca(AllocaInst &AI) { + DEBUG(dbgs() << "SROA alloca: " << AI << "\n"); + ++NumAllocasAnalyzed; + + // Special case dead allocas, as they're trivial. + if (AI.use_empty()) { + AI.eraseFromParent(); + return true; + } + const DataLayout &DL = AI.getModule()->getDataLayout(); + + // Skip alloca forms that this analysis can't handle. + if (AI.isArrayAllocation() || !AI.getAllocatedType()->isSized() || + DL.getTypeAllocSize(AI.getAllocatedType()) == 0) + return false; + + bool Changed = false; + + // First, split any FCA loads and stores touching this alloca to promote + // better splitting and promotion opportunities. + AggLoadStoreRewriter AggRewriter; + Changed |= AggRewriter.rewrite(AI); + + // Build the slices using a recursive instruction-visiting builder. + AllocaSlices AS(DL, AI); + DEBUG(AS.print(dbgs())); + if (AS.isEscaped()) + return Changed; + + // Delete all the dead users of this alloca before splitting and rewriting it. + for (Instruction *DeadUser : AS.getDeadUsers()) { + // Free up everything used by this instruction. + for (Use &DeadOp : DeadUser->operands()) + clobberUse(DeadOp); + + // Now replace the uses of this instruction. + DeadUser->replaceAllUsesWith(UndefValue::get(DeadUser->getType())); + + // And mark it for deletion. + DeadInsts.insert(DeadUser); + Changed = true; + } + for (Use *DeadOp : AS.getDeadOperands()) { + clobberUse(*DeadOp); + Changed = true; + } + + // No slices to split. Leave the dead alloca for a later pass to clean up. + if (AS.begin() == AS.end()) + return Changed; + + Changed |= splitAlloca(AI, AS); + + DEBUG(dbgs() << " Speculating PHIs\n"); + while (!SpeculatablePHIs.empty()) + speculatePHINodeLoads(*SpeculatablePHIs.pop_back_val()); + + DEBUG(dbgs() << " Speculating Selects\n"); + while (!SpeculatableSelects.empty()) + speculateSelectInstLoads(*SpeculatableSelects.pop_back_val()); + + return Changed; +} + +/// \brief Delete the dead instructions accumulated in this run. +/// +/// Recursively deletes the dead instructions we've accumulated. This is done +/// at the very end to maximize locality of the recursive delete and to +/// minimize the problems of invalidated instruction pointers as such pointers +/// are used heavily in the intermediate stages of the algorithm. +/// +/// We also record the alloca instructions deleted here so that they aren't +/// subsequently handed to mem2reg to promote. +bool SROA::deleteDeadInstructions( + SmallPtrSetImpl<AllocaInst *> &DeletedAllocas) { + bool Changed = false; + while (!DeadInsts.empty()) { + Instruction *I = DeadInsts.pop_back_val(); + DEBUG(dbgs() << "Deleting dead instruction: " << *I << "\n"); + + // If the instruction is an alloca, find the possible dbg.declare connected + // to it, and remove it too. We must do this before calling RAUW or we will + // not be able to find it. + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) { + DeletedAllocas.insert(AI); + for (DbgInfoIntrinsic *OldDII : FindDbgAddrUses(AI)) + OldDII->eraseFromParent(); + } + + I->replaceAllUsesWith(UndefValue::get(I->getType())); + + for (Use &Operand : I->operands()) + if (Instruction *U = dyn_cast<Instruction>(Operand)) { + // Zero out the operand and see if it becomes trivially dead. + Operand = nullptr; + if (isInstructionTriviallyDead(U)) + DeadInsts.insert(U); + } + + ++NumDeleted; + I->eraseFromParent(); + Changed = true; + } + return Changed; +} + +/// \brief Promote the allocas, using the best available technique. +/// +/// This attempts to promote whatever allocas have been identified as viable in +/// the PromotableAllocas list. If that list is empty, there is nothing to do. +/// This function returns whether any promotion occurred. +bool SROA::promoteAllocas(Function &F) { + if (PromotableAllocas.empty()) + return false; + + NumPromoted += PromotableAllocas.size(); + + DEBUG(dbgs() << "Promoting allocas with mem2reg...\n"); + PromoteMemToReg(PromotableAllocas, *DT, AC); + PromotableAllocas.clear(); + return true; +} + +PreservedAnalyses SROA::runImpl(Function &F, DominatorTree &RunDT, + AssumptionCache &RunAC) { + DEBUG(dbgs() << "SROA function: " << F.getName() << "\n"); + C = &F.getContext(); + DT = &RunDT; + AC = &RunAC; + + BasicBlock &EntryBB = F.getEntryBlock(); + for (BasicBlock::iterator I = EntryBB.begin(), E = std::prev(EntryBB.end()); + I != E; ++I) { + if (AllocaInst *AI = dyn_cast<AllocaInst>(I)) + Worklist.insert(AI); + } + + bool Changed = false; + // A set of deleted alloca instruction pointers which should be removed from + // the list of promotable allocas. + SmallPtrSet<AllocaInst *, 4> DeletedAllocas; + + do { + while (!Worklist.empty()) { + Changed |= runOnAlloca(*Worklist.pop_back_val()); + Changed |= deleteDeadInstructions(DeletedAllocas); + + // Remove the deleted allocas from various lists so that we don't try to + // continue processing them. + if (!DeletedAllocas.empty()) { + auto IsInSet = [&](AllocaInst *AI) { return DeletedAllocas.count(AI); }; + Worklist.remove_if(IsInSet); + PostPromotionWorklist.remove_if(IsInSet); + PromotableAllocas.erase(llvm::remove_if(PromotableAllocas, IsInSet), + PromotableAllocas.end()); + DeletedAllocas.clear(); + } + } + + Changed |= promoteAllocas(F); + + Worklist = PostPromotionWorklist; + PostPromotionWorklist.clear(); + } while (!Worklist.empty()); + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + PA.preserve<GlobalsAA>(); + return PA; +} + +PreservedAnalyses SROA::run(Function &F, FunctionAnalysisManager &AM) { + return runImpl(F, AM.getResult<DominatorTreeAnalysis>(F), + AM.getResult<AssumptionAnalysis>(F)); +} + +/// A legacy pass for the legacy pass manager that wraps the \c SROA pass. +/// +/// This is in the llvm namespace purely to allow it to be a friend of the \c +/// SROA pass. +class llvm::sroa::SROALegacyPass : public FunctionPass { + /// The SROA implementation. + SROA Impl; + +public: + static char ID; + + SROALegacyPass() : FunctionPass(ID) { + initializeSROALegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + auto PA = Impl.runImpl( + F, getAnalysis<DominatorTreeWrapperPass>().getDomTree(), + getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F)); + return !PA.areAllPreserved(); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.setPreservesCFG(); + } + + StringRef getPassName() const override { return "SROA"; } +}; + +char SROALegacyPass::ID = 0; + +FunctionPass *llvm::createSROAPass() { return new SROALegacyPass(); } + +INITIALIZE_PASS_BEGIN(SROALegacyPass, "sroa", + "Scalar Replacement Of Aggregates", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_END(SROALegacyPass, "sroa", "Scalar Replacement Of Aggregates", + false, false) diff --git a/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp b/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp new file mode 100644 index 000000000000..3b99ddff2e06 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/Scalar.cpp @@ -0,0 +1,282 @@ +//===-- Scalar.cpp --------------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements common infrastructure for libLLVMScalarOpts.a, which +// implements several scalar transformations over the LLVM intermediate +// representation, including the C bindings for that library. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar.h" +#include "llvm-c/Initialization.h" +#include "llvm-c/Transforms/Scalar.h" +#include "llvm/Analysis/BasicAliasAnalysis.h" +#include "llvm/Analysis/Passes.h" +#include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/Analysis/TypeBasedAliasAnalysis.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Verifier.h" +#include "llvm/InitializePasses.h" +#include "llvm/Transforms/Scalar/GVN.h" +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" + +using namespace llvm; + +/// initializeScalarOptsPasses - Initialize all passes linked into the +/// ScalarOpts library. +void llvm::initializeScalarOpts(PassRegistry &Registry) { + initializeADCELegacyPassPass(Registry); + initializeBDCELegacyPassPass(Registry); + initializeAlignmentFromAssumptionsPass(Registry); + initializeCallSiteSplittingLegacyPassPass(Registry); + initializeConstantHoistingLegacyPassPass(Registry); + initializeConstantPropagationPass(Registry); + initializeCorrelatedValuePropagationPass(Registry); + initializeDCELegacyPassPass(Registry); + initializeDeadInstEliminationPass(Registry); + initializeDivRemPairsLegacyPassPass(Registry); + initializeScalarizerPass(Registry); + initializeDSELegacyPassPass(Registry); + initializeGuardWideningLegacyPassPass(Registry); + initializeGVNLegacyPassPass(Registry); + initializeNewGVNLegacyPassPass(Registry); + initializeEarlyCSELegacyPassPass(Registry); + initializeEarlyCSEMemSSALegacyPassPass(Registry); + initializeGVNHoistLegacyPassPass(Registry); + initializeGVNSinkLegacyPassPass(Registry); + initializeFlattenCFGPassPass(Registry); + initializeInductiveRangeCheckEliminationPass(Registry); + initializeIndVarSimplifyLegacyPassPass(Registry); + initializeInferAddressSpacesPass(Registry); + initializeJumpThreadingPass(Registry); + initializeLegacyLICMPassPass(Registry); + initializeLegacyLoopSinkPassPass(Registry); + initializeLoopDataPrefetchLegacyPassPass(Registry); + initializeLoopDeletionLegacyPassPass(Registry); + initializeLoopAccessLegacyAnalysisPass(Registry); + initializeLoopInstSimplifyLegacyPassPass(Registry); + initializeLoopInterchangePass(Registry); + initializeLoopPredicationLegacyPassPass(Registry); + initializeLoopRotateLegacyPassPass(Registry); + initializeLoopStrengthReducePass(Registry); + initializeLoopRerollPass(Registry); + initializeLoopUnrollPass(Registry); + initializeLoopUnswitchPass(Registry); + initializeLoopVersioningLICMPass(Registry); + initializeLoopIdiomRecognizeLegacyPassPass(Registry); + initializeLowerAtomicLegacyPassPass(Registry); + initializeLowerExpectIntrinsicPass(Registry); + initializeLowerGuardIntrinsicLegacyPassPass(Registry); + initializeMemCpyOptLegacyPassPass(Registry); + initializeMergeICmpsPass(Registry); + initializeMergedLoadStoreMotionLegacyPassPass(Registry); + initializeNaryReassociateLegacyPassPass(Registry); + initializePartiallyInlineLibCallsLegacyPassPass(Registry); + initializeReassociateLegacyPassPass(Registry); + initializeRegToMemPass(Registry); + initializeRewriteStatepointsForGCLegacyPassPass(Registry); + initializeSCCPLegacyPassPass(Registry); + initializeIPSCCPLegacyPassPass(Registry); + initializeSROALegacyPassPass(Registry); + initializeCFGSimplifyPassPass(Registry); + initializeStructurizeCFGPass(Registry); + initializeSimpleLoopUnswitchLegacyPassPass(Registry); + initializeSinkingLegacyPassPass(Registry); + initializeTailCallElimPass(Registry); + initializeSeparateConstOffsetFromGEPPass(Registry); + initializeSpeculativeExecutionLegacyPassPass(Registry); + initializeStraightLineStrengthReducePass(Registry); + initializePlaceBackedgeSafepointsImplPass(Registry); + initializePlaceSafepointsPass(Registry); + initializeFloat2IntLegacyPassPass(Registry); + initializeLoopDistributeLegacyPass(Registry); + initializeLoopLoadEliminationPass(Registry); + initializeLoopSimplifyCFGLegacyPassPass(Registry); + initializeLoopVersioningPassPass(Registry); + initializeEntryExitInstrumenterPass(Registry); + initializePostInlineEntryExitInstrumenterPass(Registry); +} + +void LLVMInitializeScalarOpts(LLVMPassRegistryRef R) { + initializeScalarOpts(*unwrap(R)); +} + +void LLVMAddAggressiveDCEPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createAggressiveDCEPass()); +} + +void LLVMAddBitTrackingDCEPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createBitTrackingDCEPass()); +} + +void LLVMAddAlignmentFromAssumptionsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createAlignmentFromAssumptionsPass()); +} + +void LLVMAddCFGSimplificationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCFGSimplificationPass(1, false, false, true)); +} + +void LLVMAddDeadStoreEliminationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createDeadStoreEliminationPass()); +} + +void LLVMAddScalarizerPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createScalarizerPass()); +} + +void LLVMAddGVNPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createGVNPass()); +} + +void LLVMAddNewGVNPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createNewGVNPass()); +} + +void LLVMAddMergedLoadStoreMotionPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createMergedLoadStoreMotionPass()); +} + +void LLVMAddIndVarSimplifyPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createIndVarSimplifyPass()); +} + +void LLVMAddInstructionCombiningPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createInstructionCombiningPass()); +} + +void LLVMAddJumpThreadingPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createJumpThreadingPass()); +} + +void LLVMAddLoopSinkPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopSinkPass()); +} + +void LLVMAddLICMPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLICMPass()); +} + +void LLVMAddLoopDeletionPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopDeletionPass()); +} + +void LLVMAddLoopIdiomPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopIdiomPass()); +} + +void LLVMAddLoopRotatePass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopRotatePass()); +} + +void LLVMAddLoopRerollPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopRerollPass()); +} + +void LLVMAddLoopSimplifyCFGPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopSimplifyCFGPass()); +} + +void LLVMAddLoopUnrollPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopUnrollPass()); +} + +void LLVMAddLoopUnswitchPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLoopUnswitchPass()); +} + +void LLVMAddMemCpyOptPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createMemCpyOptPass()); +} + +void LLVMAddPartiallyInlineLibCallsPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createPartiallyInlineLibCallsPass()); +} + +void LLVMAddLowerSwitchPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLowerSwitchPass()); +} + +void LLVMAddPromoteMemoryToRegisterPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createPromoteMemoryToRegisterPass()); +} + +void LLVMAddReassociatePass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createReassociatePass()); +} + +void LLVMAddSCCPPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createSCCPPass()); +} + +void LLVMAddScalarReplAggregatesPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createSROAPass()); +} + +void LLVMAddScalarReplAggregatesPassSSA(LLVMPassManagerRef PM) { + unwrap(PM)->add(createSROAPass()); +} + +void LLVMAddScalarReplAggregatesPassWithThreshold(LLVMPassManagerRef PM, + int Threshold) { + unwrap(PM)->add(createSROAPass()); +} + +void LLVMAddSimplifyLibCallsPass(LLVMPassManagerRef PM) { + // NOTE: The simplify-libcalls pass has been removed. +} + +void LLVMAddTailCallEliminationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createTailCallEliminationPass()); +} + +void LLVMAddConstantPropagationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createConstantPropagationPass()); +} + +void LLVMAddDemoteMemoryToRegisterPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createDemoteRegisterToMemoryPass()); +} + +void LLVMAddVerifierPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createVerifierPass()); +} + +void LLVMAddCorrelatedValuePropagationPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createCorrelatedValuePropagationPass()); +} + +void LLVMAddEarlyCSEPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createEarlyCSEPass(false/*=UseMemorySSA*/)); +} + +void LLVMAddEarlyCSEMemSSAPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createEarlyCSEPass(true/*=UseMemorySSA*/)); +} + +void LLVMAddGVNHoistLegacyPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createGVNHoistPass()); +} + +void LLVMAddTypeBasedAliasAnalysisPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createTypeBasedAAWrapperPass()); +} + +void LLVMAddScopedNoAliasAAPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createScopedNoAliasAAWrapperPass()); +} + +void LLVMAddBasicAliasAnalysisPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createBasicAAWrapperPass()); +} + +void LLVMAddLowerExpectIntrinsicPass(LLVMPassManagerRef PM) { + unwrap(PM)->add(createLowerExpectIntrinsicPass()); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp new file mode 100644 index 000000000000..34ed126155be --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/Scalarizer.cpp @@ -0,0 +1,802 @@ +//===- Scalarizer.cpp - Scalarize vector operations -----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass converts vector operations into scalar operations, in order +// to expose optimization opportunities on the individual scalar operations. +// It is mainly intended for targets that do not have vector units, but it +// may also be useful for revectorizing code to different vector widths. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstVisitor.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/MathExtras.h" +#include "llvm/Support/Options.h" +#include "llvm/Transforms/Scalar.h" +#include <cassert> +#include <cstdint> +#include <iterator> +#include <map> +#include <utility> + +using namespace llvm; + +#define DEBUG_TYPE "scalarizer" + +namespace { + +// Used to store the scattered form of a vector. +using ValueVector = SmallVector<Value *, 8>; + +// Used to map a vector Value to its scattered form. We use std::map +// because we want iterators to persist across insertion and because the +// values are relatively large. +using ScatterMap = std::map<Value *, ValueVector>; + +// Lists Instructions that have been replaced with scalar implementations, +// along with a pointer to their scattered forms. +using GatherList = SmallVector<std::pair<Instruction *, ValueVector *>, 16>; + +// Provides a very limited vector-like interface for lazily accessing one +// component of a scattered vector or vector pointer. +class Scatterer { +public: + Scatterer() = default; + + // Scatter V into Size components. If new instructions are needed, + // insert them before BBI in BB. If Cache is nonnull, use it to cache + // the results. + Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, + ValueVector *cachePtr = nullptr); + + // Return component I, creating a new Value for it if necessary. + Value *operator[](unsigned I); + + // Return the number of components. + unsigned size() const { return Size; } + +private: + BasicBlock *BB; + BasicBlock::iterator BBI; + Value *V; + ValueVector *CachePtr; + PointerType *PtrTy; + ValueVector Tmp; + unsigned Size; +}; + +// FCmpSpliiter(FCI)(Builder, X, Y, Name) uses Builder to create an FCmp +// called Name that compares X and Y in the same way as FCI. +struct FCmpSplitter { + FCmpSplitter(FCmpInst &fci) : FCI(fci) {} + + Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, + const Twine &Name) const { + return Builder.CreateFCmp(FCI.getPredicate(), Op0, Op1, Name); + } + + FCmpInst &FCI; +}; + +// ICmpSpliiter(ICI)(Builder, X, Y, Name) uses Builder to create an ICmp +// called Name that compares X and Y in the same way as ICI. +struct ICmpSplitter { + ICmpSplitter(ICmpInst &ici) : ICI(ici) {} + + Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, + const Twine &Name) const { + return Builder.CreateICmp(ICI.getPredicate(), Op0, Op1, Name); + } + + ICmpInst &ICI; +}; + +// BinarySpliiter(BO)(Builder, X, Y, Name) uses Builder to create +// a binary operator like BO called Name with operands X and Y. +struct BinarySplitter { + BinarySplitter(BinaryOperator &bo) : BO(bo) {} + + Value *operator()(IRBuilder<> &Builder, Value *Op0, Value *Op1, + const Twine &Name) const { + return Builder.CreateBinOp(BO.getOpcode(), Op0, Op1, Name); + } + + BinaryOperator &BO; +}; + +// Information about a load or store that we're scalarizing. +struct VectorLayout { + VectorLayout() = default; + + // Return the alignment of element I. + uint64_t getElemAlign(unsigned I) { + return MinAlign(VecAlign, I * ElemSize); + } + + // The type of the vector. + VectorType *VecTy = nullptr; + + // The type of each element. + Type *ElemTy = nullptr; + + // The alignment of the vector. + uint64_t VecAlign = 0; + + // The size of each element. + uint64_t ElemSize = 0; +}; + +class Scalarizer : public FunctionPass, + public InstVisitor<Scalarizer, bool> { +public: + static char ID; + + Scalarizer() : FunctionPass(ID) { + initializeScalarizerPass(*PassRegistry::getPassRegistry()); + } + + bool doInitialization(Module &M) override; + bool runOnFunction(Function &F) override; + + // InstVisitor methods. They return true if the instruction was scalarized, + // false if nothing changed. + bool visitInstruction(Instruction &I) { return false; } + bool visitSelectInst(SelectInst &SI); + bool visitICmpInst(ICmpInst &ICI); + bool visitFCmpInst(FCmpInst &FCI); + bool visitBinaryOperator(BinaryOperator &BO); + bool visitGetElementPtrInst(GetElementPtrInst &GEPI); + bool visitCastInst(CastInst &CI); + bool visitBitCastInst(BitCastInst &BCI); + bool visitShuffleVectorInst(ShuffleVectorInst &SVI); + bool visitPHINode(PHINode &PHI); + bool visitLoadInst(LoadInst &LI); + bool visitStoreInst(StoreInst &SI); + bool visitCallInst(CallInst &ICI); + + static void registerOptions() { + // This is disabled by default because having separate loads and stores + // makes it more likely that the -combiner-alias-analysis limits will be + // reached. + OptionRegistry::registerOption<bool, Scalarizer, + &Scalarizer::ScalarizeLoadStore>( + "scalarize-load-store", + "Allow the scalarizer pass to scalarize loads and store", false); + } + +private: + Scatterer scatter(Instruction *Point, Value *V); + void gather(Instruction *Op, const ValueVector &CV); + bool canTransferMetadata(unsigned Kind); + void transferMetadata(Instruction *Op, const ValueVector &CV); + bool getVectorLayout(Type *Ty, unsigned Alignment, VectorLayout &Layout, + const DataLayout &DL); + bool finish(); + + template<typename T> bool splitBinary(Instruction &, const T &); + + bool splitCall(CallInst &CI); + + ScatterMap Scattered; + GatherList Gathered; + unsigned ParallelLoopAccessMDKind; + bool ScalarizeLoadStore; +}; + +} // end anonymous namespace + +char Scalarizer::ID = 0; + +INITIALIZE_PASS_WITH_OPTIONS(Scalarizer, "scalarizer", + "Scalarize vector operations", false, false) + +Scatterer::Scatterer(BasicBlock *bb, BasicBlock::iterator bbi, Value *v, + ValueVector *cachePtr) + : BB(bb), BBI(bbi), V(v), CachePtr(cachePtr) { + Type *Ty = V->getType(); + PtrTy = dyn_cast<PointerType>(Ty); + if (PtrTy) + Ty = PtrTy->getElementType(); + Size = Ty->getVectorNumElements(); + if (!CachePtr) + Tmp.resize(Size, nullptr); + else if (CachePtr->empty()) + CachePtr->resize(Size, nullptr); + else + assert(Size == CachePtr->size() && "Inconsistent vector sizes"); +} + +// Return component I, creating a new Value for it if necessary. +Value *Scatterer::operator[](unsigned I) { + ValueVector &CV = (CachePtr ? *CachePtr : Tmp); + // Try to reuse a previous value. + if (CV[I]) + return CV[I]; + IRBuilder<> Builder(BB, BBI); + if (PtrTy) { + if (!CV[0]) { + Type *Ty = + PointerType::get(PtrTy->getElementType()->getVectorElementType(), + PtrTy->getAddressSpace()); + CV[0] = Builder.CreateBitCast(V, Ty, V->getName() + ".i0"); + } + if (I != 0) + CV[I] = Builder.CreateConstGEP1_32(nullptr, CV[0], I, + V->getName() + ".i" + Twine(I)); + } else { + // Search through a chain of InsertElementInsts looking for element I. + // Record other elements in the cache. The new V is still suitable + // for all uncached indices. + while (true) { + InsertElementInst *Insert = dyn_cast<InsertElementInst>(V); + if (!Insert) + break; + ConstantInt *Idx = dyn_cast<ConstantInt>(Insert->getOperand(2)); + if (!Idx) + break; + unsigned J = Idx->getZExtValue(); + V = Insert->getOperand(0); + if (I == J) { + CV[J] = Insert->getOperand(1); + return CV[J]; + } else if (!CV[J]) { + // Only cache the first entry we find for each index we're not actively + // searching for. This prevents us from going too far up the chain and + // caching incorrect entries. + CV[J] = Insert->getOperand(1); + } + } + CV[I] = Builder.CreateExtractElement(V, Builder.getInt32(I), + V->getName() + ".i" + Twine(I)); + } + return CV[I]; +} + +bool Scalarizer::doInitialization(Module &M) { + ParallelLoopAccessMDKind = + M.getContext().getMDKindID("llvm.mem.parallel_loop_access"); + ScalarizeLoadStore = + M.getContext().getOption<bool, Scalarizer, &Scalarizer::ScalarizeLoadStore>(); + return false; +} + +bool Scalarizer::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + assert(Gathered.empty() && Scattered.empty()); + for (BasicBlock &BB : F) { + for (BasicBlock::iterator II = BB.begin(), IE = BB.end(); II != IE;) { + Instruction *I = &*II; + bool Done = visit(I); + ++II; + if (Done && I->getType()->isVoidTy()) + I->eraseFromParent(); + } + } + return finish(); +} + +// Return a scattered form of V that can be accessed by Point. V must be a +// vector or a pointer to a vector. +Scatterer Scalarizer::scatter(Instruction *Point, Value *V) { + if (Argument *VArg = dyn_cast<Argument>(V)) { + // Put the scattered form of arguments in the entry block, + // so that it can be used everywhere. + Function *F = VArg->getParent(); + BasicBlock *BB = &F->getEntryBlock(); + return Scatterer(BB, BB->begin(), V, &Scattered[V]); + } + if (Instruction *VOp = dyn_cast<Instruction>(V)) { + // Put the scattered form of an instruction directly after the + // instruction. + BasicBlock *BB = VOp->getParent(); + return Scatterer(BB, std::next(BasicBlock::iterator(VOp)), + V, &Scattered[V]); + } + // In the fallback case, just put the scattered before Point and + // keep the result local to Point. + return Scatterer(Point->getParent(), Point->getIterator(), V); +} + +// Replace Op with the gathered form of the components in CV. Defer the +// deletion of Op and creation of the gathered form to the end of the pass, +// so that we can avoid creating the gathered form if all uses of Op are +// replaced with uses of CV. +void Scalarizer::gather(Instruction *Op, const ValueVector &CV) { + // Since we're not deleting Op yet, stub out its operands, so that it + // doesn't make anything live unnecessarily. + for (unsigned I = 0, E = Op->getNumOperands(); I != E; ++I) + Op->setOperand(I, UndefValue::get(Op->getOperand(I)->getType())); + + transferMetadata(Op, CV); + + // If we already have a scattered form of Op (created from ExtractElements + // of Op itself), replace them with the new form. + ValueVector &SV = Scattered[Op]; + if (!SV.empty()) { + for (unsigned I = 0, E = SV.size(); I != E; ++I) { + Value *V = SV[I]; + if (V == nullptr) + continue; + + Instruction *Old = cast<Instruction>(V); + CV[I]->takeName(Old); + Old->replaceAllUsesWith(CV[I]); + Old->eraseFromParent(); + } + } + SV = CV; + Gathered.push_back(GatherList::value_type(Op, &SV)); +} + +// Return true if it is safe to transfer the given metadata tag from +// vector to scalar instructions. +bool Scalarizer::canTransferMetadata(unsigned Tag) { + return (Tag == LLVMContext::MD_tbaa + || Tag == LLVMContext::MD_fpmath + || Tag == LLVMContext::MD_tbaa_struct + || Tag == LLVMContext::MD_invariant_load + || Tag == LLVMContext::MD_alias_scope + || Tag == LLVMContext::MD_noalias + || Tag == ParallelLoopAccessMDKind); +} + +// Transfer metadata from Op to the instructions in CV if it is known +// to be safe to do so. +void Scalarizer::transferMetadata(Instruction *Op, const ValueVector &CV) { + SmallVector<std::pair<unsigned, MDNode *>, 4> MDs; + Op->getAllMetadataOtherThanDebugLoc(MDs); + for (unsigned I = 0, E = CV.size(); I != E; ++I) { + if (Instruction *New = dyn_cast<Instruction>(CV[I])) { + for (const auto &MD : MDs) + if (canTransferMetadata(MD.first)) + New->setMetadata(MD.first, MD.second); + if (Op->getDebugLoc() && !New->getDebugLoc()) + New->setDebugLoc(Op->getDebugLoc()); + } + } +} + +// Try to fill in Layout from Ty, returning true on success. Alignment is +// the alignment of the vector, or 0 if the ABI default should be used. +bool Scalarizer::getVectorLayout(Type *Ty, unsigned Alignment, + VectorLayout &Layout, const DataLayout &DL) { + // Make sure we're dealing with a vector. + Layout.VecTy = dyn_cast<VectorType>(Ty); + if (!Layout.VecTy) + return false; + + // Check that we're dealing with full-byte elements. + Layout.ElemTy = Layout.VecTy->getElementType(); + if (DL.getTypeSizeInBits(Layout.ElemTy) != + DL.getTypeStoreSizeInBits(Layout.ElemTy)) + return false; + + if (Alignment) + Layout.VecAlign = Alignment; + else + Layout.VecAlign = DL.getABITypeAlignment(Layout.VecTy); + Layout.ElemSize = DL.getTypeStoreSize(Layout.ElemTy); + return true; +} + +// Scalarize two-operand instruction I, using Split(Builder, X, Y, Name) +// to create an instruction like I with operands X and Y and name Name. +template<typename Splitter> +bool Scalarizer::splitBinary(Instruction &I, const Splitter &Split) { + VectorType *VT = dyn_cast<VectorType>(I.getType()); + if (!VT) + return false; + + unsigned NumElems = VT->getNumElements(); + IRBuilder<> Builder(&I); + Scatterer Op0 = scatter(&I, I.getOperand(0)); + Scatterer Op1 = scatter(&I, I.getOperand(1)); + assert(Op0.size() == NumElems && "Mismatched binary operation"); + assert(Op1.size() == NumElems && "Mismatched binary operation"); + ValueVector Res; + Res.resize(NumElems); + for (unsigned Elem = 0; Elem < NumElems; ++Elem) + Res[Elem] = Split(Builder, Op0[Elem], Op1[Elem], + I.getName() + ".i" + Twine(Elem)); + gather(&I, Res); + return true; +} + +static bool isTriviallyScalariable(Intrinsic::ID ID) { + return isTriviallyVectorizable(ID); +} + +// All of the current scalarizable intrinsics only have one mangled type. +static Function *getScalarIntrinsicDeclaration(Module *M, + Intrinsic::ID ID, + VectorType *Ty) { + return Intrinsic::getDeclaration(M, ID, { Ty->getScalarType() }); +} + +/// If a call to a vector typed intrinsic function, split into a scalar call per +/// element if possible for the intrinsic. +bool Scalarizer::splitCall(CallInst &CI) { + VectorType *VT = dyn_cast<VectorType>(CI.getType()); + if (!VT) + return false; + + Function *F = CI.getCalledFunction(); + if (!F) + return false; + + Intrinsic::ID ID = F->getIntrinsicID(); + if (ID == Intrinsic::not_intrinsic || !isTriviallyScalariable(ID)) + return false; + + unsigned NumElems = VT->getNumElements(); + unsigned NumArgs = CI.getNumArgOperands(); + + ValueVector ScalarOperands(NumArgs); + SmallVector<Scatterer, 8> Scattered(NumArgs); + + Scattered.resize(NumArgs); + + // Assumes that any vector type has the same number of elements as the return + // vector type, which is true for all current intrinsics. + for (unsigned I = 0; I != NumArgs; ++I) { + Value *OpI = CI.getOperand(I); + if (OpI->getType()->isVectorTy()) { + Scattered[I] = scatter(&CI, OpI); + assert(Scattered[I].size() == NumElems && "mismatched call operands"); + } else { + ScalarOperands[I] = OpI; + } + } + + ValueVector Res(NumElems); + ValueVector ScalarCallOps(NumArgs); + + Function *NewIntrin = getScalarIntrinsicDeclaration(F->getParent(), ID, VT); + IRBuilder<> Builder(&CI); + + // Perform actual scalarization, taking care to preserve any scalar operands. + for (unsigned Elem = 0; Elem < NumElems; ++Elem) { + ScalarCallOps.clear(); + + for (unsigned J = 0; J != NumArgs; ++J) { + if (hasVectorInstrinsicScalarOpd(ID, J)) + ScalarCallOps.push_back(ScalarOperands[J]); + else + ScalarCallOps.push_back(Scattered[J][Elem]); + } + + Res[Elem] = Builder.CreateCall(NewIntrin, ScalarCallOps, + CI.getName() + ".i" + Twine(Elem)); + } + + gather(&CI, Res); + return true; +} + +bool Scalarizer::visitSelectInst(SelectInst &SI) { + VectorType *VT = dyn_cast<VectorType>(SI.getType()); + if (!VT) + return false; + + unsigned NumElems = VT->getNumElements(); + IRBuilder<> Builder(&SI); + Scatterer Op1 = scatter(&SI, SI.getOperand(1)); + Scatterer Op2 = scatter(&SI, SI.getOperand(2)); + assert(Op1.size() == NumElems && "Mismatched select"); + assert(Op2.size() == NumElems && "Mismatched select"); + ValueVector Res; + Res.resize(NumElems); + + if (SI.getOperand(0)->getType()->isVectorTy()) { + Scatterer Op0 = scatter(&SI, SI.getOperand(0)); + assert(Op0.size() == NumElems && "Mismatched select"); + for (unsigned I = 0; I < NumElems; ++I) + Res[I] = Builder.CreateSelect(Op0[I], Op1[I], Op2[I], + SI.getName() + ".i" + Twine(I)); + } else { + Value *Op0 = SI.getOperand(0); + for (unsigned I = 0; I < NumElems; ++I) + Res[I] = Builder.CreateSelect(Op0, Op1[I], Op2[I], + SI.getName() + ".i" + Twine(I)); + } + gather(&SI, Res); + return true; +} + +bool Scalarizer::visitICmpInst(ICmpInst &ICI) { + return splitBinary(ICI, ICmpSplitter(ICI)); +} + +bool Scalarizer::visitFCmpInst(FCmpInst &FCI) { + return splitBinary(FCI, FCmpSplitter(FCI)); +} + +bool Scalarizer::visitBinaryOperator(BinaryOperator &BO) { + return splitBinary(BO, BinarySplitter(BO)); +} + +bool Scalarizer::visitGetElementPtrInst(GetElementPtrInst &GEPI) { + VectorType *VT = dyn_cast<VectorType>(GEPI.getType()); + if (!VT) + return false; + + IRBuilder<> Builder(&GEPI); + unsigned NumElems = VT->getNumElements(); + unsigned NumIndices = GEPI.getNumIndices(); + + // The base pointer might be scalar even if it's a vector GEP. In those cases, + // splat the pointer into a vector value, and scatter that vector. + Value *Op0 = GEPI.getOperand(0); + if (!Op0->getType()->isVectorTy()) + Op0 = Builder.CreateVectorSplat(NumElems, Op0); + Scatterer Base = scatter(&GEPI, Op0); + + SmallVector<Scatterer, 8> Ops; + Ops.resize(NumIndices); + for (unsigned I = 0; I < NumIndices; ++I) { + Value *Op = GEPI.getOperand(I + 1); + + // The indices might be scalars even if it's a vector GEP. In those cases, + // splat the scalar into a vector value, and scatter that vector. + if (!Op->getType()->isVectorTy()) + Op = Builder.CreateVectorSplat(NumElems, Op); + + Ops[I] = scatter(&GEPI, Op); + } + + ValueVector Res; + Res.resize(NumElems); + for (unsigned I = 0; I < NumElems; ++I) { + SmallVector<Value *, 8> Indices; + Indices.resize(NumIndices); + for (unsigned J = 0; J < NumIndices; ++J) + Indices[J] = Ops[J][I]; + Res[I] = Builder.CreateGEP(GEPI.getSourceElementType(), Base[I], Indices, + GEPI.getName() + ".i" + Twine(I)); + if (GEPI.isInBounds()) + if (GetElementPtrInst *NewGEPI = dyn_cast<GetElementPtrInst>(Res[I])) + NewGEPI->setIsInBounds(); + } + gather(&GEPI, Res); + return true; +} + +bool Scalarizer::visitCastInst(CastInst &CI) { + VectorType *VT = dyn_cast<VectorType>(CI.getDestTy()); + if (!VT) + return false; + + unsigned NumElems = VT->getNumElements(); + IRBuilder<> Builder(&CI); + Scatterer Op0 = scatter(&CI, CI.getOperand(0)); + assert(Op0.size() == NumElems && "Mismatched cast"); + ValueVector Res; + Res.resize(NumElems); + for (unsigned I = 0; I < NumElems; ++I) + Res[I] = Builder.CreateCast(CI.getOpcode(), Op0[I], VT->getElementType(), + CI.getName() + ".i" + Twine(I)); + gather(&CI, Res); + return true; +} + +bool Scalarizer::visitBitCastInst(BitCastInst &BCI) { + VectorType *DstVT = dyn_cast<VectorType>(BCI.getDestTy()); + VectorType *SrcVT = dyn_cast<VectorType>(BCI.getSrcTy()); + if (!DstVT || !SrcVT) + return false; + + unsigned DstNumElems = DstVT->getNumElements(); + unsigned SrcNumElems = SrcVT->getNumElements(); + IRBuilder<> Builder(&BCI); + Scatterer Op0 = scatter(&BCI, BCI.getOperand(0)); + ValueVector Res; + Res.resize(DstNumElems); + + if (DstNumElems == SrcNumElems) { + for (unsigned I = 0; I < DstNumElems; ++I) + Res[I] = Builder.CreateBitCast(Op0[I], DstVT->getElementType(), + BCI.getName() + ".i" + Twine(I)); + } else if (DstNumElems > SrcNumElems) { + // <M x t1> -> <N*M x t2>. Convert each t1 to <N x t2> and copy the + // individual elements to the destination. + unsigned FanOut = DstNumElems / SrcNumElems; + Type *MidTy = VectorType::get(DstVT->getElementType(), FanOut); + unsigned ResI = 0; + for (unsigned Op0I = 0; Op0I < SrcNumElems; ++Op0I) { + Value *V = Op0[Op0I]; + Instruction *VI; + // Look through any existing bitcasts before converting to <N x t2>. + // In the best case, the resulting conversion might be a no-op. + while ((VI = dyn_cast<Instruction>(V)) && + VI->getOpcode() == Instruction::BitCast) + V = VI->getOperand(0); + V = Builder.CreateBitCast(V, MidTy, V->getName() + ".cast"); + Scatterer Mid = scatter(&BCI, V); + for (unsigned MidI = 0; MidI < FanOut; ++MidI) + Res[ResI++] = Mid[MidI]; + } + } else { + // <N*M x t1> -> <M x t2>. Convert each group of <N x t1> into a t2. + unsigned FanIn = SrcNumElems / DstNumElems; + Type *MidTy = VectorType::get(SrcVT->getElementType(), FanIn); + unsigned Op0I = 0; + for (unsigned ResI = 0; ResI < DstNumElems; ++ResI) { + Value *V = UndefValue::get(MidTy); + for (unsigned MidI = 0; MidI < FanIn; ++MidI) + V = Builder.CreateInsertElement(V, Op0[Op0I++], Builder.getInt32(MidI), + BCI.getName() + ".i" + Twine(ResI) + + ".upto" + Twine(MidI)); + Res[ResI] = Builder.CreateBitCast(V, DstVT->getElementType(), + BCI.getName() + ".i" + Twine(ResI)); + } + } + gather(&BCI, Res); + return true; +} + +bool Scalarizer::visitShuffleVectorInst(ShuffleVectorInst &SVI) { + VectorType *VT = dyn_cast<VectorType>(SVI.getType()); + if (!VT) + return false; + + unsigned NumElems = VT->getNumElements(); + Scatterer Op0 = scatter(&SVI, SVI.getOperand(0)); + Scatterer Op1 = scatter(&SVI, SVI.getOperand(1)); + ValueVector Res; + Res.resize(NumElems); + + for (unsigned I = 0; I < NumElems; ++I) { + int Selector = SVI.getMaskValue(I); + if (Selector < 0) + Res[I] = UndefValue::get(VT->getElementType()); + else if (unsigned(Selector) < Op0.size()) + Res[I] = Op0[Selector]; + else + Res[I] = Op1[Selector - Op0.size()]; + } + gather(&SVI, Res); + return true; +} + +bool Scalarizer::visitPHINode(PHINode &PHI) { + VectorType *VT = dyn_cast<VectorType>(PHI.getType()); + if (!VT) + return false; + + unsigned NumElems = VT->getNumElements(); + IRBuilder<> Builder(&PHI); + ValueVector Res; + Res.resize(NumElems); + + unsigned NumOps = PHI.getNumOperands(); + for (unsigned I = 0; I < NumElems; ++I) + Res[I] = Builder.CreatePHI(VT->getElementType(), NumOps, + PHI.getName() + ".i" + Twine(I)); + + for (unsigned I = 0; I < NumOps; ++I) { + Scatterer Op = scatter(&PHI, PHI.getIncomingValue(I)); + BasicBlock *IncomingBlock = PHI.getIncomingBlock(I); + for (unsigned J = 0; J < NumElems; ++J) + cast<PHINode>(Res[J])->addIncoming(Op[J], IncomingBlock); + } + gather(&PHI, Res); + return true; +} + +bool Scalarizer::visitLoadInst(LoadInst &LI) { + if (!ScalarizeLoadStore) + return false; + if (!LI.isSimple()) + return false; + + VectorLayout Layout; + if (!getVectorLayout(LI.getType(), LI.getAlignment(), Layout, + LI.getModule()->getDataLayout())) + return false; + + unsigned NumElems = Layout.VecTy->getNumElements(); + IRBuilder<> Builder(&LI); + Scatterer Ptr = scatter(&LI, LI.getPointerOperand()); + ValueVector Res; + Res.resize(NumElems); + + for (unsigned I = 0; I < NumElems; ++I) + Res[I] = Builder.CreateAlignedLoad(Ptr[I], Layout.getElemAlign(I), + LI.getName() + ".i" + Twine(I)); + gather(&LI, Res); + return true; +} + +bool Scalarizer::visitStoreInst(StoreInst &SI) { + if (!ScalarizeLoadStore) + return false; + if (!SI.isSimple()) + return false; + + VectorLayout Layout; + Value *FullValue = SI.getValueOperand(); + if (!getVectorLayout(FullValue->getType(), SI.getAlignment(), Layout, + SI.getModule()->getDataLayout())) + return false; + + unsigned NumElems = Layout.VecTy->getNumElements(); + IRBuilder<> Builder(&SI); + Scatterer Ptr = scatter(&SI, SI.getPointerOperand()); + Scatterer Val = scatter(&SI, FullValue); + + ValueVector Stores; + Stores.resize(NumElems); + for (unsigned I = 0; I < NumElems; ++I) { + unsigned Align = Layout.getElemAlign(I); + Stores[I] = Builder.CreateAlignedStore(Val[I], Ptr[I], Align); + } + transferMetadata(&SI, Stores); + return true; +} + +bool Scalarizer::visitCallInst(CallInst &CI) { + return splitCall(CI); +} + +// Delete the instructions that we scalarized. If a full vector result +// is still needed, recreate it using InsertElements. +bool Scalarizer::finish() { + // The presence of data in Gathered or Scattered indicates changes + // made to the Function. + if (Gathered.empty() && Scattered.empty()) + return false; + for (const auto &GMI : Gathered) { + Instruction *Op = GMI.first; + ValueVector &CV = *GMI.second; + if (!Op->use_empty()) { + // The value is still needed, so recreate it using a series of + // InsertElements. + Type *Ty = Op->getType(); + Value *Res = UndefValue::get(Ty); + BasicBlock *BB = Op->getParent(); + unsigned Count = Ty->getVectorNumElements(); + IRBuilder<> Builder(Op); + if (isa<PHINode>(Op)) + Builder.SetInsertPoint(BB, BB->getFirstInsertionPt()); + for (unsigned I = 0; I < Count; ++I) + Res = Builder.CreateInsertElement(Res, CV[I], Builder.getInt32(I), + Op->getName() + ".upto" + Twine(I)); + Res->takeName(Op); + Op->replaceAllUsesWith(Res); + } + Op->eraseFromParent(); + } + Gathered.clear(); + Scattered.clear(); + return true; +} + +FunctionPass *llvm::createScalarizerPass() { + return new Scalarizer(); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp b/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp new file mode 100644 index 000000000000..4a96e0ddca16 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SeparateConstOffsetFromGEP.cpp @@ -0,0 +1,1310 @@ +//===- SeparateConstOffsetFromGEP.cpp -------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// Loop unrolling may create many similar GEPs for array accesses. +// e.g., a 2-level loop +// +// float a[32][32]; // global variable +// +// for (int i = 0; i < 2; ++i) { +// for (int j = 0; j < 2; ++j) { +// ... +// ... = a[x + i][y + j]; +// ... +// } +// } +// +// will probably be unrolled to: +// +// gep %a, 0, %x, %y; load +// gep %a, 0, %x, %y + 1; load +// gep %a, 0, %x + 1, %y; load +// gep %a, 0, %x + 1, %y + 1; load +// +// LLVM's GVN does not use partial redundancy elimination yet, and is thus +// unable to reuse (gep %a, 0, %x, %y). As a result, this misoptimization incurs +// significant slowdown in targets with limited addressing modes. For instance, +// because the PTX target does not support the reg+reg addressing mode, the +// NVPTX backend emits PTX code that literally computes the pointer address of +// each GEP, wasting tons of registers. It emits the following PTX for the +// first load and similar PTX for other loads. +// +// mov.u32 %r1, %x; +// mov.u32 %r2, %y; +// mul.wide.u32 %rl2, %r1, 128; +// mov.u64 %rl3, a; +// add.s64 %rl4, %rl3, %rl2; +// mul.wide.u32 %rl5, %r2, 4; +// add.s64 %rl6, %rl4, %rl5; +// ld.global.f32 %f1, [%rl6]; +// +// To reduce the register pressure, the optimization implemented in this file +// merges the common part of a group of GEPs, so we can compute each pointer +// address by adding a simple offset to the common part, saving many registers. +// +// It works by splitting each GEP into a variadic base and a constant offset. +// The variadic base can be computed once and reused by multiple GEPs, and the +// constant offsets can be nicely folded into the reg+immediate addressing mode +// (supported by most targets) without using any extra register. +// +// For instance, we transform the four GEPs and four loads in the above example +// into: +// +// base = gep a, 0, x, y +// load base +// laod base + 1 * sizeof(float) +// load base + 32 * sizeof(float) +// load base + 33 * sizeof(float) +// +// Given the transformed IR, a backend that supports the reg+immediate +// addressing mode can easily fold the pointer arithmetics into the loads. For +// example, the NVPTX backend can easily fold the pointer arithmetics into the +// ld.global.f32 instructions, and the resultant PTX uses much fewer registers. +// +// mov.u32 %r1, %tid.x; +// mov.u32 %r2, %tid.y; +// mul.wide.u32 %rl2, %r1, 128; +// mov.u64 %rl3, a; +// add.s64 %rl4, %rl3, %rl2; +// mul.wide.u32 %rl5, %r2, 4; +// add.s64 %rl6, %rl4, %rl5; +// ld.global.f32 %f1, [%rl6]; // so far the same as unoptimized PTX +// ld.global.f32 %f2, [%rl6+4]; // much better +// ld.global.f32 %f3, [%rl6+128]; // much better +// ld.global.f32 %f4, [%rl6+132]; // much better +// +// Another improvement enabled by the LowerGEP flag is to lower a GEP with +// multiple indices to either multiple GEPs with a single index or arithmetic +// operations (depending on whether the target uses alias analysis in codegen). +// Such transformation can have following benefits: +// (1) It can always extract constants in the indices of structure type. +// (2) After such Lowering, there are more optimization opportunities such as +// CSE, LICM and CGP. +// +// E.g. The following GEPs have multiple indices: +// BB1: +// %p = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 3 +// load %p +// ... +// BB2: +// %p2 = getelementptr [10 x %struct]* %ptr, i64 %i, i64 %j1, i32 2 +// load %p2 +// ... +// +// We can not do CSE to the common part related to index "i64 %i". Lowering +// GEPs can achieve such goals. +// If the target does not use alias analysis in codegen, this pass will +// lower a GEP with multiple indices into arithmetic operations: +// BB1: +// %1 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity +// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %3 = add i64 %1, %2 ; CSE opportunity +// %4 = mul i64 %j1, length_of_struct +// %5 = add i64 %3, %4 +// %6 = add i64 %3, struct_field_3 ; Constant offset +// %p = inttoptr i64 %6 to i32* +// load %p +// ... +// BB2: +// %7 = ptrtoint [10 x %struct]* %ptr to i64 ; CSE opportunity +// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %9 = add i64 %7, %8 ; CSE opportunity +// %10 = mul i64 %j2, length_of_struct +// %11 = add i64 %9, %10 +// %12 = add i64 %11, struct_field_2 ; Constant offset +// %p = inttoptr i64 %12 to i32* +// load %p2 +// ... +// +// If the target uses alias analysis in codegen, this pass will lower a GEP +// with multiple indices into multiple GEPs with a single index: +// BB1: +// %1 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity +// %2 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %3 = getelementptr i8* %1, i64 %2 ; CSE opportunity +// %4 = mul i64 %j1, length_of_struct +// %5 = getelementptr i8* %3, i64 %4 +// %6 = getelementptr i8* %5, struct_field_3 ; Constant offset +// %p = bitcast i8* %6 to i32* +// load %p +// ... +// BB2: +// %7 = bitcast [10 x %struct]* %ptr to i8* ; CSE opportunity +// %8 = mul i64 %i, length_of_10xstruct ; CSE opportunity +// %9 = getelementptr i8* %7, i64 %8 ; CSE opportunity +// %10 = mul i64 %j2, length_of_struct +// %11 = getelementptr i8* %9, i64 %10 +// %12 = getelementptr i8* %11, struct_field_2 ; Constant offset +// %p2 = bitcast i8* %12 to i32* +// load %p2 +// ... +// +// Lowering GEPs can also benefit other passes such as LICM and CGP. +// LICM (Loop Invariant Code Motion) can not hoist/sink a GEP of multiple +// indices if one of the index is variant. If we lower such GEP into invariant +// parts and variant parts, LICM can hoist/sink those invariant parts. +// CGP (CodeGen Prepare) tries to sink address calculations that match the +// target's addressing modes. A GEP with multiple indices may not match and will +// not be sunk. If we lower such GEP into smaller parts, CGP may sink some of +// them. So we end up with a better addressing mode. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/MemoryBuiltins.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/CodeGen/TargetSubtargetInfo.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> +#include <string> + +using namespace llvm; +using namespace llvm::PatternMatch; + +static cl::opt<bool> DisableSeparateConstOffsetFromGEP( + "disable-separate-const-offset-from-gep", cl::init(false), + cl::desc("Do not separate the constant offset from a GEP instruction"), + cl::Hidden); + +// Setting this flag may emit false positives when the input module already +// contains dead instructions. Therefore, we set it only in unit tests that are +// free of dead code. +static cl::opt<bool> + VerifyNoDeadCode("reassociate-geps-verify-no-dead-code", cl::init(false), + cl::desc("Verify this pass produces no dead code"), + cl::Hidden); + +namespace { + +/// \brief A helper class for separating a constant offset from a GEP index. +/// +/// In real programs, a GEP index may be more complicated than a simple addition +/// of something and a constant integer which can be trivially splitted. For +/// example, to split ((a << 3) | 5) + b, we need to search deeper for the +/// constant offset, so that we can separate the index to (a << 3) + b and 5. +/// +/// Therefore, this class looks into the expression that computes a given GEP +/// index, and tries to find a constant integer that can be hoisted to the +/// outermost level of the expression as an addition. Not every constant in an +/// expression can jump out. e.g., we cannot transform (b * (a + 5)) to (b * a + +/// 5); nor can we transform (3 * (a + 5)) to (3 * a + 5), however in this case, +/// -instcombine probably already optimized (3 * (a + 5)) to (3 * a + 15). +class ConstantOffsetExtractor { +public: + /// Extracts a constant offset from the given GEP index. It returns the + /// new index representing the remainder (equal to the original index minus + /// the constant offset), or nullptr if we cannot extract a constant offset. + /// \p Idx The given GEP index + /// \p GEP The given GEP + /// \p UserChainTail Outputs the tail of UserChain so that we can + /// garbage-collect unused instructions in UserChain. + static Value *Extract(Value *Idx, GetElementPtrInst *GEP, + User *&UserChainTail, const DominatorTree *DT); + + /// Looks for a constant offset from the given GEP index without extracting + /// it. It returns the numeric value of the extracted constant offset (0 if + /// failed). The meaning of the arguments are the same as Extract. + static int64_t Find(Value *Idx, GetElementPtrInst *GEP, + const DominatorTree *DT); + +private: + ConstantOffsetExtractor(Instruction *InsertionPt, const DominatorTree *DT) + : IP(InsertionPt), DL(InsertionPt->getModule()->getDataLayout()), DT(DT) { + } + + /// Searches the expression that computes V for a non-zero constant C s.t. + /// V can be reassociated into the form V' + C. If the searching is + /// successful, returns C and update UserChain as a def-use chain from C to V; + /// otherwise, UserChain is empty. + /// + /// \p V The given expression + /// \p SignExtended Whether V will be sign-extended in the computation of the + /// GEP index + /// \p ZeroExtended Whether V will be zero-extended in the computation of the + /// GEP index + /// \p NonNegative Whether V is guaranteed to be non-negative. For example, + /// an index of an inbounds GEP is guaranteed to be + /// non-negative. Levaraging this, we can better split + /// inbounds GEPs. + APInt find(Value *V, bool SignExtended, bool ZeroExtended, bool NonNegative); + + /// A helper function to look into both operands of a binary operator. + APInt findInEitherOperand(BinaryOperator *BO, bool SignExtended, + bool ZeroExtended); + + /// After finding the constant offset C from the GEP index I, we build a new + /// index I' s.t. I' + C = I. This function builds and returns the new + /// index I' according to UserChain produced by function "find". + /// + /// The building conceptually takes two steps: + /// 1) iteratively distribute s/zext towards the leaves of the expression tree + /// that computes I + /// 2) reassociate the expression tree to the form I' + C. + /// + /// For example, to extract the 5 from sext(a + (b + 5)), we first distribute + /// sext to a, b and 5 so that we have + /// sext(a) + (sext(b) + 5). + /// Then, we reassociate it to + /// (sext(a) + sext(b)) + 5. + /// Given this form, we know I' is sext(a) + sext(b). + Value *rebuildWithoutConstOffset(); + + /// After the first step of rebuilding the GEP index without the constant + /// offset, distribute s/zext to the operands of all operators in UserChain. + /// e.g., zext(sext(a + (b + 5)) (assuming no overflow) => + /// zext(sext(a)) + (zext(sext(b)) + zext(sext(5))). + /// + /// The function also updates UserChain to point to new subexpressions after + /// distributing s/zext. e.g., the old UserChain of the above example is + /// 5 -> b + 5 -> a + (b + 5) -> sext(...) -> zext(sext(...)), + /// and the new UserChain is + /// zext(sext(5)) -> zext(sext(b)) + zext(sext(5)) -> + /// zext(sext(a)) + (zext(sext(b)) + zext(sext(5)) + /// + /// \p ChainIndex The index to UserChain. ChainIndex is initially + /// UserChain.size() - 1, and is decremented during + /// the recursion. + Value *distributeExtsAndCloneChain(unsigned ChainIndex); + + /// Reassociates the GEP index to the form I' + C and returns I'. + Value *removeConstOffset(unsigned ChainIndex); + + /// A helper function to apply ExtInsts, a list of s/zext, to value V. + /// e.g., if ExtInsts = [sext i32 to i64, zext i16 to i32], this function + /// returns "sext i32 (zext i16 V to i32) to i64". + Value *applyExts(Value *V); + + /// A helper function that returns whether we can trace into the operands + /// of binary operator BO for a constant offset. + /// + /// \p SignExtended Whether BO is surrounded by sext + /// \p ZeroExtended Whether BO is surrounded by zext + /// \p NonNegative Whether BO is known to be non-negative, e.g., an in-bound + /// array index. + bool CanTraceInto(bool SignExtended, bool ZeroExtended, BinaryOperator *BO, + bool NonNegative); + + /// The path from the constant offset to the old GEP index. e.g., if the GEP + /// index is "a * b + (c + 5)". After running function find, UserChain[0] will + /// be the constant 5, UserChain[1] will be the subexpression "c + 5", and + /// UserChain[2] will be the entire expression "a * b + (c + 5)". + /// + /// This path helps to rebuild the new GEP index. + SmallVector<User *, 8> UserChain; + + /// A data structure used in rebuildWithoutConstOffset. Contains all + /// sext/zext instructions along UserChain. + SmallVector<CastInst *, 16> ExtInsts; + + /// Insertion position of cloned instructions. + Instruction *IP; + + const DataLayout &DL; + const DominatorTree *DT; +}; + +/// \brief A pass that tries to split every GEP in the function into a variadic +/// base and a constant offset. It is a FunctionPass because searching for the +/// constant offset may inspect other basic blocks. +class SeparateConstOffsetFromGEP : public FunctionPass { +public: + static char ID; + + SeparateConstOffsetFromGEP(const TargetMachine *TM = nullptr, + bool LowerGEP = false) + : FunctionPass(ID), TM(TM), LowerGEP(LowerGEP) { + initializeSeparateConstOffsetFromGEPPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.setPreservesCFG(); + AU.addRequired<TargetLibraryInfoWrapperPass>(); + } + + bool doInitialization(Module &M) override { + DL = &M.getDataLayout(); + return false; + } + + bool runOnFunction(Function &F) override; + +private: + /// Tries to split the given GEP into a variadic base and a constant offset, + /// and returns true if the splitting succeeds. + bool splitGEP(GetElementPtrInst *GEP); + + /// Lower a GEP with multiple indices into multiple GEPs with a single index. + /// Function splitGEP already split the original GEP into a variadic part and + /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the + /// variadic part into a set of GEPs with a single index and applies + /// AccumulativeByteOffset to it. + /// \p Variadic The variadic part of the original GEP. + /// \p AccumulativeByteOffset The constant offset. + void lowerToSingleIndexGEPs(GetElementPtrInst *Variadic, + int64_t AccumulativeByteOffset); + + /// Lower a GEP with multiple indices into ptrtoint+arithmetics+inttoptr form. + /// Function splitGEP already split the original GEP into a variadic part and + /// a constant offset (i.e., AccumulativeByteOffset). This function lowers the + /// variadic part into a set of arithmetic operations and applies + /// AccumulativeByteOffset to it. + /// \p Variadic The variadic part of the original GEP. + /// \p AccumulativeByteOffset The constant offset. + void lowerToArithmetics(GetElementPtrInst *Variadic, + int64_t AccumulativeByteOffset); + + /// Finds the constant offset within each index and accumulates them. If + /// LowerGEP is true, it finds in indices of both sequential and structure + /// types, otherwise it only finds in sequential indices. The output + /// NeedsExtraction indicates whether we successfully find a non-zero constant + /// offset. + int64_t accumulateByteOffset(GetElementPtrInst *GEP, bool &NeedsExtraction); + + /// Canonicalize array indices to pointer-size integers. This helps to + /// simplify the logic of splitting a GEP. For example, if a + b is a + /// pointer-size integer, we have + /// gep base, a + b = gep (gep base, a), b + /// However, this equality may not hold if the size of a + b is smaller than + /// the pointer size, because LLVM conceptually sign-extends GEP indices to + /// pointer size before computing the address + /// (http://llvm.org/docs/LangRef.html#id181). + /// + /// This canonicalization is very likely already done in clang and + /// instcombine. Therefore, the program will probably remain the same. + /// + /// Returns true if the module changes. + /// + /// Verified in @i32_add in split-gep.ll + bool canonicalizeArrayIndicesToPointerSize(GetElementPtrInst *GEP); + + /// Optimize sext(a)+sext(b) to sext(a+b) when a+b can't sign overflow. + /// SeparateConstOffsetFromGEP distributes a sext to leaves before extracting + /// the constant offset. After extraction, it becomes desirable to reunion the + /// distributed sexts. For example, + /// + /// &a[sext(i +nsw (j +nsw 5)] + /// => distribute &a[sext(i) +nsw (sext(j) +nsw 5)] + /// => constant extraction &a[sext(i) + sext(j)] + 5 + /// => reunion &a[sext(i +nsw j)] + 5 + bool reuniteExts(Function &F); + + /// A helper that reunites sexts in an instruction. + bool reuniteExts(Instruction *I); + + /// Find the closest dominator of <Dominatee> that is equivalent to <Key>. + Instruction *findClosestMatchingDominator(const SCEV *Key, + Instruction *Dominatee); + /// Verify F is free of dead code. + void verifyNoDeadCode(Function &F); + + bool hasMoreThanOneUseInLoop(Value *v, Loop *L); + + // Swap the index operand of two GEP. + void swapGEPOperand(GetElementPtrInst *First, GetElementPtrInst *Second); + + // Check if it is safe to swap operand of two GEP. + bool isLegalToSwapOperand(GetElementPtrInst *First, GetElementPtrInst *Second, + Loop *CurLoop); + + const DataLayout *DL = nullptr; + DominatorTree *DT = nullptr; + ScalarEvolution *SE; + const TargetMachine *TM; + + LoopInfo *LI; + TargetLibraryInfo *TLI; + + /// Whether to lower a GEP with multiple indices into arithmetic operations or + /// multiple GEPs with a single index. + bool LowerGEP; + + DenseMap<const SCEV *, SmallVector<Instruction *, 2>> DominatingExprs; +}; + +} // end anonymous namespace + +char SeparateConstOffsetFromGEP::ID = 0; + +INITIALIZE_PASS_BEGIN( + SeparateConstOffsetFromGEP, "separate-const-offset-from-gep", + "Split GEPs to a variadic base and a constant offset for better CSE", false, + false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) +INITIALIZE_PASS_END( + SeparateConstOffsetFromGEP, "separate-const-offset-from-gep", + "Split GEPs to a variadic base and a constant offset for better CSE", false, + false) + +FunctionPass * +llvm::createSeparateConstOffsetFromGEPPass(const TargetMachine *TM, + bool LowerGEP) { + return new SeparateConstOffsetFromGEP(TM, LowerGEP); +} + +bool ConstantOffsetExtractor::CanTraceInto(bool SignExtended, + bool ZeroExtended, + BinaryOperator *BO, + bool NonNegative) { + // We only consider ADD, SUB and OR, because a non-zero constant found in + // expressions composed of these operations can be easily hoisted as a + // constant offset by reassociation. + if (BO->getOpcode() != Instruction::Add && + BO->getOpcode() != Instruction::Sub && + BO->getOpcode() != Instruction::Or) { + return false; + } + + Value *LHS = BO->getOperand(0), *RHS = BO->getOperand(1); + // Do not trace into "or" unless it is equivalent to "add". If LHS and RHS + // don't have common bits, (LHS | RHS) is equivalent to (LHS + RHS). + if (BO->getOpcode() == Instruction::Or && + !haveNoCommonBitsSet(LHS, RHS, DL, nullptr, BO, DT)) + return false; + + // In addition, tracing into BO requires that its surrounding s/zext (if + // any) is distributable to both operands. + // + // Suppose BO = A op B. + // SignExtended | ZeroExtended | Distributable? + // --------------+--------------+---------------------------------- + // 0 | 0 | true because no s/zext exists + // 0 | 1 | zext(BO) == zext(A) op zext(B) + // 1 | 0 | sext(BO) == sext(A) op sext(B) + // 1 | 1 | zext(sext(BO)) == + // | | zext(sext(A)) op zext(sext(B)) + if (BO->getOpcode() == Instruction::Add && !ZeroExtended && NonNegative) { + // If a + b >= 0 and (a >= 0 or b >= 0), then + // sext(a + b) = sext(a) + sext(b) + // even if the addition is not marked nsw. + // + // Leveraging this invarient, we can trace into an sext'ed inbound GEP + // index if the constant offset is non-negative. + // + // Verified in @sext_add in split-gep.ll. + if (ConstantInt *ConstLHS = dyn_cast<ConstantInt>(LHS)) { + if (!ConstLHS->isNegative()) + return true; + } + if (ConstantInt *ConstRHS = dyn_cast<ConstantInt>(RHS)) { + if (!ConstRHS->isNegative()) + return true; + } + } + + // sext (add/sub nsw A, B) == add/sub nsw (sext A), (sext B) + // zext (add/sub nuw A, B) == add/sub nuw (zext A), (zext B) + if (BO->getOpcode() == Instruction::Add || + BO->getOpcode() == Instruction::Sub) { + if (SignExtended && !BO->hasNoSignedWrap()) + return false; + if (ZeroExtended && !BO->hasNoUnsignedWrap()) + return false; + } + + return true; +} + +APInt ConstantOffsetExtractor::findInEitherOperand(BinaryOperator *BO, + bool SignExtended, + bool ZeroExtended) { + // BO being non-negative does not shed light on whether its operands are + // non-negative. Clear the NonNegative flag here. + APInt ConstantOffset = find(BO->getOperand(0), SignExtended, ZeroExtended, + /* NonNegative */ false); + // If we found a constant offset in the left operand, stop and return that. + // This shortcut might cause us to miss opportunities of combining the + // constant offsets in both operands, e.g., (a + 4) + (b + 5) => (a + b) + 9. + // However, such cases are probably already handled by -instcombine, + // given this pass runs after the standard optimizations. + if (ConstantOffset != 0) return ConstantOffset; + ConstantOffset = find(BO->getOperand(1), SignExtended, ZeroExtended, + /* NonNegative */ false); + // If U is a sub operator, negate the constant offset found in the right + // operand. + if (BO->getOpcode() == Instruction::Sub) + ConstantOffset = -ConstantOffset; + return ConstantOffset; +} + +APInt ConstantOffsetExtractor::find(Value *V, bool SignExtended, + bool ZeroExtended, bool NonNegative) { + // TODO(jingyue): We could trace into integer/pointer casts, such as + // inttoptr, ptrtoint, bitcast, and addrspacecast. We choose to handle only + // integers because it gives good enough results for our benchmarks. + unsigned BitWidth = cast<IntegerType>(V->getType())->getBitWidth(); + + // We cannot do much with Values that are not a User, such as an Argument. + User *U = dyn_cast<User>(V); + if (U == nullptr) return APInt(BitWidth, 0); + + APInt ConstantOffset(BitWidth, 0); + if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) { + // Hooray, we found it! + ConstantOffset = CI->getValue(); + } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(V)) { + // Trace into subexpressions for more hoisting opportunities. + if (CanTraceInto(SignExtended, ZeroExtended, BO, NonNegative)) + ConstantOffset = findInEitherOperand(BO, SignExtended, ZeroExtended); + } else if (isa<SExtInst>(V)) { + ConstantOffset = find(U->getOperand(0), /* SignExtended */ true, + ZeroExtended, NonNegative).sext(BitWidth); + } else if (isa<ZExtInst>(V)) { + // As an optimization, we can clear the SignExtended flag because + // sext(zext(a)) = zext(a). Verified in @sext_zext in split-gep.ll. + // + // Clear the NonNegative flag, because zext(a) >= 0 does not imply a >= 0. + ConstantOffset = + find(U->getOperand(0), /* SignExtended */ false, + /* ZeroExtended */ true, /* NonNegative */ false).zext(BitWidth); + } + + // If we found a non-zero constant offset, add it to the path for + // rebuildWithoutConstOffset. Zero is a valid constant offset, but doesn't + // help this optimization. + if (ConstantOffset != 0) + UserChain.push_back(U); + return ConstantOffset; +} + +Value *ConstantOffsetExtractor::applyExts(Value *V) { + Value *Current = V; + // ExtInsts is built in the use-def order. Therefore, we apply them to V + // in the reversed order. + for (auto I = ExtInsts.rbegin(), E = ExtInsts.rend(); I != E; ++I) { + if (Constant *C = dyn_cast<Constant>(Current)) { + // If Current is a constant, apply s/zext using ConstantExpr::getCast. + // ConstantExpr::getCast emits a ConstantInt if C is a ConstantInt. + Current = ConstantExpr::getCast((*I)->getOpcode(), C, (*I)->getType()); + } else { + Instruction *Ext = (*I)->clone(); + Ext->setOperand(0, Current); + Ext->insertBefore(IP); + Current = Ext; + } + } + return Current; +} + +Value *ConstantOffsetExtractor::rebuildWithoutConstOffset() { + distributeExtsAndCloneChain(UserChain.size() - 1); + // Remove all nullptrs (used to be s/zext) from UserChain. + unsigned NewSize = 0; + for (User *I : UserChain) { + if (I != nullptr) { + UserChain[NewSize] = I; + NewSize++; + } + } + UserChain.resize(NewSize); + return removeConstOffset(UserChain.size() - 1); +} + +Value * +ConstantOffsetExtractor::distributeExtsAndCloneChain(unsigned ChainIndex) { + User *U = UserChain[ChainIndex]; + if (ChainIndex == 0) { + assert(isa<ConstantInt>(U)); + // If U is a ConstantInt, applyExts will return a ConstantInt as well. + return UserChain[ChainIndex] = cast<ConstantInt>(applyExts(U)); + } + + if (CastInst *Cast = dyn_cast<CastInst>(U)) { + assert((isa<SExtInst>(Cast) || isa<ZExtInst>(Cast)) && + "We only traced into two types of CastInst: sext and zext"); + ExtInsts.push_back(Cast); + UserChain[ChainIndex] = nullptr; + return distributeExtsAndCloneChain(ChainIndex - 1); + } + + // Function find only trace into BinaryOperator and CastInst. + BinaryOperator *BO = cast<BinaryOperator>(U); + // OpNo = which operand of BO is UserChain[ChainIndex - 1] + unsigned OpNo = (BO->getOperand(0) == UserChain[ChainIndex - 1] ? 0 : 1); + Value *TheOther = applyExts(BO->getOperand(1 - OpNo)); + Value *NextInChain = distributeExtsAndCloneChain(ChainIndex - 1); + + BinaryOperator *NewBO = nullptr; + if (OpNo == 0) { + NewBO = BinaryOperator::Create(BO->getOpcode(), NextInChain, TheOther, + BO->getName(), IP); + } else { + NewBO = BinaryOperator::Create(BO->getOpcode(), TheOther, NextInChain, + BO->getName(), IP); + } + return UserChain[ChainIndex] = NewBO; +} + +Value *ConstantOffsetExtractor::removeConstOffset(unsigned ChainIndex) { + if (ChainIndex == 0) { + assert(isa<ConstantInt>(UserChain[ChainIndex])); + return ConstantInt::getNullValue(UserChain[ChainIndex]->getType()); + } + + BinaryOperator *BO = cast<BinaryOperator>(UserChain[ChainIndex]); + assert(BO->getNumUses() <= 1 && + "distributeExtsAndCloneChain clones each BinaryOperator in " + "UserChain, so no one should be used more than " + "once"); + + unsigned OpNo = (BO->getOperand(0) == UserChain[ChainIndex - 1] ? 0 : 1); + assert(BO->getOperand(OpNo) == UserChain[ChainIndex - 1]); + Value *NextInChain = removeConstOffset(ChainIndex - 1); + Value *TheOther = BO->getOperand(1 - OpNo); + + // If NextInChain is 0 and not the LHS of a sub, we can simplify the + // sub-expression to be just TheOther. + if (ConstantInt *CI = dyn_cast<ConstantInt>(NextInChain)) { + if (CI->isZero() && !(BO->getOpcode() == Instruction::Sub && OpNo == 0)) + return TheOther; + } + + BinaryOperator::BinaryOps NewOp = BO->getOpcode(); + if (BO->getOpcode() == Instruction::Or) { + // Rebuild "or" as "add", because "or" may be invalid for the new + // epxression. + // + // For instance, given + // a | (b + 5) where a and b + 5 have no common bits, + // we can extract 5 as the constant offset. + // + // However, reusing the "or" in the new index would give us + // (a | b) + 5 + // which does not equal a | (b + 5). + // + // Replacing the "or" with "add" is fine, because + // a | (b + 5) = a + (b + 5) = (a + b) + 5 + NewOp = Instruction::Add; + } + + BinaryOperator *NewBO; + if (OpNo == 0) { + NewBO = BinaryOperator::Create(NewOp, NextInChain, TheOther, "", IP); + } else { + NewBO = BinaryOperator::Create(NewOp, TheOther, NextInChain, "", IP); + } + NewBO->takeName(BO); + return NewBO; +} + +Value *ConstantOffsetExtractor::Extract(Value *Idx, GetElementPtrInst *GEP, + User *&UserChainTail, + const DominatorTree *DT) { + ConstantOffsetExtractor Extractor(GEP, DT); + // Find a non-zero constant offset first. + APInt ConstantOffset = + Extractor.find(Idx, /* SignExtended */ false, /* ZeroExtended */ false, + GEP->isInBounds()); + if (ConstantOffset == 0) { + UserChainTail = nullptr; + return nullptr; + } + // Separates the constant offset from the GEP index. + Value *IdxWithoutConstOffset = Extractor.rebuildWithoutConstOffset(); + UserChainTail = Extractor.UserChain.back(); + return IdxWithoutConstOffset; +} + +int64_t ConstantOffsetExtractor::Find(Value *Idx, GetElementPtrInst *GEP, + const DominatorTree *DT) { + // If Idx is an index of an inbound GEP, Idx is guaranteed to be non-negative. + return ConstantOffsetExtractor(GEP, DT) + .find(Idx, /* SignExtended */ false, /* ZeroExtended */ false, + GEP->isInBounds()) + .getSExtValue(); +} + +bool SeparateConstOffsetFromGEP::canonicalizeArrayIndicesToPointerSize( + GetElementPtrInst *GEP) { + bool Changed = false; + Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + gep_type_iterator GTI = gep_type_begin(*GEP); + for (User::op_iterator I = GEP->op_begin() + 1, E = GEP->op_end(); + I != E; ++I, ++GTI) { + // Skip struct member indices which must be i32. + if (GTI.isSequential()) { + if ((*I)->getType() != IntPtrTy) { + *I = CastInst::CreateIntegerCast(*I, IntPtrTy, true, "idxprom", GEP); + Changed = true; + } + } + } + return Changed; +} + +int64_t +SeparateConstOffsetFromGEP::accumulateByteOffset(GetElementPtrInst *GEP, + bool &NeedsExtraction) { + NeedsExtraction = false; + int64_t AccumulativeByteOffset = 0; + gep_type_iterator GTI = gep_type_begin(*GEP); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isSequential()) { + // Tries to extract a constant offset from this GEP index. + int64_t ConstantOffset = + ConstantOffsetExtractor::Find(GEP->getOperand(I), GEP, DT); + if (ConstantOffset != 0) { + NeedsExtraction = true; + // A GEP may have multiple indices. We accumulate the extracted + // constant offset to a byte offset, and later offset the remainder of + // the original GEP with this byte offset. + AccumulativeByteOffset += + ConstantOffset * DL->getTypeAllocSize(GTI.getIndexedType()); + } + } else if (LowerGEP) { + StructType *StTy = GTI.getStructType(); + uint64_t Field = cast<ConstantInt>(GEP->getOperand(I))->getZExtValue(); + // Skip field 0 as the offset is always 0. + if (Field != 0) { + NeedsExtraction = true; + AccumulativeByteOffset += + DL->getStructLayout(StTy)->getElementOffset(Field); + } + } + } + return AccumulativeByteOffset; +} + +void SeparateConstOffsetFromGEP::lowerToSingleIndexGEPs( + GetElementPtrInst *Variadic, int64_t AccumulativeByteOffset) { + IRBuilder<> Builder(Variadic); + Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); + + Type *I8PtrTy = + Builder.getInt8PtrTy(Variadic->getType()->getPointerAddressSpace()); + Value *ResultPtr = Variadic->getOperand(0); + Loop *L = LI->getLoopFor(Variadic->getParent()); + // Check if the base is not loop invariant or used more than once. + bool isSwapCandidate = + L && L->isLoopInvariant(ResultPtr) && + !hasMoreThanOneUseInLoop(ResultPtr, L); + Value *FirstResult = nullptr; + + if (ResultPtr->getType() != I8PtrTy) + ResultPtr = Builder.CreateBitCast(ResultPtr, I8PtrTy); + + gep_type_iterator GTI = gep_type_begin(*Variadic); + // Create an ugly GEP for each sequential index. We don't create GEPs for + // structure indices, as they are accumulated in the constant offset index. + for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isSequential()) { + Value *Idx = Variadic->getOperand(I); + // Skip zero indices. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) + if (CI->isZero()) + continue; + + APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(), + DL->getTypeAllocSize(GTI.getIndexedType())); + // Scale the index by element size. + if (ElementSize != 1) { + if (ElementSize.isPowerOf2()) { + Idx = Builder.CreateShl( + Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2())); + } else { + Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize)); + } + } + // Create an ugly GEP with a single index for each index. + ResultPtr = + Builder.CreateGEP(Builder.getInt8Ty(), ResultPtr, Idx, "uglygep"); + if (FirstResult == nullptr) + FirstResult = ResultPtr; + } + } + + // Create a GEP with the constant offset index. + if (AccumulativeByteOffset != 0) { + Value *Offset = ConstantInt::get(IntPtrTy, AccumulativeByteOffset); + ResultPtr = + Builder.CreateGEP(Builder.getInt8Ty(), ResultPtr, Offset, "uglygep"); + } else + isSwapCandidate = false; + + // If we created a GEP with constant index, and the base is loop invariant, + // then we swap the first one with it, so LICM can move constant GEP out + // later. + GetElementPtrInst *FirstGEP = dyn_cast_or_null<GetElementPtrInst>(FirstResult); + GetElementPtrInst *SecondGEP = dyn_cast_or_null<GetElementPtrInst>(ResultPtr); + if (isSwapCandidate && isLegalToSwapOperand(FirstGEP, SecondGEP, L)) + swapGEPOperand(FirstGEP, SecondGEP); + + if (ResultPtr->getType() != Variadic->getType()) + ResultPtr = Builder.CreateBitCast(ResultPtr, Variadic->getType()); + + Variadic->replaceAllUsesWith(ResultPtr); + Variadic->eraseFromParent(); +} + +void +SeparateConstOffsetFromGEP::lowerToArithmetics(GetElementPtrInst *Variadic, + int64_t AccumulativeByteOffset) { + IRBuilder<> Builder(Variadic); + Type *IntPtrTy = DL->getIntPtrType(Variadic->getType()); + + Value *ResultPtr = Builder.CreatePtrToInt(Variadic->getOperand(0), IntPtrTy); + gep_type_iterator GTI = gep_type_begin(*Variadic); + // Create ADD/SHL/MUL arithmetic operations for each sequential indices. We + // don't create arithmetics for structure indices, as they are accumulated + // in the constant offset index. + for (unsigned I = 1, E = Variadic->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isSequential()) { + Value *Idx = Variadic->getOperand(I); + // Skip zero indices. + if (ConstantInt *CI = dyn_cast<ConstantInt>(Idx)) + if (CI->isZero()) + continue; + + APInt ElementSize = APInt(IntPtrTy->getIntegerBitWidth(), + DL->getTypeAllocSize(GTI.getIndexedType())); + // Scale the index by element size. + if (ElementSize != 1) { + if (ElementSize.isPowerOf2()) { + Idx = Builder.CreateShl( + Idx, ConstantInt::get(IntPtrTy, ElementSize.logBase2())); + } else { + Idx = Builder.CreateMul(Idx, ConstantInt::get(IntPtrTy, ElementSize)); + } + } + // Create an ADD for each index. + ResultPtr = Builder.CreateAdd(ResultPtr, Idx); + } + } + + // Create an ADD for the constant offset index. + if (AccumulativeByteOffset != 0) { + ResultPtr = Builder.CreateAdd( + ResultPtr, ConstantInt::get(IntPtrTy, AccumulativeByteOffset)); + } + + ResultPtr = Builder.CreateIntToPtr(ResultPtr, Variadic->getType()); + Variadic->replaceAllUsesWith(ResultPtr); + Variadic->eraseFromParent(); +} + +bool SeparateConstOffsetFromGEP::splitGEP(GetElementPtrInst *GEP) { + // Skip vector GEPs. + if (GEP->getType()->isVectorTy()) + return false; + + // The backend can already nicely handle the case where all indices are + // constant. + if (GEP->hasAllConstantIndices()) + return false; + + bool Changed = canonicalizeArrayIndicesToPointerSize(GEP); + + bool NeedsExtraction; + int64_t AccumulativeByteOffset = accumulateByteOffset(GEP, NeedsExtraction); + + if (!NeedsExtraction) + return Changed; + // If LowerGEP is disabled, before really splitting the GEP, check whether the + // backend supports the addressing mode we are about to produce. If no, this + // splitting probably won't be beneficial. + // If LowerGEP is enabled, even the extracted constant offset can not match + // the addressing mode, we can still do optimizations to other lowered parts + // of variable indices. Therefore, we don't check for addressing modes in that + // case. + if (!LowerGEP) { + TargetTransformInfo &TTI = + getAnalysis<TargetTransformInfoWrapperPass>().getTTI( + *GEP->getParent()->getParent()); + unsigned AddrSpace = GEP->getPointerAddressSpace(); + if (!TTI.isLegalAddressingMode(GEP->getResultElementType(), + /*BaseGV=*/nullptr, AccumulativeByteOffset, + /*HasBaseReg=*/true, /*Scale=*/0, + AddrSpace)) { + return Changed; + } + } + + // Remove the constant offset in each sequential index. The resultant GEP + // computes the variadic base. + // Notice that we don't remove struct field indices here. If LowerGEP is + // disabled, a structure index is not accumulated and we still use the old + // one. If LowerGEP is enabled, a structure index is accumulated in the + // constant offset. LowerToSingleIndexGEPs or lowerToArithmetics will later + // handle the constant offset and won't need a new structure index. + gep_type_iterator GTI = gep_type_begin(*GEP); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isSequential()) { + // Splits this GEP index into a variadic part and a constant offset, and + // uses the variadic part as the new index. + Value *OldIdx = GEP->getOperand(I); + User *UserChainTail; + Value *NewIdx = + ConstantOffsetExtractor::Extract(OldIdx, GEP, UserChainTail, DT); + if (NewIdx != nullptr) { + // Switches to the index with the constant offset removed. + GEP->setOperand(I, NewIdx); + // After switching to the new index, we can garbage-collect UserChain + // and the old index if they are not used. + RecursivelyDeleteTriviallyDeadInstructions(UserChainTail); + RecursivelyDeleteTriviallyDeadInstructions(OldIdx); + } + } + } + + // Clear the inbounds attribute because the new index may be off-bound. + // e.g., + // + // b = add i64 a, 5 + // addr = gep inbounds float, float* p, i64 b + // + // is transformed to: + // + // addr2 = gep float, float* p, i64 a ; inbounds removed + // addr = gep inbounds float, float* addr2, i64 5 + // + // If a is -4, although the old index b is in bounds, the new index a is + // off-bound. http://llvm.org/docs/LangRef.html#id181 says "if the + // inbounds keyword is not present, the offsets are added to the base + // address with silently-wrapping two's complement arithmetic". + // Therefore, the final code will be a semantically equivalent. + // + // TODO(jingyue): do some range analysis to keep as many inbounds as + // possible. GEPs with inbounds are more friendly to alias analysis. + bool GEPWasInBounds = GEP->isInBounds(); + GEP->setIsInBounds(false); + + // Lowers a GEP to either GEPs with a single index or arithmetic operations. + if (LowerGEP) { + // As currently BasicAA does not analyze ptrtoint/inttoptr, do not lower to + // arithmetic operations if the target uses alias analysis in codegen. + if (TM && TM->getSubtargetImpl(*GEP->getParent()->getParent())->useAA()) + lowerToSingleIndexGEPs(GEP, AccumulativeByteOffset); + else + lowerToArithmetics(GEP, AccumulativeByteOffset); + return true; + } + + // No need to create another GEP if the accumulative byte offset is 0. + if (AccumulativeByteOffset == 0) + return true; + + // Offsets the base with the accumulative byte offset. + // + // %gep ; the base + // ... %gep ... + // + // => add the offset + // + // %gep2 ; clone of %gep + // %new.gep = gep %gep2, <offset / sizeof(*%gep)> + // %gep ; will be removed + // ... %gep ... + // + // => replace all uses of %gep with %new.gep and remove %gep + // + // %gep2 ; clone of %gep + // %new.gep = gep %gep2, <offset / sizeof(*%gep)> + // ... %new.gep ... + // + // If AccumulativeByteOffset is not a multiple of sizeof(*%gep), we emit an + // uglygep (http://llvm.org/docs/GetElementPtr.html#what-s-an-uglygep): + // bitcast %gep2 to i8*, add the offset, and bitcast the result back to the + // type of %gep. + // + // %gep2 ; clone of %gep + // %0 = bitcast %gep2 to i8* + // %uglygep = gep %0, <offset> + // %new.gep = bitcast %uglygep to <type of %gep> + // ... %new.gep ... + Instruction *NewGEP = GEP->clone(); + NewGEP->insertBefore(GEP); + + // Per ANSI C standard, signed / unsigned = unsigned and signed % unsigned = + // unsigned.. Therefore, we cast ElementTypeSizeOfGEP to signed because it is + // used with unsigned integers later. + int64_t ElementTypeSizeOfGEP = static_cast<int64_t>( + DL->getTypeAllocSize(GEP->getResultElementType())); + Type *IntPtrTy = DL->getIntPtrType(GEP->getType()); + if (AccumulativeByteOffset % ElementTypeSizeOfGEP == 0) { + // Very likely. As long as %gep is natually aligned, the byte offset we + // extracted should be a multiple of sizeof(*%gep). + int64_t Index = AccumulativeByteOffset / ElementTypeSizeOfGEP; + NewGEP = GetElementPtrInst::Create(GEP->getResultElementType(), NewGEP, + ConstantInt::get(IntPtrTy, Index, true), + GEP->getName(), GEP); + NewGEP->copyMetadata(*GEP); + // Inherit the inbounds attribute of the original GEP. + cast<GetElementPtrInst>(NewGEP)->setIsInBounds(GEPWasInBounds); + } else { + // Unlikely but possible. For example, + // #pragma pack(1) + // struct S { + // int a[3]; + // int64 b[8]; + // }; + // #pragma pack() + // + // Suppose the gep before extraction is &s[i + 1].b[j + 3]. After + // extraction, it becomes &s[i].b[j] and AccumulativeByteOffset is + // sizeof(S) + 3 * sizeof(int64) = 100, which is not a multiple of + // sizeof(int64). + // + // Emit an uglygep in this case. + Type *I8PtrTy = Type::getInt8PtrTy(GEP->getContext(), + GEP->getPointerAddressSpace()); + NewGEP = new BitCastInst(NewGEP, I8PtrTy, "", GEP); + NewGEP = GetElementPtrInst::Create( + Type::getInt8Ty(GEP->getContext()), NewGEP, + ConstantInt::get(IntPtrTy, AccumulativeByteOffset, true), "uglygep", + GEP); + NewGEP->copyMetadata(*GEP); + // Inherit the inbounds attribute of the original GEP. + cast<GetElementPtrInst>(NewGEP)->setIsInBounds(GEPWasInBounds); + if (GEP->getType() != I8PtrTy) + NewGEP = new BitCastInst(NewGEP, GEP->getType(), GEP->getName(), GEP); + } + + GEP->replaceAllUsesWith(NewGEP); + GEP->eraseFromParent(); + + return true; +} + +bool SeparateConstOffsetFromGEP::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + if (DisableSeparateConstOffsetFromGEP) + return false; + + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); + bool Changed = false; + for (BasicBlock &B : F) { + for (BasicBlock::iterator I = B.begin(), IE = B.end(); I != IE;) + if (GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(I++)) + Changed |= splitGEP(GEP); + // No need to split GEP ConstantExprs because all its indices are constant + // already. + } + + Changed |= reuniteExts(F); + + if (VerifyNoDeadCode) + verifyNoDeadCode(F); + + return Changed; +} + +Instruction *SeparateConstOffsetFromGEP::findClosestMatchingDominator( + const SCEV *Key, Instruction *Dominatee) { + auto Pos = DominatingExprs.find(Key); + if (Pos == DominatingExprs.end()) + return nullptr; + + auto &Candidates = Pos->second; + // Because we process the basic blocks in pre-order of the dominator tree, a + // candidate that doesn't dominate the current instruction won't dominate any + // future instruction either. Therefore, we pop it out of the stack. This + // optimization makes the algorithm O(n). + while (!Candidates.empty()) { + Instruction *Candidate = Candidates.back(); + if (DT->dominates(Candidate, Dominatee)) + return Candidate; + Candidates.pop_back(); + } + return nullptr; +} + +bool SeparateConstOffsetFromGEP::reuniteExts(Instruction *I) { + if (!SE->isSCEVable(I->getType())) + return false; + + // Dom: LHS+RHS + // I: sext(LHS)+sext(RHS) + // If Dom can't sign overflow and Dom dominates I, optimize I to sext(Dom). + // TODO: handle zext + Value *LHS = nullptr, *RHS = nullptr; + if (match(I, m_Add(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS)))) || + match(I, m_Sub(m_SExt(m_Value(LHS)), m_SExt(m_Value(RHS))))) { + if (LHS->getType() == RHS->getType()) { + const SCEV *Key = + SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); + if (auto *Dom = findClosestMatchingDominator(Key, I)) { + Instruction *NewSExt = new SExtInst(Dom, I->getType(), "", I); + NewSExt->takeName(I); + I->replaceAllUsesWith(NewSExt); + RecursivelyDeleteTriviallyDeadInstructions(I); + return true; + } + } + } + + // Add I to DominatingExprs if it's an add/sub that can't sign overflow. + if (match(I, m_NSWAdd(m_Value(LHS), m_Value(RHS))) || + match(I, m_NSWSub(m_Value(LHS), m_Value(RHS)))) { + if (programUndefinedIfFullPoison(I)) { + const SCEV *Key = + SE->getAddExpr(SE->getUnknown(LHS), SE->getUnknown(RHS)); + DominatingExprs[Key].push_back(I); + } + } + return false; +} + +bool SeparateConstOffsetFromGEP::reuniteExts(Function &F) { + bool Changed = false; + DominatingExprs.clear(); + for (const auto Node : depth_first(DT)) { + BasicBlock *BB = Node->getBlock(); + for (auto I = BB->begin(); I != BB->end(); ) { + Instruction *Cur = &*I++; + Changed |= reuniteExts(Cur); + } + } + return Changed; +} + +void SeparateConstOffsetFromGEP::verifyNoDeadCode(Function &F) { + for (BasicBlock &B : F) { + for (Instruction &I : B) { + if (isInstructionTriviallyDead(&I)) { + std::string ErrMessage; + raw_string_ostream RSO(ErrMessage); + RSO << "Dead instruction detected!\n" << I << "\n"; + llvm_unreachable(RSO.str().c_str()); + } + } + } +} + +bool SeparateConstOffsetFromGEP::isLegalToSwapOperand( + GetElementPtrInst *FirstGEP, GetElementPtrInst *SecondGEP, Loop *CurLoop) { + if (!FirstGEP || !FirstGEP->hasOneUse()) + return false; + + if (!SecondGEP || FirstGEP->getParent() != SecondGEP->getParent()) + return false; + + if (FirstGEP == SecondGEP) + return false; + + unsigned FirstNum = FirstGEP->getNumOperands(); + unsigned SecondNum = SecondGEP->getNumOperands(); + // Give up if the number of operands are not 2. + if (FirstNum != SecondNum || FirstNum != 2) + return false; + + Value *FirstBase = FirstGEP->getOperand(0); + Value *SecondBase = SecondGEP->getOperand(0); + Value *FirstOffset = FirstGEP->getOperand(1); + // Give up if the index of the first GEP is loop invariant. + if (CurLoop->isLoopInvariant(FirstOffset)) + return false; + + // Give up if base doesn't have same type. + if (FirstBase->getType() != SecondBase->getType()) + return false; + + Instruction *FirstOffsetDef = dyn_cast<Instruction>(FirstOffset); + + // Check if the second operand of first GEP has constant coefficient. + // For an example, for the following code, we won't gain anything by + // hoisting the second GEP out because the second GEP can be folded away. + // %scevgep.sum.ur159 = add i64 %idxprom48.ur, 256 + // %67 = shl i64 %scevgep.sum.ur159, 2 + // %uglygep160 = getelementptr i8* %65, i64 %67 + // %uglygep161 = getelementptr i8* %uglygep160, i64 -1024 + + // Skip constant shift instruction which may be generated by Splitting GEPs. + if (FirstOffsetDef && FirstOffsetDef->isShift() && + isa<ConstantInt>(FirstOffsetDef->getOperand(1))) + FirstOffsetDef = dyn_cast<Instruction>(FirstOffsetDef->getOperand(0)); + + // Give up if FirstOffsetDef is an Add or Sub with constant. + // Because it may not profitable at all due to constant folding. + if (FirstOffsetDef) + if (BinaryOperator *BO = dyn_cast<BinaryOperator>(FirstOffsetDef)) { + unsigned opc = BO->getOpcode(); + if ((opc == Instruction::Add || opc == Instruction::Sub) && + (isa<ConstantInt>(BO->getOperand(0)) || + isa<ConstantInt>(BO->getOperand(1)))) + return false; + } + return true; +} + +bool SeparateConstOffsetFromGEP::hasMoreThanOneUseInLoop(Value *V, Loop *L) { + int UsesInLoop = 0; + for (User *U : V->users()) { + if (Instruction *User = dyn_cast<Instruction>(U)) + if (L->contains(User)) + if (++UsesInLoop > 1) + return true; + } + return false; +} + +void SeparateConstOffsetFromGEP::swapGEPOperand(GetElementPtrInst *First, + GetElementPtrInst *Second) { + Value *Offset1 = First->getOperand(1); + Value *Offset2 = Second->getOperand(1); + First->setOperand(1, Offset2); + Second->setOperand(1, Offset1); + + // We changed p+o+c to p+c+o, p+c may not be inbound anymore. + const DataLayout &DAL = First->getModule()->getDataLayout(); + APInt Offset(DAL.getPointerSizeInBits( + cast<PointerType>(First->getType())->getAddressSpace()), + 0); + Value *NewBase = + First->stripAndAccumulateInBoundsConstantOffsets(DAL, Offset); + uint64_t ObjectSize; + if (!getObjectSize(NewBase, ObjectSize, DAL, TLI) || + Offset.ugt(ObjectSize)) { + First->setIsInBounds(false); + Second->setIsInBounds(false); + } else + First->setIsInBounds(true); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp b/contrib/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp new file mode 100644 index 000000000000..aba732bc413f --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp @@ -0,0 +1,2171 @@ +//===- SimpleLoopUnswitch.cpp - Hoist loop-invariant control flow ---------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CodeMetrics.h" +#include "llvm/Analysis/LoopAnalysisManager.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/LoopPass.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/GenericDomTree.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar/SimpleLoopUnswitch.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LoopUtils.h" +#include "llvm/Transforms/Utils/ValueMapper.h" +#include <algorithm> +#include <cassert> +#include <iterator> +#include <numeric> +#include <utility> + +#define DEBUG_TYPE "simple-loop-unswitch" + +using namespace llvm; + +STATISTIC(NumBranches, "Number of branches unswitched"); +STATISTIC(NumSwitches, "Number of switches unswitched"); +STATISTIC(NumTrivial, "Number of unswitches that are trivial"); + +static cl::opt<bool> EnableNonTrivialUnswitch( + "enable-nontrivial-unswitch", cl::init(false), cl::Hidden, + cl::desc("Forcibly enables non-trivial loop unswitching rather than " + "following the configuration passed into the pass.")); + +static cl::opt<int> + UnswitchThreshold("unswitch-threshold", cl::init(50), cl::Hidden, + cl::desc("The cost threshold for unswitching a loop.")); + +static void replaceLoopUsesWithConstant(Loop &L, Value &LIC, + Constant &Replacement) { + assert(!isa<Constant>(LIC) && "Why are we unswitching on a constant?"); + + // Replace uses of LIC in the loop with the given constant. + for (auto UI = LIC.use_begin(), UE = LIC.use_end(); UI != UE;) { + // Grab the use and walk past it so we can clobber it in the use list. + Use *U = &*UI++; + Instruction *UserI = dyn_cast<Instruction>(U->getUser()); + if (!UserI || !L.contains(UserI)) + continue; + + // Replace this use within the loop body. + *U = &Replacement; + } +} + +/// Update the IDom for a basic block whose predecessor set has changed. +/// +/// This routine is designed to work when the domtree update is relatively +/// localized by leveraging a known common dominator, often a loop header. +/// +/// FIXME: Should consider hand-rolling a slightly more efficient non-DFS +/// approach here as we can do that easily by persisting the candidate IDom's +/// dominating set between each predecessor. +/// +/// FIXME: Longer term, many uses of this can be replaced by an incremental +/// domtree update strategy that starts from a known dominating block and +/// rebuilds that subtree. +static bool updateIDomWithKnownCommonDominator(BasicBlock *BB, + BasicBlock *KnownDominatingBB, + DominatorTree &DT) { + assert(pred_begin(BB) != pred_end(BB) && + "This routine does not handle unreachable blocks!"); + + BasicBlock *OrigIDom = DT[BB]->getIDom()->getBlock(); + + BasicBlock *IDom = *pred_begin(BB); + assert(DT.dominates(KnownDominatingBB, IDom) && + "Bad known dominating block!"); + + // Walk all of the other predecessors finding the nearest common dominator + // until all predecessors are covered or we reach the loop header. The loop + // header necessarily dominates all loop exit blocks in loop simplified form + // so we can early-exit the moment we hit that block. + for (auto PI = std::next(pred_begin(BB)), PE = pred_end(BB); + PI != PE && IDom != KnownDominatingBB; ++PI) { + assert(DT.dominates(KnownDominatingBB, *PI) && + "Bad known dominating block!"); + IDom = DT.findNearestCommonDominator(IDom, *PI); + } + + if (IDom == OrigIDom) + return false; + + DT.changeImmediateDominator(BB, IDom); + return true; +} + +// Note that we don't currently use the IDFCalculator here for two reasons: +// 1) It computes dominator tree levels for the entire function on each run +// of 'compute'. While this isn't terrible, given that we expect to update +// relatively small subtrees of the domtree, it isn't necessarily the right +// tradeoff. +// 2) The interface doesn't fit this usage well. It doesn't operate in +// append-only, and builds several sets that we don't need. +// +// FIXME: Neither of these issues are a big deal and could be addressed with +// some amount of refactoring of IDFCalculator. That would allow us to share +// the core logic here (which is solving the same core problem). +static void appendDomFrontier(DomTreeNode *Node, + SmallSetVector<BasicBlock *, 4> &Worklist, + SmallVectorImpl<DomTreeNode *> &DomNodes, + SmallPtrSetImpl<BasicBlock *> &DomSet) { + assert(DomNodes.empty() && "Must start with no dominator nodes."); + assert(DomSet.empty() && "Must start with an empty dominator set."); + + // First flatten this subtree into sequence of nodes by doing a pre-order + // walk. + DomNodes.push_back(Node); + // We intentionally re-evaluate the size as each node can add new children. + // Because this is a tree walk, this cannot add any duplicates. + for (int i = 0; i < (int)DomNodes.size(); ++i) + DomNodes.insert(DomNodes.end(), DomNodes[i]->begin(), DomNodes[i]->end()); + + // Now create a set of the basic blocks so we can quickly test for + // dominated successors. We could in theory use the DFS numbers of the + // dominator tree for this, but we want this to remain predictably fast + // even while we mutate the dominator tree in ways that would invalidate + // the DFS numbering. + for (DomTreeNode *InnerN : DomNodes) + DomSet.insert(InnerN->getBlock()); + + // Now re-walk the nodes, appending every successor of every node that isn't + // in the set. Note that we don't append the node itself, even though if it + // is a successor it does not strictly dominate itself and thus it would be + // part of the dominance frontier. The reason we don't append it is that + // the node passed in came *from* the worklist and so it has already been + // processed. + for (DomTreeNode *InnerN : DomNodes) + for (BasicBlock *SuccBB : successors(InnerN->getBlock())) + if (!DomSet.count(SuccBB)) + Worklist.insert(SuccBB); + + DomNodes.clear(); + DomSet.clear(); +} + +/// Update the dominator tree after unswitching a particular former exit block. +/// +/// This handles the full update of the dominator tree after hoisting a block +/// that previously was an exit block (or split off of an exit block) up to be +/// reached from the new immediate dominator of the preheader. +/// +/// The common case is simple -- we just move the unswitched block to have an +/// immediate dominator of the old preheader. But in complex cases, there may +/// be other blocks reachable from the unswitched block that are immediately +/// dominated by some node between the unswitched one and the old preheader. +/// All of these also need to be hoisted in the dominator tree. We also want to +/// minimize queries to the dominator tree because each step of this +/// invalidates any DFS numbers that would make queries fast. +static void updateDTAfterUnswitch(BasicBlock *UnswitchedBB, BasicBlock *OldPH, + DominatorTree &DT) { + DomTreeNode *OldPHNode = DT[OldPH]; + DomTreeNode *UnswitchedNode = DT[UnswitchedBB]; + // If the dominator tree has already been updated for this unswitched node, + // we're done. This makes it easier to use this routine if there are multiple + // paths to the same unswitched destination. + if (UnswitchedNode->getIDom() == OldPHNode) + return; + + // First collect the domtree nodes that we are hoisting over. These are the + // set of nodes which may have children that need to be hoisted as well. + SmallPtrSet<DomTreeNode *, 4> DomChain; + for (auto *IDom = UnswitchedNode->getIDom(); IDom != OldPHNode; + IDom = IDom->getIDom()) + DomChain.insert(IDom); + + // The unswitched block ends up immediately dominated by the old preheader -- + // regardless of whether it is the loop exit block or split off of the loop + // exit block. + DT.changeImmediateDominator(UnswitchedNode, OldPHNode); + + // For everything that moves up the dominator tree, we need to examine the + // dominator frontier to see if it additionally should move up the dominator + // tree. This lambda appends the dominator frontier for a node on the + // worklist. + SmallSetVector<BasicBlock *, 4> Worklist; + + // Scratch data structures reused by domfrontier finding. + SmallVector<DomTreeNode *, 4> DomNodes; + SmallPtrSet<BasicBlock *, 4> DomSet; + + // Append the initial dom frontier nodes. + appendDomFrontier(UnswitchedNode, Worklist, DomNodes, DomSet); + + // Walk the worklist. We grow the list in the loop and so must recompute size. + for (int i = 0; i < (int)Worklist.size(); ++i) { + auto *BB = Worklist[i]; + + DomTreeNode *Node = DT[BB]; + assert(!DomChain.count(Node) && + "Cannot be dominated by a block you can reach!"); + + // If this block had an immediate dominator somewhere in the chain + // we hoisted over, then its position in the domtree needs to move as it is + // reachable from a node hoisted over this chain. + if (!DomChain.count(Node->getIDom())) + continue; + + DT.changeImmediateDominator(Node, OldPHNode); + + // Now add this node's dominator frontier to the worklist as well. + appendDomFrontier(Node, Worklist, DomNodes, DomSet); + } +} + +/// Check that all the LCSSA PHI nodes in the loop exit block have trivial +/// incoming values along this edge. +static bool areLoopExitPHIsLoopInvariant(Loop &L, BasicBlock &ExitingBB, + BasicBlock &ExitBB) { + for (Instruction &I : ExitBB) { + auto *PN = dyn_cast<PHINode>(&I); + if (!PN) + // No more PHIs to check. + return true; + + // If the incoming value for this edge isn't loop invariant the unswitch + // won't be trivial. + if (!L.isLoopInvariant(PN->getIncomingValueForBlock(&ExitingBB))) + return false; + } + llvm_unreachable("Basic blocks should never be empty!"); +} + +/// Rewrite the PHI nodes in an unswitched loop exit basic block. +/// +/// Requires that the loop exit and unswitched basic block are the same, and +/// that the exiting block was a unique predecessor of that block. Rewrites the +/// PHI nodes in that block such that what were LCSSA PHI nodes become trivial +/// PHI nodes from the old preheader that now contains the unswitched +/// terminator. +static void rewritePHINodesForUnswitchedExitBlock(BasicBlock &UnswitchedBB, + BasicBlock &OldExitingBB, + BasicBlock &OldPH) { + for (PHINode &PN : UnswitchedBB.phis()) { + // When the loop exit is directly unswitched we just need to update the + // incoming basic block. We loop to handle weird cases with repeated + // incoming blocks, but expect to typically only have one operand here. + for (auto i : seq<int>(0, PN.getNumOperands())) { + assert(PN.getIncomingBlock(i) == &OldExitingBB && + "Found incoming block different from unique predecessor!"); + PN.setIncomingBlock(i, &OldPH); + } + } +} + +/// Rewrite the PHI nodes in the loop exit basic block and the split off +/// unswitched block. +/// +/// Because the exit block remains an exit from the loop, this rewrites the +/// LCSSA PHI nodes in it to remove the unswitched edge and introduces PHI +/// nodes into the unswitched basic block to select between the value in the +/// old preheader and the loop exit. +static void rewritePHINodesForExitAndUnswitchedBlocks(BasicBlock &ExitBB, + BasicBlock &UnswitchedBB, + BasicBlock &OldExitingBB, + BasicBlock &OldPH) { + assert(&ExitBB != &UnswitchedBB && + "Must have different loop exit and unswitched blocks!"); + Instruction *InsertPt = &*UnswitchedBB.begin(); + for (PHINode &PN : ExitBB.phis()) { + auto *NewPN = PHINode::Create(PN.getType(), /*NumReservedValues*/ 2, + PN.getName() + ".split", InsertPt); + + // Walk backwards over the old PHI node's inputs to minimize the cost of + // removing each one. We have to do this weird loop manually so that we + // create the same number of new incoming edges in the new PHI as we expect + // each case-based edge to be included in the unswitched switch in some + // cases. + // FIXME: This is really, really gross. It would be much cleaner if LLVM + // allowed us to create a single entry for a predecessor block without + // having separate entries for each "edge" even though these edges are + // required to produce identical results. + for (int i = PN.getNumIncomingValues() - 1; i >= 0; --i) { + if (PN.getIncomingBlock(i) != &OldExitingBB) + continue; + + Value *Incoming = PN.removeIncomingValue(i); + NewPN->addIncoming(Incoming, &OldPH); + } + + // Now replace the old PHI with the new one and wire the old one in as an + // input to the new one. + PN.replaceAllUsesWith(NewPN); + NewPN->addIncoming(&PN, &ExitBB); + } +} + +/// Unswitch a trivial branch if the condition is loop invariant. +/// +/// This routine should only be called when loop code leading to the branch has +/// been validated as trivial (no side effects). This routine checks if the +/// condition is invariant and one of the successors is a loop exit. This +/// allows us to unswitch without duplicating the loop, making it trivial. +/// +/// If this routine fails to unswitch the branch it returns false. +/// +/// If the branch can be unswitched, this routine splits the preheader and +/// hoists the branch above that split. Preserves loop simplified form +/// (splitting the exit block as necessary). It simplifies the branch within +/// the loop to an unconditional branch but doesn't remove it entirely. Further +/// cleanup can be done with some simplify-cfg like pass. +static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT, + LoopInfo &LI) { + assert(BI.isConditional() && "Can only unswitch a conditional branch!"); + DEBUG(dbgs() << " Trying to unswitch branch: " << BI << "\n"); + + Value *LoopCond = BI.getCondition(); + + // Need a trivial loop condition to unswitch. + if (!L.isLoopInvariant(LoopCond)) + return false; + + // FIXME: We should compute this once at the start and update it! + SmallVector<BasicBlock *, 16> ExitBlocks; + L.getExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 16> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); + + // Check to see if a successor of the branch is guaranteed to + // exit through a unique exit block without having any + // side-effects. If so, determine the value of Cond that causes + // it to do this. + ConstantInt *CondVal = ConstantInt::getTrue(BI.getContext()); + ConstantInt *Replacement = ConstantInt::getFalse(BI.getContext()); + int LoopExitSuccIdx = 0; + auto *LoopExitBB = BI.getSuccessor(0); + if (!ExitBlockSet.count(LoopExitBB)) { + std::swap(CondVal, Replacement); + LoopExitSuccIdx = 1; + LoopExitBB = BI.getSuccessor(1); + if (!ExitBlockSet.count(LoopExitBB)) + return false; + } + auto *ContinueBB = BI.getSuccessor(1 - LoopExitSuccIdx); + assert(L.contains(ContinueBB) && + "Cannot have both successors exit and still be in the loop!"); + + auto *ParentBB = BI.getParent(); + if (!areLoopExitPHIsLoopInvariant(L, *ParentBB, *LoopExitBB)) + return false; + + DEBUG(dbgs() << " unswitching trivial branch when: " << CondVal + << " == " << LoopCond << "\n"); + + // Split the preheader, so that we know that there is a safe place to insert + // the conditional branch. We will change the preheader to have a conditional + // branch on LoopCond. + BasicBlock *OldPH = L.getLoopPreheader(); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + + // Now that we have a place to insert the conditional branch, create a place + // to branch to: this is the exit block out of the loop that we are + // unswitching. We need to split this if there are other loop predecessors. + // Because the loop is in simplified form, *any* other predecessor is enough. + BasicBlock *UnswitchedBB; + if (BasicBlock *PredBB = LoopExitBB->getUniquePredecessor()) { + (void)PredBB; + assert(PredBB == BI.getParent() && + "A branch's parent isn't a predecessor!"); + UnswitchedBB = LoopExitBB; + } else { + UnswitchedBB = SplitBlock(LoopExitBB, &LoopExitBB->front(), &DT, &LI); + } + + // Now splice the branch to gate reaching the new preheader and re-point its + // successors. + OldPH->getInstList().splice(std::prev(OldPH->end()), + BI.getParent()->getInstList(), BI); + OldPH->getTerminator()->eraseFromParent(); + BI.setSuccessor(LoopExitSuccIdx, UnswitchedBB); + BI.setSuccessor(1 - LoopExitSuccIdx, NewPH); + + // Create a new unconditional branch that will continue the loop as a new + // terminator. + BranchInst::Create(ContinueBB, ParentBB); + + // Rewrite the relevant PHI nodes. + if (UnswitchedBB == LoopExitBB) + rewritePHINodesForUnswitchedExitBlock(*UnswitchedBB, *ParentBB, *OldPH); + else + rewritePHINodesForExitAndUnswitchedBlocks(*LoopExitBB, *UnswitchedBB, + *ParentBB, *OldPH); + + // Now we need to update the dominator tree. + updateDTAfterUnswitch(UnswitchedBB, OldPH, DT); + // But if we split something off of the loop exit block then we also removed + // one of the predecessors for the loop exit block and may need to update its + // idom. + if (UnswitchedBB != LoopExitBB) + updateIDomWithKnownCommonDominator(LoopExitBB, L.getHeader(), DT); + + // Since this is an i1 condition we can also trivially replace uses of it + // within the loop with a constant. + replaceLoopUsesWithConstant(L, *LoopCond, *Replacement); + + ++NumTrivial; + ++NumBranches; + return true; +} + +/// Unswitch a trivial switch if the condition is loop invariant. +/// +/// This routine should only be called when loop code leading to the switch has +/// been validated as trivial (no side effects). This routine checks if the +/// condition is invariant and that at least one of the successors is a loop +/// exit. This allows us to unswitch without duplicating the loop, making it +/// trivial. +/// +/// If this routine fails to unswitch the switch it returns false. +/// +/// If the switch can be unswitched, this routine splits the preheader and +/// copies the switch above that split. If the default case is one of the +/// exiting cases, it copies the non-exiting cases and points them at the new +/// preheader. If the default case is not exiting, it copies the exiting cases +/// and points the default at the preheader. It preserves loop simplified form +/// (splitting the exit blocks as necessary). It simplifies the switch within +/// the loop by removing now-dead cases. If the default case is one of those +/// unswitched, it replaces its destination with a new basic block containing +/// only unreachable. Such basic blocks, while technically loop exits, are not +/// considered for unswitching so this is a stable transform and the same +/// switch will not be revisited. If after unswitching there is only a single +/// in-loop successor, the switch is further simplified to an unconditional +/// branch. Still more cleanup can be done with some simplify-cfg like pass. +static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT, + LoopInfo &LI) { + DEBUG(dbgs() << " Trying to unswitch switch: " << SI << "\n"); + Value *LoopCond = SI.getCondition(); + + // If this isn't switching on an invariant condition, we can't unswitch it. + if (!L.isLoopInvariant(LoopCond)) + return false; + + auto *ParentBB = SI.getParent(); + + // FIXME: We should compute this once at the start and update it! + SmallVector<BasicBlock *, 16> ExitBlocks; + L.getExitBlocks(ExitBlocks); + SmallPtrSet<BasicBlock *, 16> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); + + SmallVector<int, 4> ExitCaseIndices; + for (auto Case : SI.cases()) { + auto *SuccBB = Case.getCaseSuccessor(); + if (ExitBlockSet.count(SuccBB) && + areLoopExitPHIsLoopInvariant(L, *ParentBB, *SuccBB)) + ExitCaseIndices.push_back(Case.getCaseIndex()); + } + BasicBlock *DefaultExitBB = nullptr; + if (ExitBlockSet.count(SI.getDefaultDest()) && + areLoopExitPHIsLoopInvariant(L, *ParentBB, *SI.getDefaultDest()) && + !isa<UnreachableInst>(SI.getDefaultDest()->getTerminator())) + DefaultExitBB = SI.getDefaultDest(); + else if (ExitCaseIndices.empty()) + return false; + + DEBUG(dbgs() << " unswitching trivial cases...\n"); + + SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases; + ExitCases.reserve(ExitCaseIndices.size()); + // We walk the case indices backwards so that we remove the last case first + // and don't disrupt the earlier indices. + for (unsigned Index : reverse(ExitCaseIndices)) { + auto CaseI = SI.case_begin() + Index; + // Save the value of this case. + ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()}); + // Delete the unswitched cases. + SI.removeCase(CaseI); + } + + // Check if after this all of the remaining cases point at the same + // successor. + BasicBlock *CommonSuccBB = nullptr; + if (SI.getNumCases() > 0 && + std::all_of(std::next(SI.case_begin()), SI.case_end(), + [&SI](const SwitchInst::CaseHandle &Case) { + return Case.getCaseSuccessor() == + SI.case_begin()->getCaseSuccessor(); + })) + CommonSuccBB = SI.case_begin()->getCaseSuccessor(); + + if (DefaultExitBB) { + // We can't remove the default edge so replace it with an edge to either + // the single common remaining successor (if we have one) or an unreachable + // block. + if (CommonSuccBB) { + SI.setDefaultDest(CommonSuccBB); + } else { + BasicBlock *UnreachableBB = BasicBlock::Create( + ParentBB->getContext(), + Twine(ParentBB->getName()) + ".unreachable_default", + ParentBB->getParent()); + new UnreachableInst(ParentBB->getContext(), UnreachableBB); + SI.setDefaultDest(UnreachableBB); + DT.addNewBlock(UnreachableBB, ParentBB); + } + } else { + // If we're not unswitching the default, we need it to match any cases to + // have a common successor or if we have no cases it is the common + // successor. + if (SI.getNumCases() == 0) + CommonSuccBB = SI.getDefaultDest(); + else if (SI.getDefaultDest() != CommonSuccBB) + CommonSuccBB = nullptr; + } + + // Split the preheader, so that we know that there is a safe place to insert + // the switch. + BasicBlock *OldPH = L.getLoopPreheader(); + BasicBlock *NewPH = SplitEdge(OldPH, L.getHeader(), &DT, &LI); + OldPH->getTerminator()->eraseFromParent(); + + // Now add the unswitched switch. + auto *NewSI = SwitchInst::Create(LoopCond, NewPH, ExitCases.size(), OldPH); + + // Rewrite the IR for the unswitched basic blocks. This requires two steps. + // First, we split any exit blocks with remaining in-loop predecessors. Then + // we update the PHIs in one of two ways depending on if there was a split. + // We walk in reverse so that we split in the same order as the cases + // appeared. This is purely for convenience of reading the resulting IR, but + // it doesn't cost anything really. + SmallPtrSet<BasicBlock *, 2> UnswitchedExitBBs; + SmallDenseMap<BasicBlock *, BasicBlock *, 2> SplitExitBBMap; + // Handle the default exit if necessary. + // FIXME: It'd be great if we could merge this with the loop below but LLVM's + // ranges aren't quite powerful enough yet. + if (DefaultExitBB) { + if (pred_empty(DefaultExitBB)) { + UnswitchedExitBBs.insert(DefaultExitBB); + rewritePHINodesForUnswitchedExitBlock(*DefaultExitBB, *ParentBB, *OldPH); + } else { + auto *SplitBB = + SplitBlock(DefaultExitBB, &DefaultExitBB->front(), &DT, &LI); + rewritePHINodesForExitAndUnswitchedBlocks(*DefaultExitBB, *SplitBB, + *ParentBB, *OldPH); + updateIDomWithKnownCommonDominator(DefaultExitBB, L.getHeader(), DT); + DefaultExitBB = SplitExitBBMap[DefaultExitBB] = SplitBB; + } + } + // Note that we must use a reference in the for loop so that we update the + // container. + for (auto &CasePair : reverse(ExitCases)) { + // Grab a reference to the exit block in the pair so that we can update it. + BasicBlock *ExitBB = CasePair.second; + + // If this case is the last edge into the exit block, we can simply reuse it + // as it will no longer be a loop exit. No mapping necessary. + if (pred_empty(ExitBB)) { + // Only rewrite once. + if (UnswitchedExitBBs.insert(ExitBB).second) + rewritePHINodesForUnswitchedExitBlock(*ExitBB, *ParentBB, *OldPH); + continue; + } + + // Otherwise we need to split the exit block so that we retain an exit + // block from the loop and a target for the unswitched condition. + BasicBlock *&SplitExitBB = SplitExitBBMap[ExitBB]; + if (!SplitExitBB) { + // If this is the first time we see this, do the split and remember it. + SplitExitBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + rewritePHINodesForExitAndUnswitchedBlocks(*ExitBB, *SplitExitBB, + *ParentBB, *OldPH); + updateIDomWithKnownCommonDominator(ExitBB, L.getHeader(), DT); + } + // Update the case pair to point to the split block. + CasePair.second = SplitExitBB; + } + + // Now add the unswitched cases. We do this in reverse order as we built them + // in reverse order. + for (auto CasePair : reverse(ExitCases)) { + ConstantInt *CaseVal = CasePair.first; + BasicBlock *UnswitchedBB = CasePair.second; + + NewSI->addCase(CaseVal, UnswitchedBB); + updateDTAfterUnswitch(UnswitchedBB, OldPH, DT); + } + + // If the default was unswitched, re-point it and add explicit cases for + // entering the loop. + if (DefaultExitBB) { + NewSI->setDefaultDest(DefaultExitBB); + updateDTAfterUnswitch(DefaultExitBB, OldPH, DT); + + // We removed all the exit cases, so we just copy the cases to the + // unswitched switch. + for (auto Case : SI.cases()) + NewSI->addCase(Case.getCaseValue(), NewPH); + } + + // If we ended up with a common successor for every path through the switch + // after unswitching, rewrite it to an unconditional branch to make it easy + // to recognize. Otherwise we potentially have to recognize the default case + // pointing at unreachable and other complexity. + if (CommonSuccBB) { + BasicBlock *BB = SI.getParent(); + SI.eraseFromParent(); + BranchInst::Create(CommonSuccBB, BB); + } + + DT.verifyDomTree(); + ++NumTrivial; + ++NumSwitches; + return true; +} + +/// This routine scans the loop to find a branch or switch which occurs before +/// any side effects occur. These can potentially be unswitched without +/// duplicating the loop. If a branch or switch is successfully unswitched the +/// scanning continues to see if subsequent branches or switches have become +/// trivial. Once all trivial candidates have been unswitched, this routine +/// returns. +/// +/// The return value indicates whether anything was unswitched (and therefore +/// changed). +static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT, + LoopInfo &LI) { + bool Changed = false; + + // If loop header has only one reachable successor we should keep looking for + // trivial condition candidates in the successor as well. An alternative is + // to constant fold conditions and merge successors into loop header (then we + // only need to check header's terminator). The reason for not doing this in + // LoopUnswitch pass is that it could potentially break LoopPassManager's + // invariants. Folding dead branches could either eliminate the current loop + // or make other loops unreachable. LCSSA form might also not be preserved + // after deleting branches. The following code keeps traversing loop header's + // successors until it finds the trivial condition candidate (condition that + // is not a constant). Since unswitching generates branches with constant + // conditions, this scenario could be very common in practice. + BasicBlock *CurrentBB = L.getHeader(); + SmallPtrSet<BasicBlock *, 8> Visited; + Visited.insert(CurrentBB); + do { + // Check if there are any side-effecting instructions (e.g. stores, calls, + // volatile loads) in the part of the loop that the code *would* execute + // without unswitching. + if (llvm::any_of(*CurrentBB, + [](Instruction &I) { return I.mayHaveSideEffects(); })) + return Changed; + + TerminatorInst *CurrentTerm = CurrentBB->getTerminator(); + + if (auto *SI = dyn_cast<SwitchInst>(CurrentTerm)) { + // Don't bother trying to unswitch past a switch with a constant + // condition. This should be removed prior to running this pass by + // simplify-cfg. + if (isa<Constant>(SI->getCondition())) + return Changed; + + if (!unswitchTrivialSwitch(L, *SI, DT, LI)) + // Coludn't unswitch this one so we're done. + return Changed; + + // Mark that we managed to unswitch something. + Changed = true; + + // If unswitching turned the terminator into an unconditional branch then + // we can continue. The unswitching logic specifically works to fold any + // cases it can into an unconditional branch to make it easier to + // recognize here. + auto *BI = dyn_cast<BranchInst>(CurrentBB->getTerminator()); + if (!BI || BI->isConditional()) + return Changed; + + CurrentBB = BI->getSuccessor(0); + continue; + } + + auto *BI = dyn_cast<BranchInst>(CurrentTerm); + if (!BI) + // We do not understand other terminator instructions. + return Changed; + + // Don't bother trying to unswitch past an unconditional branch or a branch + // with a constant value. These should be removed by simplify-cfg prior to + // running this pass. + if (!BI->isConditional() || isa<Constant>(BI->getCondition())) + return Changed; + + // Found a trivial condition candidate: non-foldable conditional branch. If + // we fail to unswitch this, we can't do anything else that is trivial. + if (!unswitchTrivialBranch(L, *BI, DT, LI)) + return Changed; + + // Mark that we managed to unswitch something. + Changed = true; + + // We unswitched the branch. This should always leave us with an + // unconditional branch that we can follow now. + BI = cast<BranchInst>(CurrentBB->getTerminator()); + assert(!BI->isConditional() && + "Cannot form a conditional branch by unswitching1"); + CurrentBB = BI->getSuccessor(0); + + // When continuing, if we exit the loop or reach a previous visited block, + // then we can not reach any trivial condition candidates (unfoldable + // branch instructions or switch instructions) and no unswitch can happen. + } while (L.contains(CurrentBB) && Visited.insert(CurrentBB).second); + + return Changed; +} + +/// Build the cloned blocks for an unswitched copy of the given loop. +/// +/// The cloned blocks are inserted before the loop preheader (`LoopPH`) and +/// after the split block (`SplitBB`) that will be used to select between the +/// cloned and original loop. +/// +/// This routine handles cloning all of the necessary loop blocks and exit +/// blocks including rewriting their instructions and the relevant PHI nodes. +/// It skips loop and exit blocks that are not necessary based on the provided +/// set. It also correctly creates the unconditional branch in the cloned +/// unswitched parent block to only point at the unswitched successor. +/// +/// This does not handle most of the necessary updates to `LoopInfo`. Only exit +/// block splitting is correctly reflected in `LoopInfo`, essentially all of +/// the cloned blocks (and their loops) are left without full `LoopInfo` +/// updates. This also doesn't fully update `DominatorTree`. It adds the cloned +/// blocks to them but doesn't create the cloned `DominatorTree` structure and +/// instead the caller must recompute an accurate DT. It *does* correctly +/// update the `AssumptionCache` provided in `AC`. +static BasicBlock *buildClonedLoopBlocks( + Loop &L, BasicBlock *LoopPH, BasicBlock *SplitBB, + ArrayRef<BasicBlock *> ExitBlocks, BasicBlock *ParentBB, + BasicBlock *UnswitchedSuccBB, BasicBlock *ContinueSuccBB, + const SmallPtrSetImpl<BasicBlock *> &SkippedLoopAndExitBlocks, + ValueToValueMapTy &VMap, AssumptionCache &AC, DominatorTree &DT, + LoopInfo &LI) { + SmallVector<BasicBlock *, 4> NewBlocks; + NewBlocks.reserve(L.getNumBlocks() + ExitBlocks.size()); + + // We will need to clone a bunch of blocks, wrap up the clone operation in + // a helper. + auto CloneBlock = [&](BasicBlock *OldBB) { + // Clone the basic block and insert it before the new preheader. + BasicBlock *NewBB = CloneBasicBlock(OldBB, VMap, ".us", OldBB->getParent()); + NewBB->moveBefore(LoopPH); + + // Record this block and the mapping. + NewBlocks.push_back(NewBB); + VMap[OldBB] = NewBB; + + // Add the block to the domtree. We'll move it to the correct position + // below. + DT.addNewBlock(NewBB, SplitBB); + + return NewBB; + }; + + // First, clone the preheader. + auto *ClonedPH = CloneBlock(LoopPH); + + // Then clone all the loop blocks, skipping the ones that aren't necessary. + for (auto *LoopBB : L.blocks()) + if (!SkippedLoopAndExitBlocks.count(LoopBB)) + CloneBlock(LoopBB); + + // Split all the loop exit edges so that when we clone the exit blocks, if + // any of the exit blocks are *also* a preheader for some other loop, we + // don't create multiple predecessors entering the loop header. + for (auto *ExitBB : ExitBlocks) { + if (SkippedLoopAndExitBlocks.count(ExitBB)) + continue; + + // When we are going to clone an exit, we don't need to clone all the + // instructions in the exit block and we want to ensure we have an easy + // place to merge the CFG, so split the exit first. This is always safe to + // do because there cannot be any non-loop predecessors of a loop exit in + // loop simplified form. + auto *MergeBB = SplitBlock(ExitBB, &ExitBB->front(), &DT, &LI); + + // Rearrange the names to make it easier to write test cases by having the + // exit block carry the suffix rather than the merge block carrying the + // suffix. + MergeBB->takeName(ExitBB); + ExitBB->setName(Twine(MergeBB->getName()) + ".split"); + + // Now clone the original exit block. + auto *ClonedExitBB = CloneBlock(ExitBB); + assert(ClonedExitBB->getTerminator()->getNumSuccessors() == 1 && + "Exit block should have been split to have one successor!"); + assert(ClonedExitBB->getTerminator()->getSuccessor(0) == MergeBB && + "Cloned exit block has the wrong successor!"); + + // Move the merge block's idom to be the split point as one exit is + // dominated by one header, and the other by another, so we know the split + // point dominates both. While the dominator tree isn't fully accurate, we + // want sub-trees within the original loop to be correctly reflect + // dominance within that original loop (at least) and that requires moving + // the merge block out of that subtree. + // FIXME: This is very brittle as we essentially have a partial contract on + // the dominator tree. We really need to instead update it and keep it + // valid or stop relying on it. + DT.changeImmediateDominator(MergeBB, SplitBB); + + // Remap any cloned instructions and create a merge phi node for them. + for (auto ZippedInsts : llvm::zip_first( + llvm::make_range(ExitBB->begin(), std::prev(ExitBB->end())), + llvm::make_range(ClonedExitBB->begin(), + std::prev(ClonedExitBB->end())))) { + Instruction &I = std::get<0>(ZippedInsts); + Instruction &ClonedI = std::get<1>(ZippedInsts); + + // The only instructions in the exit block should be PHI nodes and + // potentially a landing pad. + assert( + (isa<PHINode>(I) || isa<LandingPadInst>(I) || isa<CatchPadInst>(I)) && + "Bad instruction in exit block!"); + // We should have a value map between the instruction and its clone. + assert(VMap.lookup(&I) == &ClonedI && "Mismatch in the value map!"); + + auto *MergePN = + PHINode::Create(I.getType(), /*NumReservedValues*/ 2, ".us-phi", + &*MergeBB->getFirstInsertionPt()); + I.replaceAllUsesWith(MergePN); + MergePN->addIncoming(&I, ExitBB); + MergePN->addIncoming(&ClonedI, ClonedExitBB); + } + } + + // Rewrite the instructions in the cloned blocks to refer to the instructions + // in the cloned blocks. We have to do this as a second pass so that we have + // everything available. Also, we have inserted new instructions which may + // include assume intrinsics, so we update the assumption cache while + // processing this. + for (auto *ClonedBB : NewBlocks) + for (Instruction &I : *ClonedBB) { + RemapInstruction(&I, VMap, + RF_NoModuleLevelChanges | RF_IgnoreMissingLocals); + if (auto *II = dyn_cast<IntrinsicInst>(&I)) + if (II->getIntrinsicID() == Intrinsic::assume) + AC.registerAssumption(II); + } + + // Remove the cloned parent as a predecessor of the cloned continue successor + // if we did in fact clone it. + auto *ClonedParentBB = cast<BasicBlock>(VMap.lookup(ParentBB)); + if (auto *ClonedContinueSuccBB = + cast_or_null<BasicBlock>(VMap.lookup(ContinueSuccBB))) + ClonedContinueSuccBB->removePredecessor(ClonedParentBB, + /*DontDeleteUselessPHIs*/ true); + // Replace the cloned branch with an unconditional branch to the cloneed + // unswitched successor. + auto *ClonedSuccBB = cast<BasicBlock>(VMap.lookup(UnswitchedSuccBB)); + ClonedParentBB->getTerminator()->eraseFromParent(); + BranchInst::Create(ClonedSuccBB, ClonedParentBB); + + // Update any PHI nodes in the cloned successors of the skipped blocks to not + // have spurious incoming values. + for (auto *LoopBB : L.blocks()) + if (SkippedLoopAndExitBlocks.count(LoopBB)) + for (auto *SuccBB : successors(LoopBB)) + if (auto *ClonedSuccBB = cast_or_null<BasicBlock>(VMap.lookup(SuccBB))) + for (PHINode &PN : ClonedSuccBB->phis()) + PN.removeIncomingValue(LoopBB, /*DeletePHIIfEmpty*/ false); + + return ClonedPH; +} + +/// Recursively clone the specified loop and all of its children. +/// +/// The target parent loop for the clone should be provided, or can be null if +/// the clone is a top-level loop. While cloning, all the blocks are mapped +/// with the provided value map. The entire original loop must be present in +/// the value map. The cloned loop is returned. +static Loop *cloneLoopNest(Loop &OrigRootL, Loop *RootParentL, + const ValueToValueMapTy &VMap, LoopInfo &LI) { + auto AddClonedBlocksToLoop = [&](Loop &OrigL, Loop &ClonedL) { + assert(ClonedL.getBlocks().empty() && "Must start with an empty loop!"); + ClonedL.reserveBlocks(OrigL.getNumBlocks()); + for (auto *BB : OrigL.blocks()) { + auto *ClonedBB = cast<BasicBlock>(VMap.lookup(BB)); + ClonedL.addBlockEntry(ClonedBB); + if (LI.getLoopFor(BB) == &OrigL) { + assert(!LI.getLoopFor(ClonedBB) && + "Should not have an existing loop for this block!"); + LI.changeLoopFor(ClonedBB, &ClonedL); + } + } + }; + + // We specially handle the first loop because it may get cloned into + // a different parent and because we most commonly are cloning leaf loops. + Loop *ClonedRootL = LI.AllocateLoop(); + if (RootParentL) + RootParentL->addChildLoop(ClonedRootL); + else + LI.addTopLevelLoop(ClonedRootL); + AddClonedBlocksToLoop(OrigRootL, *ClonedRootL); + + if (OrigRootL.empty()) + return ClonedRootL; + + // If we have a nest, we can quickly clone the entire loop nest using an + // iterative approach because it is a tree. We keep the cloned parent in the + // data structure to avoid repeatedly querying through a map to find it. + SmallVector<std::pair<Loop *, Loop *>, 16> LoopsToClone; + // Build up the loops to clone in reverse order as we'll clone them from the + // back. + for (Loop *ChildL : llvm::reverse(OrigRootL)) + LoopsToClone.push_back({ClonedRootL, ChildL}); + do { + Loop *ClonedParentL, *L; + std::tie(ClonedParentL, L) = LoopsToClone.pop_back_val(); + Loop *ClonedL = LI.AllocateLoop(); + ClonedParentL->addChildLoop(ClonedL); + AddClonedBlocksToLoop(*L, *ClonedL); + for (Loop *ChildL : llvm::reverse(*L)) + LoopsToClone.push_back({ClonedL, ChildL}); + } while (!LoopsToClone.empty()); + + return ClonedRootL; +} + +/// Build the cloned loops of an original loop from unswitching. +/// +/// Because unswitching simplifies the CFG of the loop, this isn't a trivial +/// operation. We need to re-verify that there even is a loop (as the backedge +/// may not have been cloned), and even if there are remaining backedges the +/// backedge set may be different. However, we know that each child loop is +/// undisturbed, we only need to find where to place each child loop within +/// either any parent loop or within a cloned version of the original loop. +/// +/// Because child loops may end up cloned outside of any cloned version of the +/// original loop, multiple cloned sibling loops may be created. All of them +/// are returned so that the newly introduced loop nest roots can be +/// identified. +static Loop *buildClonedLoops(Loop &OrigL, ArrayRef<BasicBlock *> ExitBlocks, + const ValueToValueMapTy &VMap, LoopInfo &LI, + SmallVectorImpl<Loop *> &NonChildClonedLoops) { + Loop *ClonedL = nullptr; + + auto *OrigPH = OrigL.getLoopPreheader(); + auto *OrigHeader = OrigL.getHeader(); + + auto *ClonedPH = cast<BasicBlock>(VMap.lookup(OrigPH)); + auto *ClonedHeader = cast<BasicBlock>(VMap.lookup(OrigHeader)); + + // We need to know the loops of the cloned exit blocks to even compute the + // accurate parent loop. If we only clone exits to some parent of the + // original parent, we want to clone into that outer loop. We also keep track + // of the loops that our cloned exit blocks participate in. + Loop *ParentL = nullptr; + SmallVector<BasicBlock *, 4> ClonedExitsInLoops; + SmallDenseMap<BasicBlock *, Loop *, 16> ExitLoopMap; + ClonedExitsInLoops.reserve(ExitBlocks.size()); + for (auto *ExitBB : ExitBlocks) + if (auto *ClonedExitBB = cast_or_null<BasicBlock>(VMap.lookup(ExitBB))) + if (Loop *ExitL = LI.getLoopFor(ExitBB)) { + ExitLoopMap[ClonedExitBB] = ExitL; + ClonedExitsInLoops.push_back(ClonedExitBB); + if (!ParentL || (ParentL != ExitL && ParentL->contains(ExitL))) + ParentL = ExitL; + } + assert((!ParentL || ParentL == OrigL.getParentLoop() || + ParentL->contains(OrigL.getParentLoop())) && + "The computed parent loop should always contain (or be) the parent of " + "the original loop."); + + // We build the set of blocks dominated by the cloned header from the set of + // cloned blocks out of the original loop. While not all of these will + // necessarily be in the cloned loop, it is enough to establish that they + // aren't in unreachable cycles, etc. + SmallSetVector<BasicBlock *, 16> ClonedLoopBlocks; + for (auto *BB : OrigL.blocks()) + if (auto *ClonedBB = cast_or_null<BasicBlock>(VMap.lookup(BB))) + ClonedLoopBlocks.insert(ClonedBB); + + // Rebuild the set of blocks that will end up in the cloned loop. We may have + // skipped cloning some region of this loop which can in turn skip some of + // the backedges so we have to rebuild the blocks in the loop based on the + // backedges that remain after cloning. + SmallVector<BasicBlock *, 16> Worklist; + SmallPtrSet<BasicBlock *, 16> BlocksInClonedLoop; + for (auto *Pred : predecessors(ClonedHeader)) { + // The only possible non-loop header predecessor is the preheader because + // we know we cloned the loop in simplified form. + if (Pred == ClonedPH) + continue; + + // Because the loop was in simplified form, the only non-loop predecessor + // should be the preheader. + assert(ClonedLoopBlocks.count(Pred) && "Found a predecessor of the loop " + "header other than the preheader " + "that is not part of the loop!"); + + // Insert this block into the loop set and on the first visit (and if it + // isn't the header we're currently walking) put it into the worklist to + // recurse through. + if (BlocksInClonedLoop.insert(Pred).second && Pred != ClonedHeader) + Worklist.push_back(Pred); + } + + // If we had any backedges then there *is* a cloned loop. Put the header into + // the loop set and then walk the worklist backwards to find all the blocks + // that remain within the loop after cloning. + if (!BlocksInClonedLoop.empty()) { + BlocksInClonedLoop.insert(ClonedHeader); + + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + assert(BlocksInClonedLoop.count(BB) && + "Didn't put block into the loop set!"); + + // Insert any predecessors that are in the possible set into the cloned + // set, and if the insert is successful, add them to the worklist. Note + // that we filter on the blocks that are definitely reachable via the + // backedge to the loop header so we may prune out dead code within the + // cloned loop. + for (auto *Pred : predecessors(BB)) + if (ClonedLoopBlocks.count(Pred) && + BlocksInClonedLoop.insert(Pred).second) + Worklist.push_back(Pred); + } + + ClonedL = LI.AllocateLoop(); + if (ParentL) { + ParentL->addBasicBlockToLoop(ClonedPH, LI); + ParentL->addChildLoop(ClonedL); + } else { + LI.addTopLevelLoop(ClonedL); + } + + ClonedL->reserveBlocks(BlocksInClonedLoop.size()); + // We don't want to just add the cloned loop blocks based on how we + // discovered them. The original order of blocks was carefully built in + // a way that doesn't rely on predecessor ordering. Rather than re-invent + // that logic, we just re-walk the original blocks (and those of the child + // loops) and filter them as we add them into the cloned loop. + for (auto *BB : OrigL.blocks()) { + auto *ClonedBB = cast_or_null<BasicBlock>(VMap.lookup(BB)); + if (!ClonedBB || !BlocksInClonedLoop.count(ClonedBB)) + continue; + + // Directly add the blocks that are only in this loop. + if (LI.getLoopFor(BB) == &OrigL) { + ClonedL->addBasicBlockToLoop(ClonedBB, LI); + continue; + } + + // We want to manually add it to this loop and parents. + // Registering it with LoopInfo will happen when we clone the top + // loop for this block. + for (Loop *PL = ClonedL; PL; PL = PL->getParentLoop()) + PL->addBlockEntry(ClonedBB); + } + + // Now add each child loop whose header remains within the cloned loop. All + // of the blocks within the loop must satisfy the same constraints as the + // header so once we pass the header checks we can just clone the entire + // child loop nest. + for (Loop *ChildL : OrigL) { + auto *ClonedChildHeader = + cast_or_null<BasicBlock>(VMap.lookup(ChildL->getHeader())); + if (!ClonedChildHeader || !BlocksInClonedLoop.count(ClonedChildHeader)) + continue; + +#ifndef NDEBUG + // We should never have a cloned child loop header but fail to have + // all of the blocks for that child loop. + for (auto *ChildLoopBB : ChildL->blocks()) + assert(BlocksInClonedLoop.count( + cast<BasicBlock>(VMap.lookup(ChildLoopBB))) && + "Child cloned loop has a header within the cloned outer " + "loop but not all of its blocks!"); +#endif + + cloneLoopNest(*ChildL, ClonedL, VMap, LI); + } + } + + // Now that we've handled all the components of the original loop that were + // cloned into a new loop, we still need to handle anything from the original + // loop that wasn't in a cloned loop. + + // Figure out what blocks are left to place within any loop nest containing + // the unswitched loop. If we never formed a loop, the cloned PH is one of + // them. + SmallPtrSet<BasicBlock *, 16> UnloopedBlockSet; + if (BlocksInClonedLoop.empty()) + UnloopedBlockSet.insert(ClonedPH); + for (auto *ClonedBB : ClonedLoopBlocks) + if (!BlocksInClonedLoop.count(ClonedBB)) + UnloopedBlockSet.insert(ClonedBB); + + // Copy the cloned exits and sort them in ascending loop depth, we'll work + // backwards across these to process them inside out. The order shouldn't + // matter as we're just trying to build up the map from inside-out; we use + // the map in a more stably ordered way below. + auto OrderedClonedExitsInLoops = ClonedExitsInLoops; + std::sort(OrderedClonedExitsInLoops.begin(), OrderedClonedExitsInLoops.end(), + [&](BasicBlock *LHS, BasicBlock *RHS) { + return ExitLoopMap.lookup(LHS)->getLoopDepth() < + ExitLoopMap.lookup(RHS)->getLoopDepth(); + }); + + // Populate the existing ExitLoopMap with everything reachable from each + // exit, starting from the inner most exit. + while (!UnloopedBlockSet.empty() && !OrderedClonedExitsInLoops.empty()) { + assert(Worklist.empty() && "Didn't clear worklist!"); + + BasicBlock *ExitBB = OrderedClonedExitsInLoops.pop_back_val(); + Loop *ExitL = ExitLoopMap.lookup(ExitBB); + + // Walk the CFG back until we hit the cloned PH adding everything reachable + // and in the unlooped set to this exit block's loop. + Worklist.push_back(ExitBB); + do { + BasicBlock *BB = Worklist.pop_back_val(); + // We can stop recursing at the cloned preheader (if we get there). + if (BB == ClonedPH) + continue; + + for (BasicBlock *PredBB : predecessors(BB)) { + // If this pred has already been moved to our set or is part of some + // (inner) loop, no update needed. + if (!UnloopedBlockSet.erase(PredBB)) { + assert( + (BlocksInClonedLoop.count(PredBB) || ExitLoopMap.count(PredBB)) && + "Predecessor not mapped to a loop!"); + continue; + } + + // We just insert into the loop set here. We'll add these blocks to the + // exit loop after we build up the set in an order that doesn't rely on + // predecessor order (which in turn relies on use list order). + bool Inserted = ExitLoopMap.insert({PredBB, ExitL}).second; + (void)Inserted; + assert(Inserted && "Should only visit an unlooped block once!"); + + // And recurse through to its predecessors. + Worklist.push_back(PredBB); + } + } while (!Worklist.empty()); + } + + // Now that the ExitLoopMap gives as mapping for all the non-looping cloned + // blocks to their outer loops, walk the cloned blocks and the cloned exits + // in their original order adding them to the correct loop. + + // We need a stable insertion order. We use the order of the original loop + // order and map into the correct parent loop. + for (auto *BB : llvm::concat<BasicBlock *const>( + makeArrayRef(ClonedPH), ClonedLoopBlocks, ClonedExitsInLoops)) + if (Loop *OuterL = ExitLoopMap.lookup(BB)) + OuterL->addBasicBlockToLoop(BB, LI); + +#ifndef NDEBUG + for (auto &BBAndL : ExitLoopMap) { + auto *BB = BBAndL.first; + auto *OuterL = BBAndL.second; + assert(LI.getLoopFor(BB) == OuterL && + "Failed to put all blocks into outer loops!"); + } +#endif + + // Now that all the blocks are placed into the correct containing loop in the + // absence of child loops, find all the potentially cloned child loops and + // clone them into whatever outer loop we placed their header into. + for (Loop *ChildL : OrigL) { + auto *ClonedChildHeader = + cast_or_null<BasicBlock>(VMap.lookup(ChildL->getHeader())); + if (!ClonedChildHeader || BlocksInClonedLoop.count(ClonedChildHeader)) + continue; + +#ifndef NDEBUG + for (auto *ChildLoopBB : ChildL->blocks()) + assert(VMap.count(ChildLoopBB) && + "Cloned a child loop header but not all of that loops blocks!"); +#endif + + NonChildClonedLoops.push_back(cloneLoopNest( + *ChildL, ExitLoopMap.lookup(ClonedChildHeader), VMap, LI)); + } + + // Return the main cloned loop if any. + return ClonedL; +} + +static void deleteDeadBlocksFromLoop(Loop &L, BasicBlock *DeadSubtreeRoot, + SmallVectorImpl<BasicBlock *> &ExitBlocks, + DominatorTree &DT, LoopInfo &LI) { + // Walk the dominator tree to build up the set of blocks we will delete here. + // The order is designed to allow us to always delete bottom-up and avoid any + // dangling uses. + SmallSetVector<BasicBlock *, 16> DeadBlocks; + DeadBlocks.insert(DeadSubtreeRoot); + for (int i = 0; i < (int)DeadBlocks.size(); ++i) + for (DomTreeNode *ChildN : *DT[DeadBlocks[i]]) { + // FIXME: This assert should pass and that means we don't change nearly + // as much below! Consider rewriting all of this to avoid deleting + // blocks. They are always cloned before being deleted, and so instead + // could just be moved. + // FIXME: This in turn means that we might actually be more able to + // update the domtree. + assert((L.contains(ChildN->getBlock()) || + llvm::find(ExitBlocks, ChildN->getBlock()) != ExitBlocks.end()) && + "Should never reach beyond the loop and exits when deleting!"); + DeadBlocks.insert(ChildN->getBlock()); + } + + // Filter out the dead blocks from the exit blocks list so that it can be + // used in the caller. + llvm::erase_if(ExitBlocks, + [&](BasicBlock *BB) { return DeadBlocks.count(BB); }); + + // Remove these blocks from their successors. + for (auto *BB : DeadBlocks) + for (BasicBlock *SuccBB : successors(BB)) + SuccBB->removePredecessor(BB, /*DontDeleteUselessPHIs*/ true); + + // Walk from this loop up through its parents removing all of the dead blocks. + for (Loop *ParentL = &L; ParentL; ParentL = ParentL->getParentLoop()) { + for (auto *BB : DeadBlocks) + ParentL->getBlocksSet().erase(BB); + llvm::erase_if(ParentL->getBlocksVector(), + [&](BasicBlock *BB) { return DeadBlocks.count(BB); }); + } + + // Now delete the dead child loops. This raw delete will clear them + // recursively. + llvm::erase_if(L.getSubLoopsVector(), [&](Loop *ChildL) { + if (!DeadBlocks.count(ChildL->getHeader())) + return false; + + assert(llvm::all_of(ChildL->blocks(), + [&](BasicBlock *ChildBB) { + return DeadBlocks.count(ChildBB); + }) && + "If the child loop header is dead all blocks in the child loop must " + "be dead as well!"); + LI.destroy(ChildL); + return true; + }); + + // Remove the mappings for the dead blocks. + for (auto *BB : DeadBlocks) + LI.changeLoopFor(BB, nullptr); + + // Drop all the references from these blocks to others to handle cyclic + // references as we start deleting the blocks themselves. + for (auto *BB : DeadBlocks) + BB->dropAllReferences(); + + for (auto *BB : llvm::reverse(DeadBlocks)) { + DT.eraseNode(BB); + BB->eraseFromParent(); + } +} + +/// Recompute the set of blocks in a loop after unswitching. +/// +/// This walks from the original headers predecessors to rebuild the loop. We +/// take advantage of the fact that new blocks can't have been added, and so we +/// filter by the original loop's blocks. This also handles potentially +/// unreachable code that we don't want to explore but might be found examining +/// the predecessors of the header. +/// +/// If the original loop is no longer a loop, this will return an empty set. If +/// it remains a loop, all the blocks within it will be added to the set +/// (including those blocks in inner loops). +static SmallPtrSet<const BasicBlock *, 16> recomputeLoopBlockSet(Loop &L, + LoopInfo &LI) { + SmallPtrSet<const BasicBlock *, 16> LoopBlockSet; + + auto *PH = L.getLoopPreheader(); + auto *Header = L.getHeader(); + + // A worklist to use while walking backwards from the header. + SmallVector<BasicBlock *, 16> Worklist; + + // First walk the predecessors of the header to find the backedges. This will + // form the basis of our walk. + for (auto *Pred : predecessors(Header)) { + // Skip the preheader. + if (Pred == PH) + continue; + + // Because the loop was in simplified form, the only non-loop predecessor + // is the preheader. + assert(L.contains(Pred) && "Found a predecessor of the loop header other " + "than the preheader that is not part of the " + "loop!"); + + // Insert this block into the loop set and on the first visit and, if it + // isn't the header we're currently walking, put it into the worklist to + // recurse through. + if (LoopBlockSet.insert(Pred).second && Pred != Header) + Worklist.push_back(Pred); + } + + // If no backedges were found, we're done. + if (LoopBlockSet.empty()) + return LoopBlockSet; + + // Add the loop header to the set. + LoopBlockSet.insert(Header); + + // We found backedges, recurse through them to identify the loop blocks. + while (!Worklist.empty()) { + BasicBlock *BB = Worklist.pop_back_val(); + assert(LoopBlockSet.count(BB) && "Didn't put block into the loop set!"); + + // Because we know the inner loop structure remains valid we can use the + // loop structure to jump immediately across the entire nested loop. + // Further, because it is in loop simplified form, we can directly jump + // to its preheader afterward. + if (Loop *InnerL = LI.getLoopFor(BB)) + if (InnerL != &L) { + assert(L.contains(InnerL) && + "Should not reach a loop *outside* this loop!"); + // The preheader is the only possible predecessor of the loop so + // insert it into the set and check whether it was already handled. + auto *InnerPH = InnerL->getLoopPreheader(); + assert(L.contains(InnerPH) && "Cannot contain an inner loop block " + "but not contain the inner loop " + "preheader!"); + if (!LoopBlockSet.insert(InnerPH).second) + // The only way to reach the preheader is through the loop body + // itself so if it has been visited the loop is already handled. + continue; + + // Insert all of the blocks (other than those already present) into + // the loop set. The only block we expect to already be in the set is + // the one we used to find this loop as we immediately handle the + // others the first time we encounter the loop. + for (auto *InnerBB : InnerL->blocks()) { + if (InnerBB == BB) { + assert(LoopBlockSet.count(InnerBB) && + "Block should already be in the set!"); + continue; + } + + bool Inserted = LoopBlockSet.insert(InnerBB).second; + (void)Inserted; + assert(Inserted && "Should only insert an inner loop once!"); + } + + // Add the preheader to the worklist so we will continue past the + // loop body. + Worklist.push_back(InnerPH); + continue; + } + + // Insert any predecessors that were in the original loop into the new + // set, and if the insert is successful, add them to the worklist. + for (auto *Pred : predecessors(BB)) + if (L.contains(Pred) && LoopBlockSet.insert(Pred).second) + Worklist.push_back(Pred); + } + + // We've found all the blocks participating in the loop, return our completed + // set. + return LoopBlockSet; +} + +/// Rebuild a loop after unswitching removes some subset of blocks and edges. +/// +/// The removal may have removed some child loops entirely but cannot have +/// disturbed any remaining child loops. However, they may need to be hoisted +/// to the parent loop (or to be top-level loops). The original loop may be +/// completely removed. +/// +/// The sibling loops resulting from this update are returned. If the original +/// loop remains a valid loop, it will be the first entry in this list with all +/// of the newly sibling loops following it. +/// +/// Returns true if the loop remains a loop after unswitching, and false if it +/// is no longer a loop after unswitching (and should not continue to be +/// referenced). +static bool rebuildLoopAfterUnswitch(Loop &L, ArrayRef<BasicBlock *> ExitBlocks, + LoopInfo &LI, + SmallVectorImpl<Loop *> &HoistedLoops) { + auto *PH = L.getLoopPreheader(); + + // Compute the actual parent loop from the exit blocks. Because we may have + // pruned some exits the loop may be different from the original parent. + Loop *ParentL = nullptr; + SmallVector<Loop *, 4> ExitLoops; + SmallVector<BasicBlock *, 4> ExitsInLoops; + ExitsInLoops.reserve(ExitBlocks.size()); + for (auto *ExitBB : ExitBlocks) + if (Loop *ExitL = LI.getLoopFor(ExitBB)) { + ExitLoops.push_back(ExitL); + ExitsInLoops.push_back(ExitBB); + if (!ParentL || (ParentL != ExitL && ParentL->contains(ExitL))) + ParentL = ExitL; + } + + // Recompute the blocks participating in this loop. This may be empty if it + // is no longer a loop. + auto LoopBlockSet = recomputeLoopBlockSet(L, LI); + + // If we still have a loop, we need to re-set the loop's parent as the exit + // block set changing may have moved it within the loop nest. Note that this + // can only happen when this loop has a parent as it can only hoist the loop + // *up* the nest. + if (!LoopBlockSet.empty() && L.getParentLoop() != ParentL) { + // Remove this loop's (original) blocks from all of the intervening loops. + for (Loop *IL = L.getParentLoop(); IL != ParentL; + IL = IL->getParentLoop()) { + IL->getBlocksSet().erase(PH); + for (auto *BB : L.blocks()) + IL->getBlocksSet().erase(BB); + llvm::erase_if(IL->getBlocksVector(), [&](BasicBlock *BB) { + return BB == PH || L.contains(BB); + }); + } + + LI.changeLoopFor(PH, ParentL); + L.getParentLoop()->removeChildLoop(&L); + if (ParentL) + ParentL->addChildLoop(&L); + else + LI.addTopLevelLoop(&L); + } + + // Now we update all the blocks which are no longer within the loop. + auto &Blocks = L.getBlocksVector(); + auto BlocksSplitI = + LoopBlockSet.empty() + ? Blocks.begin() + : std::stable_partition( + Blocks.begin(), Blocks.end(), + [&](BasicBlock *BB) { return LoopBlockSet.count(BB); }); + + // Before we erase the list of unlooped blocks, build a set of them. + SmallPtrSet<BasicBlock *, 16> UnloopedBlocks(BlocksSplitI, Blocks.end()); + if (LoopBlockSet.empty()) + UnloopedBlocks.insert(PH); + + // Now erase these blocks from the loop. + for (auto *BB : make_range(BlocksSplitI, Blocks.end())) + L.getBlocksSet().erase(BB); + Blocks.erase(BlocksSplitI, Blocks.end()); + + // Sort the exits in ascending loop depth, we'll work backwards across these + // to process them inside out. + std::stable_sort(ExitsInLoops.begin(), ExitsInLoops.end(), + [&](BasicBlock *LHS, BasicBlock *RHS) { + return LI.getLoopDepth(LHS) < LI.getLoopDepth(RHS); + }); + + // We'll build up a set for each exit loop. + SmallPtrSet<BasicBlock *, 16> NewExitLoopBlocks; + Loop *PrevExitL = L.getParentLoop(); // The deepest possible exit loop. + + auto RemoveUnloopedBlocksFromLoop = + [](Loop &L, SmallPtrSetImpl<BasicBlock *> &UnloopedBlocks) { + for (auto *BB : UnloopedBlocks) + L.getBlocksSet().erase(BB); + llvm::erase_if(L.getBlocksVector(), [&](BasicBlock *BB) { + return UnloopedBlocks.count(BB); + }); + }; + + SmallVector<BasicBlock *, 16> Worklist; + while (!UnloopedBlocks.empty() && !ExitsInLoops.empty()) { + assert(Worklist.empty() && "Didn't clear worklist!"); + assert(NewExitLoopBlocks.empty() && "Didn't clear loop set!"); + + // Grab the next exit block, in decreasing loop depth order. + BasicBlock *ExitBB = ExitsInLoops.pop_back_val(); + Loop &ExitL = *LI.getLoopFor(ExitBB); + assert(ExitL.contains(&L) && "Exit loop must contain the inner loop!"); + + // Erase all of the unlooped blocks from the loops between the previous + // exit loop and this exit loop. This works because the ExitInLoops list is + // sorted in increasing order of loop depth and thus we visit loops in + // decreasing order of loop depth. + for (; PrevExitL != &ExitL; PrevExitL = PrevExitL->getParentLoop()) + RemoveUnloopedBlocksFromLoop(*PrevExitL, UnloopedBlocks); + + // Walk the CFG back until we hit the cloned PH adding everything reachable + // and in the unlooped set to this exit block's loop. + Worklist.push_back(ExitBB); + do { + BasicBlock *BB = Worklist.pop_back_val(); + // We can stop recursing at the cloned preheader (if we get there). + if (BB == PH) + continue; + + for (BasicBlock *PredBB : predecessors(BB)) { + // If this pred has already been moved to our set or is part of some + // (inner) loop, no update needed. + if (!UnloopedBlocks.erase(PredBB)) { + assert((NewExitLoopBlocks.count(PredBB) || + ExitL.contains(LI.getLoopFor(PredBB))) && + "Predecessor not in a nested loop (or already visited)!"); + continue; + } + + // We just insert into the loop set here. We'll add these blocks to the + // exit loop after we build up the set in a deterministic order rather + // than the predecessor-influenced visit order. + bool Inserted = NewExitLoopBlocks.insert(PredBB).second; + (void)Inserted; + assert(Inserted && "Should only visit an unlooped block once!"); + + // And recurse through to its predecessors. + Worklist.push_back(PredBB); + } + } while (!Worklist.empty()); + + // If blocks in this exit loop were directly part of the original loop (as + // opposed to a child loop) update the map to point to this exit loop. This + // just updates a map and so the fact that the order is unstable is fine. + for (auto *BB : NewExitLoopBlocks) + if (Loop *BBL = LI.getLoopFor(BB)) + if (BBL == &L || !L.contains(BBL)) + LI.changeLoopFor(BB, &ExitL); + + // We will remove the remaining unlooped blocks from this loop in the next + // iteration or below. + NewExitLoopBlocks.clear(); + } + + // Any remaining unlooped blocks are no longer part of any loop unless they + // are part of some child loop. + for (; PrevExitL; PrevExitL = PrevExitL->getParentLoop()) + RemoveUnloopedBlocksFromLoop(*PrevExitL, UnloopedBlocks); + for (auto *BB : UnloopedBlocks) + if (Loop *BBL = LI.getLoopFor(BB)) + if (BBL == &L || !L.contains(BBL)) + LI.changeLoopFor(BB, nullptr); + + // Sink all the child loops whose headers are no longer in the loop set to + // the parent (or to be top level loops). We reach into the loop and directly + // update its subloop vector to make this batch update efficient. + auto &SubLoops = L.getSubLoopsVector(); + auto SubLoopsSplitI = + LoopBlockSet.empty() + ? SubLoops.begin() + : std::stable_partition( + SubLoops.begin(), SubLoops.end(), [&](Loop *SubL) { + return LoopBlockSet.count(SubL->getHeader()); + }); + for (auto *HoistedL : make_range(SubLoopsSplitI, SubLoops.end())) { + HoistedLoops.push_back(HoistedL); + HoistedL->setParentLoop(nullptr); + + // To compute the new parent of this hoisted loop we look at where we + // placed the preheader above. We can't lookup the header itself because we + // retained the mapping from the header to the hoisted loop. But the + // preheader and header should have the exact same new parent computed + // based on the set of exit blocks from the original loop as the preheader + // is a predecessor of the header and so reached in the reverse walk. And + // because the loops were all in simplified form the preheader of the + // hoisted loop can't be part of some *other* loop. + if (auto *NewParentL = LI.getLoopFor(HoistedL->getLoopPreheader())) + NewParentL->addChildLoop(HoistedL); + else + LI.addTopLevelLoop(HoistedL); + } + SubLoops.erase(SubLoopsSplitI, SubLoops.end()); + + // Actually delete the loop if nothing remained within it. + if (Blocks.empty()) { + assert(SubLoops.empty() && + "Failed to remove all subloops from the original loop!"); + if (Loop *ParentL = L.getParentLoop()) + ParentL->removeChildLoop(llvm::find(*ParentL, &L)); + else + LI.removeLoop(llvm::find(LI, &L)); + LI.destroy(&L); + return false; + } + + return true; +} + +/// Helper to visit a dominator subtree, invoking a callable on each node. +/// +/// Returning false at any point will stop walking past that node of the tree. +template <typename CallableT> +void visitDomSubTree(DominatorTree &DT, BasicBlock *BB, CallableT Callable) { + SmallVector<DomTreeNode *, 4> DomWorklist; + DomWorklist.push_back(DT[BB]); +#ifndef NDEBUG + SmallPtrSet<DomTreeNode *, 4> Visited; + Visited.insert(DT[BB]); +#endif + do { + DomTreeNode *N = DomWorklist.pop_back_val(); + + // Visit this node. + if (!Callable(N->getBlock())) + continue; + + // Accumulate the child nodes. + for (DomTreeNode *ChildN : *N) { + assert(Visited.insert(ChildN).second && + "Cannot visit a node twice when walking a tree!"); + DomWorklist.push_back(ChildN); + } + } while (!DomWorklist.empty()); +} + +/// Take an invariant branch that has been determined to be safe and worthwhile +/// to unswitch despite being non-trivial to do so and perform the unswitch. +/// +/// This directly updates the CFG to hoist the predicate out of the loop, and +/// clone the necessary parts of the loop to maintain behavior. +/// +/// It also updates both dominator tree and loopinfo based on the unswitching. +/// +/// Once unswitching has been performed it runs the provided callback to report +/// the new loops and no-longer valid loops to the caller. +static bool unswitchInvariantBranch( + Loop &L, BranchInst &BI, DominatorTree &DT, LoopInfo &LI, + AssumptionCache &AC, + function_ref<void(bool, ArrayRef<Loop *>)> NonTrivialUnswitchCB) { + assert(BI.isConditional() && "Can only unswitch a conditional branch!"); + assert(L.isLoopInvariant(BI.getCondition()) && + "Can only unswitch an invariant branch condition!"); + + // Constant and BBs tracking the cloned and continuing successor. + const int ClonedSucc = 0; + auto *ParentBB = BI.getParent(); + auto *UnswitchedSuccBB = BI.getSuccessor(ClonedSucc); + auto *ContinueSuccBB = BI.getSuccessor(1 - ClonedSucc); + + assert(UnswitchedSuccBB != ContinueSuccBB && + "Should not unswitch a branch that always goes to the same place!"); + + // The branch should be in this exact loop. Any inner loop's invariant branch + // should be handled by unswitching that inner loop. The caller of this + // routine should filter out any candidates that remain (but were skipped for + // whatever reason). + assert(LI.getLoopFor(ParentBB) == &L && "Branch in an inner loop!"); + + SmallVector<BasicBlock *, 4> ExitBlocks; + L.getUniqueExitBlocks(ExitBlocks); + + // We cannot unswitch if exit blocks contain a cleanuppad instruction as we + // don't know how to split those exit blocks. + // FIXME: We should teach SplitBlock to handle this and remove this + // restriction. + for (auto *ExitBB : ExitBlocks) + if (isa<CleanupPadInst>(ExitBB->getFirstNonPHI())) + return false; + + SmallPtrSet<BasicBlock *, 4> ExitBlockSet(ExitBlocks.begin(), + ExitBlocks.end()); + + // Compute the parent loop now before we start hacking on things. + Loop *ParentL = L.getParentLoop(); + + // Compute the outer-most loop containing one of our exit blocks. This is the + // furthest up our loopnest which can be mutated, which we will use below to + // update things. + Loop *OuterExitL = &L; + for (auto *ExitBB : ExitBlocks) { + Loop *NewOuterExitL = LI.getLoopFor(ExitBB); + if (!NewOuterExitL) { + // We exited the entire nest with this block, so we're done. + OuterExitL = nullptr; + break; + } + if (NewOuterExitL != OuterExitL && NewOuterExitL->contains(OuterExitL)) + OuterExitL = NewOuterExitL; + } + + // If the edge we *aren't* cloning in the unswitch (the continuing edge) + // dominates its target, we can skip cloning the dominated region of the loop + // and its exits. We compute this as a set of nodes to be skipped. + SmallPtrSet<BasicBlock *, 4> SkippedLoopAndExitBlocks; + if (ContinueSuccBB->getUniquePredecessor() || + llvm::all_of(predecessors(ContinueSuccBB), [&](BasicBlock *PredBB) { + return PredBB == ParentBB || DT.dominates(ContinueSuccBB, PredBB); + })) { + visitDomSubTree(DT, ContinueSuccBB, [&](BasicBlock *BB) { + SkippedLoopAndExitBlocks.insert(BB); + return true; + }); + } + // Similarly, if the edge we *are* cloning in the unswitch (the unswitched + // edge) dominates its target, we will end up with dead nodes in the original + // loop and its exits that will need to be deleted. Here, we just retain that + // the property holds and will compute the deleted set later. + bool DeleteUnswitchedSucc = + UnswitchedSuccBB->getUniquePredecessor() || + llvm::all_of(predecessors(UnswitchedSuccBB), [&](BasicBlock *PredBB) { + return PredBB == ParentBB || DT.dominates(UnswitchedSuccBB, PredBB); + }); + + // Split the preheader, so that we know that there is a safe place to insert + // the conditional branch. We will change the preheader to have a conditional + // branch on LoopCond. The original preheader will become the split point + // between the unswitched versions, and we will have a new preheader for the + // original loop. + BasicBlock *SplitBB = L.getLoopPreheader(); + BasicBlock *LoopPH = SplitEdge(SplitBB, L.getHeader(), &DT, &LI); + + // Keep a mapping for the cloned values. + ValueToValueMapTy VMap; + + // Build the cloned blocks from the loop. + auto *ClonedPH = buildClonedLoopBlocks( + L, LoopPH, SplitBB, ExitBlocks, ParentBB, UnswitchedSuccBB, + ContinueSuccBB, SkippedLoopAndExitBlocks, VMap, AC, DT, LI); + + // Build the cloned loop structure itself. This may be substantially + // different from the original structure due to the simplified CFG. This also + // handles inserting all the cloned blocks into the correct loops. + SmallVector<Loop *, 4> NonChildClonedLoops; + Loop *ClonedL = + buildClonedLoops(L, ExitBlocks, VMap, LI, NonChildClonedLoops); + + // Remove the parent as a predecessor of the unswitched successor. + UnswitchedSuccBB->removePredecessor(ParentBB, /*DontDeleteUselessPHIs*/ true); + + // Now splice the branch from the original loop and use it to select between + // the two loops. + SplitBB->getTerminator()->eraseFromParent(); + SplitBB->getInstList().splice(SplitBB->end(), ParentBB->getInstList(), BI); + BI.setSuccessor(ClonedSucc, ClonedPH); + BI.setSuccessor(1 - ClonedSucc, LoopPH); + + // Create a new unconditional branch to the continuing block (as opposed to + // the one cloned). + BranchInst::Create(ContinueSuccBB, ParentBB); + + // Delete anything that was made dead in the original loop due to + // unswitching. + if (DeleteUnswitchedSucc) + deleteDeadBlocksFromLoop(L, UnswitchedSuccBB, ExitBlocks, DT, LI); + + SmallVector<Loop *, 4> HoistedLoops; + bool IsStillLoop = rebuildLoopAfterUnswitch(L, ExitBlocks, LI, HoistedLoops); + + // This will have completely invalidated the dominator tree. We can't easily + // bound how much is invalid because in some cases we will refine the + // predecessor set of exit blocks of the loop which can move large unrelated + // regions of code into a new subtree. + // + // FIXME: Eventually, we should use an incremental update utility that + // leverages the existing information in the dominator tree (and potentially + // the nature of the change) to more efficiently update things. + DT.recalculate(*SplitBB->getParent()); + + // We can change which blocks are exit blocks of all the cloned sibling + // loops, the current loop, and any parent loops which shared exit blocks + // with the current loop. As a consequence, we need to re-form LCSSA for + // them. But we shouldn't need to re-form LCSSA for any child loops. + // FIXME: This could be made more efficient by tracking which exit blocks are + // new, and focusing on them, but that isn't likely to be necessary. + // + // In order to reasonably rebuild LCSSA we need to walk inside-out across the + // loop nest and update every loop that could have had its exits changed. We + // also need to cover any intervening loops. We add all of these loops to + // a list and sort them by loop depth to achieve this without updating + // unnecessary loops. + auto UpdateLCSSA = [&](Loop &UpdateL) { +#ifndef NDEBUG + for (Loop *ChildL : UpdateL) + assert(ChildL->isRecursivelyLCSSAForm(DT, LI) && + "Perturbed a child loop's LCSSA form!"); +#endif + formLCSSA(UpdateL, DT, &LI, nullptr); + }; + + // For non-child cloned loops and hoisted loops, we just need to update LCSSA + // and we can do it in any order as they don't nest relative to each other. + for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) + UpdateLCSSA(*UpdatedL); + + // If the original loop had exit blocks, walk up through the outer most loop + // of those exit blocks to update LCSSA and form updated dedicated exits. + if (OuterExitL != &L) { + SmallVector<Loop *, 4> OuterLoops; + // We start with the cloned loop and the current loop if they are loops and + // move toward OuterExitL. Also, if either the cloned loop or the current + // loop have become top level loops we need to walk all the way out. + if (ClonedL) { + OuterLoops.push_back(ClonedL); + if (!ClonedL->getParentLoop()) + OuterExitL = nullptr; + } + if (IsStillLoop) { + OuterLoops.push_back(&L); + if (!L.getParentLoop()) + OuterExitL = nullptr; + } + // Grab all of the enclosing loops now. + for (Loop *OuterL = ParentL; OuterL != OuterExitL; + OuterL = OuterL->getParentLoop()) + OuterLoops.push_back(OuterL); + + // Finally, update our list of outer loops. This is nicely ordered to work + // inside-out. + for (Loop *OuterL : OuterLoops) { + // First build LCSSA for this loop so that we can preserve it when + // forming dedicated exits. We don't want to perturb some other loop's + // LCSSA while doing that CFG edit. + UpdateLCSSA(*OuterL); + + // For loops reached by this loop's original exit blocks we may + // introduced new, non-dedicated exits. At least try to re-form dedicated + // exits for these loops. This may fail if they couldn't have dedicated + // exits to start with. + formDedicatedExitBlocks(OuterL, &DT, &LI, /*PreserveLCSSA*/ true); + } + } + +#ifndef NDEBUG + // Verify the entire loop structure to catch any incorrect updates before we + // progress in the pass pipeline. + LI.verify(DT); +#endif + + // Now that we've unswitched something, make callbacks to report the changes. + // For that we need to merge together the updated loops and the cloned loops + // and check whether the original loop survived. + SmallVector<Loop *, 4> SibLoops; + for (Loop *UpdatedL : llvm::concat<Loop *>(NonChildClonedLoops, HoistedLoops)) + if (UpdatedL->getParentLoop() == ParentL) + SibLoops.push_back(UpdatedL); + NonTrivialUnswitchCB(IsStillLoop, SibLoops); + + ++NumBranches; + return true; +} + +/// Recursively compute the cost of a dominator subtree based on the per-block +/// cost map provided. +/// +/// The recursive computation is memozied into the provided DT-indexed cost map +/// to allow querying it for most nodes in the domtree without it becoming +/// quadratic. +static int +computeDomSubtreeCost(DomTreeNode &N, + const SmallDenseMap<BasicBlock *, int, 4> &BBCostMap, + SmallDenseMap<DomTreeNode *, int, 4> &DTCostMap) { + // Don't accumulate cost (or recurse through) blocks not in our block cost + // map and thus not part of the duplication cost being considered. + auto BBCostIt = BBCostMap.find(N.getBlock()); + if (BBCostIt == BBCostMap.end()) + return 0; + + // Lookup this node to see if we already computed its cost. + auto DTCostIt = DTCostMap.find(&N); + if (DTCostIt != DTCostMap.end()) + return DTCostIt->second; + + // If not, we have to compute it. We can't use insert above and update + // because computing the cost may insert more things into the map. + int Cost = std::accumulate( + N.begin(), N.end(), BBCostIt->second, [&](int Sum, DomTreeNode *ChildN) { + return Sum + computeDomSubtreeCost(*ChildN, BBCostMap, DTCostMap); + }); + bool Inserted = DTCostMap.insert({&N, Cost}).second; + (void)Inserted; + assert(Inserted && "Should not insert a node while visiting children!"); + return Cost; +} + +/// Unswitch control flow predicated on loop invariant conditions. +/// +/// This first hoists all branches or switches which are trivial (IE, do not +/// require duplicating any part of the loop) out of the loop body. It then +/// looks at other loop invariant control flows and tries to unswitch those as +/// well by cloning the loop if the result is small enough. +static bool +unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC, + TargetTransformInfo &TTI, bool NonTrivial, + function_ref<void(bool, ArrayRef<Loop *>)> NonTrivialUnswitchCB) { + assert(L.isRecursivelyLCSSAForm(DT, LI) && + "Loops must be in LCSSA form before unswitching."); + bool Changed = false; + + // Must be in loop simplified form: we need a preheader and dedicated exits. + if (!L.isLoopSimplifyForm()) + return false; + + // Try trivial unswitch first before loop over other basic blocks in the loop. + Changed |= unswitchAllTrivialConditions(L, DT, LI); + + // If we're not doing non-trivial unswitching, we're done. We both accept + // a parameter but also check a local flag that can be used for testing + // a debugging. + if (!NonTrivial && !EnableNonTrivialUnswitch) + return Changed; + + // Collect all remaining invariant branch conditions within this loop (as + // opposed to an inner loop which would be handled when visiting that inner + // loop). + SmallVector<TerminatorInst *, 4> UnswitchCandidates; + for (auto *BB : L.blocks()) + if (LI.getLoopFor(BB) == &L) + if (auto *BI = dyn_cast<BranchInst>(BB->getTerminator())) + if (BI->isConditional() && L.isLoopInvariant(BI->getCondition()) && + BI->getSuccessor(0) != BI->getSuccessor(1)) + UnswitchCandidates.push_back(BI); + + // If we didn't find any candidates, we're done. + if (UnswitchCandidates.empty()) + return Changed; + + DEBUG(dbgs() << "Considering " << UnswitchCandidates.size() + << " non-trivial loop invariant conditions for unswitching.\n"); + + // Given that unswitching these terminators will require duplicating parts of + // the loop, so we need to be able to model that cost. Compute the ephemeral + // values and set up a data structure to hold per-BB costs. We cache each + // block's cost so that we don't recompute this when considering different + // subsets of the loop for duplication during unswitching. + SmallPtrSet<const Value *, 4> EphValues; + CodeMetrics::collectEphemeralValues(&L, &AC, EphValues); + SmallDenseMap<BasicBlock *, int, 4> BBCostMap; + + // Compute the cost of each block, as well as the total loop cost. Also, bail + // out if we see instructions which are incompatible with loop unswitching + // (convergent, noduplicate, or cross-basic-block tokens). + // FIXME: We might be able to safely handle some of these in non-duplicated + // regions. + int LoopCost = 0; + for (auto *BB : L.blocks()) { + int Cost = 0; + for (auto &I : *BB) { + if (EphValues.count(&I)) + continue; + + if (I.getType()->isTokenTy() && I.isUsedOutsideOfBlock(BB)) + return Changed; + if (auto CS = CallSite(&I)) + if (CS.isConvergent() || CS.cannotDuplicate()) + return Changed; + + Cost += TTI.getUserCost(&I); + } + assert(Cost >= 0 && "Must not have negative costs!"); + LoopCost += Cost; + assert(LoopCost >= 0 && "Must not have negative loop costs!"); + BBCostMap[BB] = Cost; + } + DEBUG(dbgs() << " Total loop cost: " << LoopCost << "\n"); + + // Now we find the best candidate by searching for the one with the following + // properties in order: + // + // 1) An unswitching cost below the threshold + // 2) The smallest number of duplicated unswitch candidates (to avoid + // creating redundant subsequent unswitching) + // 3) The smallest cost after unswitching. + // + // We prioritize reducing fanout of unswitch candidates provided the cost + // remains below the threshold because this has a multiplicative effect. + // + // This requires memoizing each dominator subtree to avoid redundant work. + // + // FIXME: Need to actually do the number of candidates part above. + SmallDenseMap<DomTreeNode *, int, 4> DTCostMap; + // Given a terminator which might be unswitched, computes the non-duplicated + // cost for that terminator. + auto ComputeUnswitchedCost = [&](TerminatorInst *TI) { + BasicBlock &BB = *TI->getParent(); + SmallPtrSet<BasicBlock *, 4> Visited; + + int Cost = LoopCost; + for (BasicBlock *SuccBB : successors(&BB)) { + // Don't count successors more than once. + if (!Visited.insert(SuccBB).second) + continue; + + // This successor's domtree will not need to be duplicated after + // unswitching if the edge to the successor dominates it (and thus the + // entire tree). This essentially means there is no other path into this + // subtree and so it will end up live in only one clone of the loop. + if (SuccBB->getUniquePredecessor() || + llvm::all_of(predecessors(SuccBB), [&](BasicBlock *PredBB) { + return PredBB == &BB || DT.dominates(SuccBB, PredBB); + })) { + Cost -= computeDomSubtreeCost(*DT[SuccBB], BBCostMap, DTCostMap); + assert(Cost >= 0 && + "Non-duplicated cost should never exceed total loop cost!"); + } + } + + // Now scale the cost by the number of unique successors minus one. We + // subtract one because there is already at least one copy of the entire + // loop. This is computing the new cost of unswitching a condition. + assert(Visited.size() > 1 && + "Cannot unswitch a condition without multiple distinct successors!"); + return Cost * (Visited.size() - 1); + }; + TerminatorInst *BestUnswitchTI = nullptr; + int BestUnswitchCost; + for (TerminatorInst *CandidateTI : UnswitchCandidates) { + int CandidateCost = ComputeUnswitchedCost(CandidateTI); + DEBUG(dbgs() << " Computed cost of " << CandidateCost + << " for unswitch candidate: " << *CandidateTI << "\n"); + if (!BestUnswitchTI || CandidateCost < BestUnswitchCost) { + BestUnswitchTI = CandidateTI; + BestUnswitchCost = CandidateCost; + } + } + + if (BestUnswitchCost < UnswitchThreshold) { + DEBUG(dbgs() << " Trying to unswitch non-trivial (cost = " + << BestUnswitchCost << ") branch: " << *BestUnswitchTI + << "\n"); + Changed |= unswitchInvariantBranch(L, cast<BranchInst>(*BestUnswitchTI), DT, + LI, AC, NonTrivialUnswitchCB); + } else { + DEBUG(dbgs() << "Cannot unswitch, lowest cost found: " << BestUnswitchCost + << "\n"); + } + + return Changed; +} + +PreservedAnalyses SimpleLoopUnswitchPass::run(Loop &L, LoopAnalysisManager &AM, + LoopStandardAnalysisResults &AR, + LPMUpdater &U) { + Function &F = *L.getHeader()->getParent(); + (void)F; + + DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << L << "\n"); + + // Save the current loop name in a variable so that we can report it even + // after it has been deleted. + std::string LoopName = L.getName(); + + auto NonTrivialUnswitchCB = [&L, &U, &LoopName](bool CurrentLoopValid, + ArrayRef<Loop *> NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + U.addSiblingLoops(NewLoops); + + // If the current loop remains valid, we should revisit it to catch any + // other unswitch opportunities. Otherwise, we need to mark it as deleted. + if (CurrentLoopValid) + U.revisitCurrentLoop(); + else + U.markLoopAsDeleted(L, LoopName); + }; + + if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, + NonTrivialUnswitchCB)) + return PreservedAnalyses::all(); + +#ifndef NDEBUG + // Historically this pass has had issues with the dominator tree so verify it + // in asserts builds. + AR.DT.verifyDomTree(); +#endif + return getLoopPassPreservedAnalyses(); +} + +namespace { + +class SimpleLoopUnswitchLegacyPass : public LoopPass { + bool NonTrivial; + +public: + static char ID; // Pass ID, replacement for typeid + + explicit SimpleLoopUnswitchLegacyPass(bool NonTrivial = false) + : LoopPass(ID), NonTrivial(NonTrivial) { + initializeSimpleLoopUnswitchLegacyPassPass( + *PassRegistry::getPassRegistry()); + } + + bool runOnLoop(Loop *L, LPPassManager &LPM) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + getLoopAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +bool SimpleLoopUnswitchLegacyPass::runOnLoop(Loop *L, LPPassManager &LPM) { + if (skipLoop(L)) + return false; + + Function &F = *L->getHeader()->getParent(); + + DEBUG(dbgs() << "Unswitching loop in " << F.getName() << ": " << *L << "\n"); + + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + + auto NonTrivialUnswitchCB = [&L, &LPM](bool CurrentLoopValid, + ArrayRef<Loop *> NewLoops) { + // If we did a non-trivial unswitch, we have added new (cloned) loops. + for (auto *NewL : NewLoops) + LPM.addLoop(*NewL); + + // If the current loop remains valid, re-add it to the queue. This is + // a little wasteful as we'll finish processing the current loop as well, + // but it is the best we can do in the old PM. + if (CurrentLoopValid) + LPM.addLoop(*L); + else + LPM.markLoopAsDeleted(*L); + }; + + bool Changed = + unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, NonTrivialUnswitchCB); + + // If anything was unswitched, also clear any cached information about this + // loop. + LPM.deleteSimpleAnalysisLoop(L); + +#ifndef NDEBUG + // Historically this pass has had issues with the dominator tree so verify it + // in asserts builds. + DT.verifyDomTree(); +#endif + return Changed; +} + +char SimpleLoopUnswitchLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", + "Simple unswitch loops", false, false) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(LoopPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(SimpleLoopUnswitchLegacyPass, "simple-loop-unswitch", + "Simple unswitch loops", false, false) + +Pass *llvm::createSimpleLoopUnswitchLegacyPass(bool NonTrivial) { + return new SimpleLoopUnswitchLegacyPass(NonTrivial); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp b/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp new file mode 100644 index 000000000000..1522170dc3b9 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SimplifyCFGPass.cpp @@ -0,0 +1,296 @@ +//===- SimplifyCFGPass.cpp - CFG Simplification Pass ----------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements dead code elimination and basic block merging, along +// with a collection of other peephole control flow optimizations. For example: +// +// * Removes basic blocks with no predecessors. +// * Merges a basic block into its predecessor if there is only one and the +// predecessor only has one successor. +// * Eliminates PHI nodes for basic blocks with a single predecessor. +// * Eliminates a basic block that only contains an unconditional branch. +// * Changes invoke instructions to nounwind functions to be calls. +// * Change things like "if (x) if (y)" into "if (x&y)". +// * etc.. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AssumptionCache.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Attributes.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Pass.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Scalar/SimplifyCFG.h" +#include "llvm/Transforms/Utils/Local.h" +#include <utility> +using namespace llvm; + +#define DEBUG_TYPE "simplifycfg" + +static cl::opt<unsigned> UserBonusInstThreshold( + "bonus-inst-threshold", cl::Hidden, cl::init(1), + cl::desc("Control the number of bonus instructions (default = 1)")); + +static cl::opt<bool> UserKeepLoops( + "keep-loops", cl::Hidden, cl::init(true), + cl::desc("Preserve canonical loop structure (default = true)")); + +static cl::opt<bool> UserSwitchToLookup( + "switch-to-lookup", cl::Hidden, cl::init(false), + cl::desc("Convert switches to lookup tables (default = false)")); + +static cl::opt<bool> UserForwardSwitchCond( + "forward-switch-cond", cl::Hidden, cl::init(false), + cl::desc("Forward switch condition to phi ops (default = false)")); + +static cl::opt<bool> UserSinkCommonInsts( + "sink-common-insts", cl::Hidden, cl::init(false), + cl::desc("Sink common instructions (default = false)")); + + +STATISTIC(NumSimpl, "Number of blocks simplified"); + +/// If we have more than one empty (other than phi node) return blocks, +/// merge them together to promote recursive block merging. +static bool mergeEmptyReturnBlocks(Function &F) { + bool Changed = false; + + BasicBlock *RetBlock = nullptr; + + // Scan all the blocks in the function, looking for empty return blocks. + for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; ) { + BasicBlock &BB = *BBI++; + + // Only look at return blocks. + ReturnInst *Ret = dyn_cast<ReturnInst>(BB.getTerminator()); + if (!Ret) continue; + + // Only look at the block if it is empty or the only other thing in it is a + // single PHI node that is the operand to the return. + if (Ret != &BB.front()) { + // Check for something else in the block. + BasicBlock::iterator I(Ret); + --I; + // Skip over debug info. + while (isa<DbgInfoIntrinsic>(I) && I != BB.begin()) + --I; + if (!isa<DbgInfoIntrinsic>(I) && + (!isa<PHINode>(I) || I != BB.begin() || Ret->getNumOperands() == 0 || + Ret->getOperand(0) != &*I)) + continue; + } + + // If this is the first returning block, remember it and keep going. + if (!RetBlock) { + RetBlock = &BB; + continue; + } + + // Otherwise, we found a duplicate return block. Merge the two. + Changed = true; + + // Case when there is no input to the return or when the returned values + // agree is trivial. Note that they can't agree if there are phis in the + // blocks. + if (Ret->getNumOperands() == 0 || + Ret->getOperand(0) == + cast<ReturnInst>(RetBlock->getTerminator())->getOperand(0)) { + BB.replaceAllUsesWith(RetBlock); + BB.eraseFromParent(); + continue; + } + + // If the canonical return block has no PHI node, create one now. + PHINode *RetBlockPHI = dyn_cast<PHINode>(RetBlock->begin()); + if (!RetBlockPHI) { + Value *InVal = cast<ReturnInst>(RetBlock->getTerminator())->getOperand(0); + pred_iterator PB = pred_begin(RetBlock), PE = pred_end(RetBlock); + RetBlockPHI = PHINode::Create(Ret->getOperand(0)->getType(), + std::distance(PB, PE), "merge", + &RetBlock->front()); + + for (pred_iterator PI = PB; PI != PE; ++PI) + RetBlockPHI->addIncoming(InVal, *PI); + RetBlock->getTerminator()->setOperand(0, RetBlockPHI); + } + + // Turn BB into a block that just unconditionally branches to the return + // block. This handles the case when the two return blocks have a common + // predecessor but that return different things. + RetBlockPHI->addIncoming(Ret->getOperand(0), &BB); + BB.getTerminator()->eraseFromParent(); + BranchInst::Create(RetBlock, &BB); + } + + return Changed; +} + +/// Call SimplifyCFG on all the blocks in the function, +/// iterating until no more changes are made. +static bool iterativelySimplifyCFG(Function &F, const TargetTransformInfo &TTI, + const SimplifyCFGOptions &Options) { + bool Changed = false; + bool LocalChange = true; + + SmallVector<std::pair<const BasicBlock *, const BasicBlock *>, 32> Edges; + FindFunctionBackedges(F, Edges); + SmallPtrSet<BasicBlock *, 16> LoopHeaders; + for (unsigned i = 0, e = Edges.size(); i != e; ++i) + LoopHeaders.insert(const_cast<BasicBlock *>(Edges[i].second)); + + while (LocalChange) { + LocalChange = false; + + // Loop over all of the basic blocks and remove them if they are unneeded. + for (Function::iterator BBIt = F.begin(); BBIt != F.end(); ) { + if (simplifyCFG(&*BBIt++, TTI, Options, &LoopHeaders)) { + LocalChange = true; + ++NumSimpl; + } + } + Changed |= LocalChange; + } + return Changed; +} + +static bool simplifyFunctionCFG(Function &F, const TargetTransformInfo &TTI, + const SimplifyCFGOptions &Options) { + bool EverChanged = removeUnreachableBlocks(F); + EverChanged |= mergeEmptyReturnBlocks(F); + EverChanged |= iterativelySimplifyCFG(F, TTI, Options); + + // If neither pass changed anything, we're done. + if (!EverChanged) return false; + + // iterativelySimplifyCFG can (rarely) make some loops dead. If this happens, + // removeUnreachableBlocks is needed to nuke them, which means we should + // iterate between the two optimizations. We structure the code like this to + // avoid rerunning iterativelySimplifyCFG if the second pass of + // removeUnreachableBlocks doesn't do anything. + if (!removeUnreachableBlocks(F)) + return true; + + do { + EverChanged = iterativelySimplifyCFG(F, TTI, Options); + EverChanged |= removeUnreachableBlocks(F); + } while (EverChanged); + + return true; +} + +// Command-line settings override compile-time settings. +SimplifyCFGPass::SimplifyCFGPass(const SimplifyCFGOptions &Opts) { + Options.BonusInstThreshold = UserBonusInstThreshold.getNumOccurrences() + ? UserBonusInstThreshold + : Opts.BonusInstThreshold; + Options.ForwardSwitchCondToPhi = UserForwardSwitchCond.getNumOccurrences() + ? UserForwardSwitchCond + : Opts.ForwardSwitchCondToPhi; + Options.ConvertSwitchToLookupTable = UserSwitchToLookup.getNumOccurrences() + ? UserSwitchToLookup + : Opts.ConvertSwitchToLookupTable; + Options.NeedCanonicalLoop = UserKeepLoops.getNumOccurrences() + ? UserKeepLoops + : Opts.NeedCanonicalLoop; + Options.SinkCommonInsts = UserSinkCommonInsts.getNumOccurrences() + ? UserSinkCommonInsts + : Opts.SinkCommonInsts; +} + +PreservedAnalyses SimplifyCFGPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + Options.AC = &AM.getResult<AssumptionAnalysis>(F); + if (!simplifyFunctionCFG(F, TTI, Options)) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} + +namespace { +struct CFGSimplifyPass : public FunctionPass { + static char ID; + SimplifyCFGOptions Options; + std::function<bool(const Function &)> PredicateFtor; + + CFGSimplifyPass(unsigned Threshold = 1, bool ForwardSwitchCond = false, + bool ConvertSwitch = false, bool KeepLoops = true, + bool SinkCommon = false, + std::function<bool(const Function &)> Ftor = nullptr) + : FunctionPass(ID), PredicateFtor(std::move(Ftor)) { + + initializeCFGSimplifyPassPass(*PassRegistry::getPassRegistry()); + + // Check for command-line overrides of options for debug/customization. + Options.BonusInstThreshold = UserBonusInstThreshold.getNumOccurrences() + ? UserBonusInstThreshold + : Threshold; + + Options.ForwardSwitchCondToPhi = UserForwardSwitchCond.getNumOccurrences() + ? UserForwardSwitchCond + : ForwardSwitchCond; + + Options.ConvertSwitchToLookupTable = UserSwitchToLookup.getNumOccurrences() + ? UserSwitchToLookup + : ConvertSwitch; + + Options.NeedCanonicalLoop = + UserKeepLoops.getNumOccurrences() ? UserKeepLoops : KeepLoops; + + Options.SinkCommonInsts = UserSinkCommonInsts.getNumOccurrences() + ? UserSinkCommonInsts + : SinkCommon; + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F) || (PredicateFtor && !PredicateFtor(F))) + return false; + + Options.AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F); + auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return simplifyFunctionCFG(F, TTI, Options); + } + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<AssumptionCacheTracker>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } +}; +} + +char CFGSimplifyPass::ID = 0; +INITIALIZE_PASS_BEGIN(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, + false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker) +INITIALIZE_PASS_END(CFGSimplifyPass, "simplifycfg", "Simplify the CFG", false, + false) + +// Public interface to the CFGSimplification pass +FunctionPass * +llvm::createCFGSimplificationPass(unsigned Threshold, bool ForwardSwitchCond, + bool ConvertSwitch, bool KeepLoops, + bool SinkCommon, + std::function<bool(const Function &)> Ftor) { + return new CFGSimplifyPass(Threshold, ForwardSwitchCond, ConvertSwitch, + KeepLoops, SinkCommon, std::move(Ftor)); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/Sink.cpp b/contrib/llvm/lib/Transforms/Scalar/Sink.cpp new file mode 100644 index 000000000000..cfb8a062299f --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/Sink.cpp @@ -0,0 +1,306 @@ +//===-- Sink.cpp - Code Sinking -------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass moves instructions into successor blocks, when possible, so that +// they aren't executed on paths where their results aren't needed. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/Sink.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/AliasAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +using namespace llvm; + +#define DEBUG_TYPE "sink" + +STATISTIC(NumSunk, "Number of instructions sunk"); +STATISTIC(NumSinkIter, "Number of sinking iterations"); + +/// AllUsesDominatedByBlock - Return true if all uses of the specified value +/// occur in blocks dominated by the specified block. +static bool AllUsesDominatedByBlock(Instruction *Inst, BasicBlock *BB, + DominatorTree &DT) { + // Ignoring debug uses is necessary so debug info doesn't affect the code. + // This may leave a referencing dbg_value in the original block, before + // the definition of the vreg. Dwarf generator handles this although the + // user might not get the right info at runtime. + for (Use &U : Inst->uses()) { + // Determine the block of the use. + Instruction *UseInst = cast<Instruction>(U.getUser()); + BasicBlock *UseBlock = UseInst->getParent(); + if (PHINode *PN = dyn_cast<PHINode>(UseInst)) { + // PHI nodes use the operand in the predecessor block, not the block with + // the PHI. + unsigned Num = PHINode::getIncomingValueNumForOperand(U.getOperandNo()); + UseBlock = PN->getIncomingBlock(Num); + } + // Check that it dominates. + if (!DT.dominates(BB, UseBlock)) + return false; + } + return true; +} + +static bool isSafeToMove(Instruction *Inst, AliasAnalysis &AA, + SmallPtrSetImpl<Instruction *> &Stores) { + + if (Inst->mayWriteToMemory()) { + Stores.insert(Inst); + return false; + } + + if (LoadInst *L = dyn_cast<LoadInst>(Inst)) { + MemoryLocation Loc = MemoryLocation::get(L); + for (Instruction *S : Stores) + if (isModSet(AA.getModRefInfo(S, Loc))) + return false; + } + + if (isa<TerminatorInst>(Inst) || isa<PHINode>(Inst) || Inst->isEHPad() || + Inst->mayThrow()) + return false; + + if (auto CS = CallSite(Inst)) { + // Convergent operations cannot be made control-dependent on additional + // values. + if (CS.hasFnAttr(Attribute::Convergent)) + return false; + + for (Instruction *S : Stores) + if (isModSet(AA.getModRefInfo(S, CS))) + return false; + } + + return true; +} + +/// IsAcceptableTarget - Return true if it is possible to sink the instruction +/// in the specified basic block. +static bool IsAcceptableTarget(Instruction *Inst, BasicBlock *SuccToSinkTo, + DominatorTree &DT, LoopInfo &LI) { + assert(Inst && "Instruction to be sunk is null"); + assert(SuccToSinkTo && "Candidate sink target is null"); + + // It is not possible to sink an instruction into its own block. This can + // happen with loops. + if (Inst->getParent() == SuccToSinkTo) + return false; + + // It's never legal to sink an instruction into a block which terminates in an + // EH-pad. + if (SuccToSinkTo->getTerminator()->isExceptional()) + return false; + + // If the block has multiple predecessors, this would introduce computation + // on different code paths. We could split the critical edge, but for now we + // just punt. + // FIXME: Split critical edges if not backedges. + if (SuccToSinkTo->getUniquePredecessor() != Inst->getParent()) { + // We cannot sink a load across a critical edge - there may be stores in + // other code paths. + if (isa<LoadInst>(Inst)) + return false; + + // We don't want to sink across a critical edge if we don't dominate the + // successor. We could be introducing calculations to new code paths. + if (!DT.dominates(Inst->getParent(), SuccToSinkTo)) + return false; + + // Don't sink instructions into a loop. + Loop *succ = LI.getLoopFor(SuccToSinkTo); + Loop *cur = LI.getLoopFor(Inst->getParent()); + if (succ != nullptr && succ != cur) + return false; + } + + // Finally, check that all the uses of the instruction are actually + // dominated by the candidate + return AllUsesDominatedByBlock(Inst, SuccToSinkTo, DT); +} + +/// SinkInstruction - Determine whether it is safe to sink the specified machine +/// instruction out of its current block into a successor. +static bool SinkInstruction(Instruction *Inst, + SmallPtrSetImpl<Instruction *> &Stores, + DominatorTree &DT, LoopInfo &LI, AAResults &AA) { + + // Don't sink static alloca instructions. CodeGen assumes allocas outside the + // entry block are dynamically sized stack objects. + if (AllocaInst *AI = dyn_cast<AllocaInst>(Inst)) + if (AI->isStaticAlloca()) + return false; + + // Check if it's safe to move the instruction. + if (!isSafeToMove(Inst, AA, Stores)) + return false; + + // FIXME: This should include support for sinking instructions within the + // block they are currently in to shorten the live ranges. We often get + // instructions sunk into the top of a large block, but it would be better to + // also sink them down before their first use in the block. This xform has to + // be careful not to *increase* register pressure though, e.g. sinking + // "x = y + z" down if it kills y and z would increase the live ranges of y + // and z and only shrink the live range of x. + + // SuccToSinkTo - This is the successor to sink this instruction to, once we + // decide. + BasicBlock *SuccToSinkTo = nullptr; + + // Instructions can only be sunk if all their uses are in blocks + // dominated by one of the successors. + // Look at all the dominated blocks and see if we can sink it in one. + DomTreeNode *DTN = DT.getNode(Inst->getParent()); + for (DomTreeNode::iterator I = DTN->begin(), E = DTN->end(); + I != E && SuccToSinkTo == nullptr; ++I) { + BasicBlock *Candidate = (*I)->getBlock(); + // A node always immediate-dominates its children on the dominator + // tree. + if (IsAcceptableTarget(Inst, Candidate, DT, LI)) + SuccToSinkTo = Candidate; + } + + // If no suitable postdominator was found, look at all the successors and + // decide which one we should sink to, if any. + for (succ_iterator I = succ_begin(Inst->getParent()), + E = succ_end(Inst->getParent()); I != E && !SuccToSinkTo; ++I) { + if (IsAcceptableTarget(Inst, *I, DT, LI)) + SuccToSinkTo = *I; + } + + // If we couldn't find a block to sink to, ignore this instruction. + if (!SuccToSinkTo) + return false; + + DEBUG(dbgs() << "Sink" << *Inst << " ("; + Inst->getParent()->printAsOperand(dbgs(), false); + dbgs() << " -> "; + SuccToSinkTo->printAsOperand(dbgs(), false); + dbgs() << ")\n"); + + // Move the instruction. + Inst->moveBefore(&*SuccToSinkTo->getFirstInsertionPt()); + return true; +} + +static bool ProcessBlock(BasicBlock &BB, DominatorTree &DT, LoopInfo &LI, + AAResults &AA) { + // Can't sink anything out of a block that has less than two successors. + if (BB.getTerminator()->getNumSuccessors() <= 1) return false; + + // Don't bother sinking code out of unreachable blocks. In addition to being + // unprofitable, it can also lead to infinite looping, because in an + // unreachable loop there may be nowhere to stop. + if (!DT.isReachableFromEntry(&BB)) return false; + + bool MadeChange = false; + + // Walk the basic block bottom-up. Remember if we saw a store. + BasicBlock::iterator I = BB.end(); + --I; + bool ProcessedBegin = false; + SmallPtrSet<Instruction *, 8> Stores; + do { + Instruction *Inst = &*I; // The instruction to sink. + + // Predecrement I (if it's not begin) so that it isn't invalidated by + // sinking. + ProcessedBegin = I == BB.begin(); + if (!ProcessedBegin) + --I; + + if (isa<DbgInfoIntrinsic>(Inst)) + continue; + + if (SinkInstruction(Inst, Stores, DT, LI, AA)) { + ++NumSunk; + MadeChange = true; + } + + // If we just processed the first instruction in the block, we're done. + } while (!ProcessedBegin); + + return MadeChange; +} + +static bool iterativelySinkInstructions(Function &F, DominatorTree &DT, + LoopInfo &LI, AAResults &AA) { + bool MadeChange, EverMadeChange = false; + + do { + MadeChange = false; + DEBUG(dbgs() << "Sinking iteration " << NumSinkIter << "\n"); + // Process all basic blocks. + for (BasicBlock &I : F) + MadeChange |= ProcessBlock(I, DT, LI, AA); + EverMadeChange |= MadeChange; + NumSinkIter++; + } while (MadeChange); + + return EverMadeChange; +} + +PreservedAnalyses SinkingPass::run(Function &F, FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &LI = AM.getResult<LoopAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + + if (!iterativelySinkInstructions(F, DT, LI, AA)) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + PA.preserveSet<CFGAnalyses>(); + return PA; +} + +namespace { + class SinkingLegacyPass : public FunctionPass { + public: + static char ID; // Pass identification + SinkingLegacyPass() : FunctionPass(ID) { + initializeSinkingLegacyPassPass(*PassRegistry::getPassRegistry()); + } + + bool runOnFunction(Function &F) override { + auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + + return iterativelySinkInstructions(F, DT, LI, AA); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesCFG(); + FunctionPass::getAnalysisUsage(AU); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); + AU.addPreserved<LoopInfoWrapperPass>(); + } + }; +} // end anonymous namespace + +char SinkingLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SinkingLegacyPass, "sink", "Code sinking", false, false) +INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) +INITIALIZE_PASS_END(SinkingLegacyPass, "sink", "Code sinking", false, false) + +FunctionPass *llvm::createSinkingPass() { return new SinkingLegacyPass(); } diff --git a/contrib/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp b/contrib/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp new file mode 100644 index 000000000000..23156d5a4d83 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SpeculateAroundPHIs.cpp @@ -0,0 +1,811 @@ +//===- SpeculateAroundPHIs.cpp --------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/SpeculateAroundPHIs.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SetVector.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/Support/Debug.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" + +using namespace llvm; + +#define DEBUG_TYPE "spec-phis" + +STATISTIC(NumPHIsSpeculated, "Number of PHI nodes we speculated around"); +STATISTIC(NumEdgesSplit, + "Number of critical edges which were split for speculation"); +STATISTIC(NumSpeculatedInstructions, + "Number of instructions we speculated around the PHI nodes"); +STATISTIC(NumNewRedundantInstructions, + "Number of new, redundant instructions inserted"); + +/// Check wether speculating the users of a PHI node around the PHI +/// will be safe. +/// +/// This checks both that all of the users are safe and also that all of their +/// operands are either recursively safe or already available along an incoming +/// edge to the PHI. +/// +/// This routine caches both all the safe nodes explored in `PotentialSpecSet` +/// and the chain of nodes that definitively reach any unsafe node in +/// `UnsafeSet`. By preserving these between repeated calls to this routine for +/// PHIs in the same basic block, the exploration here can be reused. However, +/// these caches must no be reused for PHIs in a different basic block as they +/// reflect what is available along incoming edges. +static bool +isSafeToSpeculatePHIUsers(PHINode &PN, DominatorTree &DT, + SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + SmallPtrSetImpl<Instruction *> &UnsafeSet) { + auto *PhiBB = PN.getParent(); + SmallPtrSet<Instruction *, 4> Visited; + SmallVector<std::pair<Instruction *, User::value_op_iterator>, 16> DFSStack; + + // Walk each user of the PHI node. + for (Use &U : PN.uses()) { + auto *UI = cast<Instruction>(U.getUser()); + + // Ensure the use post-dominates the PHI node. This ensures that, in the + // absence of unwinding, the use will actually be reached. + // FIXME: We use a blunt hammer of requiring them to be in the same basic + // block. We should consider using actual post-dominance here in the + // future. + if (UI->getParent() != PhiBB) { + DEBUG(dbgs() << " Unsafe: use in a different BB: " << *UI << "\n"); + return false; + } + + // FIXME: This check is much too conservative. We're not going to move these + // instructions onto new dynamic paths through the program unless there is + // a call instruction between the use and the PHI node. And memory isn't + // changing unless there is a store in that same sequence. We should + // probably change this to do at least a limited scan of the intervening + // instructions and allow handling stores in easily proven safe cases. + if (mayBeMemoryDependent(*UI)) { + DEBUG(dbgs() << " Unsafe: can't speculate use: " << *UI << "\n"); + return false; + } + + // Now do a depth-first search of everything these users depend on to make + // sure they are transitively safe. This is a depth-first search, but we + // check nodes in preorder to minimize the amount of checking. + Visited.insert(UI); + DFSStack.push_back({UI, UI->value_op_begin()}); + do { + User::value_op_iterator OpIt; + std::tie(UI, OpIt) = DFSStack.pop_back_val(); + + while (OpIt != UI->value_op_end()) { + auto *OpI = dyn_cast<Instruction>(*OpIt); + // Increment to the next operand for whenever we continue. + ++OpIt; + // No need to visit non-instructions, which can't form dependencies. + if (!OpI) + continue; + + // Now do the main pre-order checks that this operand is a viable + // dependency of something we want to speculate. + + // First do a few checks for instructions that won't require + // speculation at all because they are trivially available on the + // incoming edge (either through dominance or through an incoming value + // to a PHI). + // + // The cases in the current block will be trivially dominated by the + // edge. + auto *ParentBB = OpI->getParent(); + if (ParentBB == PhiBB) { + if (isa<PHINode>(OpI)) { + // We can trivially map through phi nodes in the same block. + continue; + } + } else if (DT.dominates(ParentBB, PhiBB)) { + // Instructions from dominating blocks are already available. + continue; + } + + // Once we know that we're considering speculating the operand, check + // if we've already explored this subgraph and found it to be safe. + if (PotentialSpecSet.count(OpI)) + continue; + + // If we've already explored this subgraph and found it unsafe, bail. + // If when we directly test whether this is safe it fails, bail. + if (UnsafeSet.count(OpI) || ParentBB != PhiBB || + mayBeMemoryDependent(*OpI)) { + DEBUG(dbgs() << " Unsafe: can't speculate transitive use: " << *OpI + << "\n"); + // Record the stack of instructions which reach this node as unsafe + // so we prune subsequent searches. + UnsafeSet.insert(OpI); + for (auto &StackPair : DFSStack) { + Instruction *I = StackPair.first; + UnsafeSet.insert(I); + } + return false; + } + + // Skip any operands we're already recursively checking. + if (!Visited.insert(OpI).second) + continue; + + // Push onto the stack and descend. We can directly continue this + // loop when ascending. + DFSStack.push_back({UI, OpIt}); + UI = OpI; + OpIt = OpI->value_op_begin(); + } + + // This node and all its operands are safe. Go ahead and cache that for + // reuse later. + PotentialSpecSet.insert(UI); + + // Continue with the next node on the stack. + } while (!DFSStack.empty()); + } + +#ifndef NDEBUG + // Every visited operand should have been marked as safe for speculation at + // this point. Verify this and return success. + for (auto *I : Visited) + assert(PotentialSpecSet.count(I) && + "Failed to mark a visited instruction as safe!"); +#endif + return true; +} + +/// Check whether, in isolation, a given PHI node is both safe and profitable +/// to speculate users around. +/// +/// This handles checking whether there are any constant operands to a PHI +/// which could represent a useful speculation candidate, whether the users of +/// the PHI are safe to speculate including all their transitive dependencies, +/// and whether after speculation there will be some cost savings (profit) to +/// folding the operands into the users of the PHI node. Returns true if both +/// safe and profitable with relevant cost savings updated in the map and with +/// an update to the `PotentialSpecSet`. Returns false if either safety or +/// profitability are absent. Some new entries may be made to the +/// `PotentialSpecSet` even when this routine returns false, but they remain +/// conservatively correct. +/// +/// The profitability check here is a local one, but it checks this in an +/// interesting way. Beyond checking that the total cost of materializing the +/// constants will be less than the cost of folding them into their users, it +/// also checks that no one incoming constant will have a higher cost when +/// folded into its users rather than materialized. This higher cost could +/// result in a dynamic *path* that is more expensive even when the total cost +/// is lower. Currently, all of the interesting cases where this optimization +/// should fire are ones where it is a no-loss operation in this sense. If we +/// ever want to be more aggressive here, we would need to balance the +/// different incoming edges' cost by looking at their respective +/// probabilities. +static bool isSafeAndProfitableToSpeculateAroundPHI( + PHINode &PN, SmallDenseMap<PHINode *, int, 16> &CostSavingsMap, + SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + SmallPtrSetImpl<Instruction *> &UnsafeSet, DominatorTree &DT, + TargetTransformInfo &TTI) { + // First see whether there is any cost savings to speculating around this + // PHI, and build up a map of the constant inputs to how many times they + // occur. + bool NonFreeMat = false; + struct CostsAndCount { + int MatCost = TargetTransformInfo::TCC_Free; + int FoldedCost = TargetTransformInfo::TCC_Free; + int Count = 0; + }; + SmallDenseMap<ConstantInt *, CostsAndCount, 16> CostsAndCounts; + SmallPtrSet<BasicBlock *, 16> IncomingConstantBlocks; + for (int i : llvm::seq<int>(0, PN.getNumIncomingValues())) { + auto *IncomingC = dyn_cast<ConstantInt>(PN.getIncomingValue(i)); + if (!IncomingC) + continue; + + // Only visit each incoming edge with a constant input once. + if (!IncomingConstantBlocks.insert(PN.getIncomingBlock(i)).second) + continue; + + auto InsertResult = CostsAndCounts.insert({IncomingC, {}}); + // Count how many edges share a given incoming costant. + ++InsertResult.first->second.Count; + // Only compute the cost the first time we see a particular constant. + if (!InsertResult.second) + continue; + + int &MatCost = InsertResult.first->second.MatCost; + MatCost = TTI.getIntImmCost(IncomingC->getValue(), IncomingC->getType()); + NonFreeMat |= MatCost != TTI.TCC_Free; + } + if (!NonFreeMat) { + DEBUG(dbgs() << " Free: " << PN << "\n"); + // No profit in free materialization. + return false; + } + + // Now check that the uses of this PHI can actually be speculated, + // otherwise we'll still have to materialize the PHI value. + if (!isSafeToSpeculatePHIUsers(PN, DT, PotentialSpecSet, UnsafeSet)) { + DEBUG(dbgs() << " Unsafe PHI: " << PN << "\n"); + return false; + } + + // Compute how much (if any) savings are available by speculating around this + // PHI. + for (Use &U : PN.uses()) { + auto *UserI = cast<Instruction>(U.getUser()); + // Now check whether there is any savings to folding the incoming constants + // into this use. + unsigned Idx = U.getOperandNo(); + + // If we have a binary operator that is commutative, an actual constant + // operand would end up on the RHS, so pretend the use of the PHI is on the + // RHS. + // + // Technically, this is a bit weird if *both* operands are PHIs we're + // speculating. But if that is the case, giving an "optimistic" cost isn't + // a bad thing because after speculation it will constant fold. And + // moreover, such cases should likely have been constant folded already by + // some other pass, so we shouldn't worry about "modeling" them terribly + // accurately here. Similarly, if the other operand is a constant, it still + // seems fine to be "optimistic" in our cost modeling, because when the + // incoming operand from the PHI node is also a constant, we will end up + // constant folding. + if (UserI->isBinaryOp() && UserI->isCommutative() && Idx != 1) + // Assume we will commute the constant to the RHS to be canonical. + Idx = 1; + + // Get the intrinsic ID if this user is an instrinsic. + Intrinsic::ID IID = Intrinsic::not_intrinsic; + if (auto *UserII = dyn_cast<IntrinsicInst>(UserI)) + IID = UserII->getIntrinsicID(); + + for (auto &IncomingConstantAndCostsAndCount : CostsAndCounts) { + ConstantInt *IncomingC = IncomingConstantAndCostsAndCount.first; + int MatCost = IncomingConstantAndCostsAndCount.second.MatCost; + int &FoldedCost = IncomingConstantAndCostsAndCount.second.FoldedCost; + if (IID) + FoldedCost += TTI.getIntImmCost(IID, Idx, IncomingC->getValue(), + IncomingC->getType()); + else + FoldedCost += + TTI.getIntImmCost(UserI->getOpcode(), Idx, IncomingC->getValue(), + IncomingC->getType()); + + // If we accumulate more folded cost for this incoming constant than + // materialized cost, then we'll regress any edge with this constant so + // just bail. We're only interested in cases where folding the incoming + // constants is at least break-even on all paths. + if (FoldedCost > MatCost) { + DEBUG(dbgs() << " Not profitable to fold imm: " << *IncomingC << "\n" + " Materializing cost: " << MatCost << "\n" + " Accumulated folded cost: " << FoldedCost << "\n"); + return false; + } + } + } + + // Compute the total cost savings afforded by this PHI node. + int TotalMatCost = TTI.TCC_Free, TotalFoldedCost = TTI.TCC_Free; + for (auto IncomingConstantAndCostsAndCount : CostsAndCounts) { + int MatCost = IncomingConstantAndCostsAndCount.second.MatCost; + int FoldedCost = IncomingConstantAndCostsAndCount.second.FoldedCost; + int Count = IncomingConstantAndCostsAndCount.second.Count; + + TotalMatCost += MatCost * Count; + TotalFoldedCost += FoldedCost * Count; + } + assert(TotalFoldedCost <= TotalMatCost && "If each constant's folded cost is " + "less that its materialized cost, " + "the sum must be as well."); + + DEBUG(dbgs() << " Cost savings " << (TotalMatCost - TotalFoldedCost) + << ": " << PN << "\n"); + CostSavingsMap[&PN] = TotalMatCost - TotalFoldedCost; + return true; +} + +/// Simple helper to walk all the users of a list of phis depth first, and call +/// a visit function on each one in post-order. +/// +/// All of the PHIs should be in the same basic block, and this is primarily +/// used to make a single depth-first walk across their collective users +/// without revisiting any subgraphs. Callers should provide a fast, idempotent +/// callable to test whether a node has been visited and the more important +/// callable to actually visit a particular node. +/// +/// Depth-first and postorder here refer to the *operand* graph -- we start +/// from a collection of users of PHI nodes and walk "up" the operands +/// depth-first. +template <typename IsVisitedT, typename VisitT> +static void visitPHIUsersAndDepsInPostOrder(ArrayRef<PHINode *> PNs, + IsVisitedT IsVisited, + VisitT Visit) { + SmallVector<std::pair<Instruction *, User::value_op_iterator>, 16> DFSStack; + for (auto *PN : PNs) + for (Use &U : PN->uses()) { + auto *UI = cast<Instruction>(U.getUser()); + if (IsVisited(UI)) + // Already visited this user, continue across the roots. + continue; + + // Otherwise, walk the operand graph depth-first and visit each + // dependency in postorder. + DFSStack.push_back({UI, UI->value_op_begin()}); + do { + User::value_op_iterator OpIt; + std::tie(UI, OpIt) = DFSStack.pop_back_val(); + while (OpIt != UI->value_op_end()) { + auto *OpI = dyn_cast<Instruction>(*OpIt); + // Increment to the next operand for whenever we continue. + ++OpIt; + // No need to visit non-instructions, which can't form dependencies, + // or instructions outside of our potential dependency set that we + // were given. Finally, if we've already visited the node, continue + // to the next. + if (!OpI || IsVisited(OpI)) + continue; + + // Push onto the stack and descend. We can directly continue this + // loop when ascending. + DFSStack.push_back({UI, OpIt}); + UI = OpI; + OpIt = OpI->value_op_begin(); + } + + // Finished visiting children, visit this node. + assert(!IsVisited(UI) && "Should not have already visited a node!"); + Visit(UI); + } while (!DFSStack.empty()); + } +} + +/// Find profitable PHIs to speculate. +/// +/// For a PHI node to be profitable, we need the cost of speculating its users +/// (and their dependencies) to not exceed the savings of folding the PHI's +/// constant operands into the speculated users. +/// +/// Computing this is surprisingly challenging. Because users of two different +/// PHI nodes can depend on each other or on common other instructions, it may +/// be profitable to speculate two PHI nodes together even though neither one +/// in isolation is profitable. The straightforward way to find all the +/// profitable PHIs would be to check each combination of PHIs' cost, but this +/// is exponential in complexity. +/// +/// Even if we assume that we only care about cases where we can consider each +/// PHI node in isolation (rather than considering cases where none are +/// profitable in isolation but some subset are profitable as a set), we still +/// have a challenge. The obvious way to find all individually profitable PHIs +/// is to iterate until reaching a fixed point, but this will be quadratic in +/// complexity. =/ +/// +/// This code currently uses a linear-to-compute order for a greedy approach. +/// It won't find cases where a set of PHIs must be considered together, but it +/// handles most cases of order dependence without quadratic iteration. The +/// specific order used is the post-order across the operand DAG. When the last +/// user of a PHI is visited in this postorder walk, we check it for +/// profitability. +/// +/// There is an orthogonal extra complexity to all of this: computing the cost +/// itself can easily become a linear computation making everything again (at +/// best) quadratic. Using a postorder over the operand graph makes it +/// particularly easy to avoid this through dynamic programming. As we do the +/// postorder walk, we build the transitive cost of that subgraph. It is also +/// straightforward to then update these costs when we mark a PHI for +/// speculation so that subsequent PHIs don't re-pay the cost of already +/// speculated instructions. +static SmallVector<PHINode *, 16> +findProfitablePHIs(ArrayRef<PHINode *> PNs, + const SmallDenseMap<PHINode *, int, 16> &CostSavingsMap, + const SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + int NumPreds, DominatorTree &DT, TargetTransformInfo &TTI) { + SmallVector<PHINode *, 16> SpecPNs; + + // First, establish a reverse mapping from immediate users of the PHI nodes + // to the nodes themselves, and count how many users each PHI node has in + // a way we can update while processing them. + SmallDenseMap<Instruction *, TinyPtrVector<PHINode *>, 16> UserToPNMap; + SmallDenseMap<PHINode *, int, 16> PNUserCountMap; + SmallPtrSet<Instruction *, 16> UserSet; + for (auto *PN : PNs) { + assert(UserSet.empty() && "Must start with an empty user set!"); + for (Use &U : PN->uses()) + UserSet.insert(cast<Instruction>(U.getUser())); + PNUserCountMap[PN] = UserSet.size(); + for (auto *UI : UserSet) + UserToPNMap.insert({UI, {}}).first->second.push_back(PN); + UserSet.clear(); + } + + // Now do a DFS across the operand graph of the users, computing cost as we + // go and when all costs for a given PHI are known, checking that PHI for + // profitability. + SmallDenseMap<Instruction *, int, 16> SpecCostMap; + visitPHIUsersAndDepsInPostOrder( + PNs, + /*IsVisited*/ + [&](Instruction *I) { + // We consider anything that isn't potentially speculated to be + // "visited" as it is already handled. Similarly, anything that *is* + // potentially speculated but for which we have an entry in our cost + // map, we're done. + return !PotentialSpecSet.count(I) || SpecCostMap.count(I); + }, + /*Visit*/ + [&](Instruction *I) { + // We've fully visited the operands, so sum their cost with this node + // and update the cost map. + int Cost = TTI.TCC_Free; + for (Value *OpV : I->operand_values()) + if (auto *OpI = dyn_cast<Instruction>(OpV)) { + auto CostMapIt = SpecCostMap.find(OpI); + if (CostMapIt != SpecCostMap.end()) + Cost += CostMapIt->second; + } + Cost += TTI.getUserCost(I); + bool Inserted = SpecCostMap.insert({I, Cost}).second; + (void)Inserted; + assert(Inserted && "Must not re-insert a cost during the DFS!"); + + // Now check if this node had a corresponding PHI node using it. If so, + // we need to decrement the outstanding user count for it. + auto UserPNsIt = UserToPNMap.find(I); + if (UserPNsIt == UserToPNMap.end()) + return; + auto &UserPNs = UserPNsIt->second; + auto UserPNsSplitIt = std::stable_partition( + UserPNs.begin(), UserPNs.end(), [&](PHINode *UserPN) { + int &PNUserCount = PNUserCountMap.find(UserPN)->second; + assert( + PNUserCount > 0 && + "Should never re-visit a PN after its user count hits zero!"); + --PNUserCount; + return PNUserCount != 0; + }); + + // FIXME: Rather than one at a time, we should sum the savings as the + // cost will be completely shared. + SmallVector<Instruction *, 16> SpecWorklist; + for (auto *PN : llvm::make_range(UserPNsSplitIt, UserPNs.end())) { + int SpecCost = TTI.TCC_Free; + for (Use &U : PN->uses()) + SpecCost += + SpecCostMap.find(cast<Instruction>(U.getUser()))->second; + SpecCost *= (NumPreds - 1); + // When the user count of a PHI node hits zero, we should check its + // profitability. If profitable, we should mark it for speculation + // and zero out the cost of everything it depends on. + int CostSavings = CostSavingsMap.find(PN)->second; + if (SpecCost > CostSavings) { + DEBUG(dbgs() << " Not profitable, speculation cost: " << *PN << "\n" + " Cost savings: " << CostSavings << "\n" + " Speculation cost: " << SpecCost << "\n"); + continue; + } + + // We're going to speculate this user-associated PHI. Copy it out and + // add its users to the worklist to update their cost. + SpecPNs.push_back(PN); + for (Use &U : PN->uses()) { + auto *UI = cast<Instruction>(U.getUser()); + auto CostMapIt = SpecCostMap.find(UI); + if (CostMapIt->second == 0) + continue; + // Zero out this cost entry to avoid duplicates. + CostMapIt->second = 0; + SpecWorklist.push_back(UI); + } + } + + // Now walk all the operands of the users in the worklist transitively + // to zero out all the memoized costs. + while (!SpecWorklist.empty()) { + Instruction *SpecI = SpecWorklist.pop_back_val(); + assert(SpecCostMap.find(SpecI)->second == 0 && + "Didn't zero out a cost!"); + + // Walk the operands recursively to zero out their cost as well. + for (auto *OpV : SpecI->operand_values()) { + auto *OpI = dyn_cast<Instruction>(OpV); + if (!OpI) + continue; + auto CostMapIt = SpecCostMap.find(OpI); + if (CostMapIt == SpecCostMap.end() || CostMapIt->second == 0) + continue; + CostMapIt->second = 0; + SpecWorklist.push_back(OpI); + } + } + }); + + return SpecPNs; +} + +/// Speculate users around a set of PHI nodes. +/// +/// This routine does the actual speculation around a set of PHI nodes where we +/// have determined this to be both safe and profitable. +/// +/// This routine handles any spliting of critical edges necessary to create +/// a safe block to speculate into as well as cloning the instructions and +/// rewriting all uses. +static void speculatePHIs(ArrayRef<PHINode *> SpecPNs, + SmallPtrSetImpl<Instruction *> &PotentialSpecSet, + SmallSetVector<BasicBlock *, 16> &PredSet, + DominatorTree &DT) { + DEBUG(dbgs() << " Speculating around " << SpecPNs.size() << " PHIs!\n"); + NumPHIsSpeculated += SpecPNs.size(); + + // Split any critical edges so that we have a block to hoist into. + auto *ParentBB = SpecPNs[0]->getParent(); + SmallVector<BasicBlock *, 16> SpecPreds; + SpecPreds.reserve(PredSet.size()); + for (auto *PredBB : PredSet) { + auto *NewPredBB = SplitCriticalEdge( + PredBB, ParentBB, + CriticalEdgeSplittingOptions(&DT).setMergeIdenticalEdges()); + if (NewPredBB) { + ++NumEdgesSplit; + DEBUG(dbgs() << " Split critical edge from: " << PredBB->getName() + << "\n"); + SpecPreds.push_back(NewPredBB); + } else { + assert(PredBB->getSingleSuccessor() == ParentBB && + "We need a non-critical predecessor to speculate into."); + assert(!isa<InvokeInst>(PredBB->getTerminator()) && + "Cannot have a non-critical invoke!"); + + // Already non-critical, use existing pred. + SpecPreds.push_back(PredBB); + } + } + + SmallPtrSet<Instruction *, 16> SpecSet; + SmallVector<Instruction *, 16> SpecList; + visitPHIUsersAndDepsInPostOrder(SpecPNs, + /*IsVisited*/ + [&](Instruction *I) { + // This is visited if we don't need to + // speculate it or we already have + // speculated it. + return !PotentialSpecSet.count(I) || + SpecSet.count(I); + }, + /*Visit*/ + [&](Instruction *I) { + // All operands scheduled, schedule this + // node. + SpecSet.insert(I); + SpecList.push_back(I); + }); + + int NumSpecInsts = SpecList.size() * SpecPreds.size(); + int NumRedundantInsts = NumSpecInsts - SpecList.size(); + DEBUG(dbgs() << " Inserting " << NumSpecInsts << " speculated instructions, " + << NumRedundantInsts << " redundancies\n"); + NumSpeculatedInstructions += NumSpecInsts; + NumNewRedundantInstructions += NumRedundantInsts; + + // Each predecessor is numbered by its index in `SpecPreds`, so for each + // instruction we speculate, the speculated instruction is stored in that + // index of the vector asosciated with the original instruction. We also + // store the incoming values for each predecessor from any PHIs used. + SmallDenseMap<Instruction *, SmallVector<Value *, 2>, 16> SpeculatedValueMap; + + // Inject the synthetic mappings to rewrite PHIs to the appropriate incoming + // value. This handles both the PHIs we are speculating around and any other + // PHIs that happen to be used. + for (auto *OrigI : SpecList) + for (auto *OpV : OrigI->operand_values()) { + auto *OpPN = dyn_cast<PHINode>(OpV); + if (!OpPN || OpPN->getParent() != ParentBB) + continue; + + auto InsertResult = SpeculatedValueMap.insert({OpPN, {}}); + if (!InsertResult.second) + continue; + + auto &SpeculatedVals = InsertResult.first->second; + + // Populating our structure for mapping is particularly annoying because + // finding an incoming value for a particular predecessor block in a PHI + // node is a linear time operation! To avoid quadratic behavior, we build + // a map for this PHI node's incoming values and then translate it into + // the more compact representation used below. + SmallDenseMap<BasicBlock *, Value *, 16> IncomingValueMap; + for (int i : llvm::seq<int>(0, OpPN->getNumIncomingValues())) + IncomingValueMap[OpPN->getIncomingBlock(i)] = OpPN->getIncomingValue(i); + + for (auto *PredBB : SpecPreds) + SpeculatedVals.push_back(IncomingValueMap.find(PredBB)->second); + } + + // Speculate into each predecessor. + for (int PredIdx : llvm::seq<int>(0, SpecPreds.size())) { + auto *PredBB = SpecPreds[PredIdx]; + assert(PredBB->getSingleSuccessor() == ParentBB && + "We need a non-critical predecessor to speculate into."); + + for (auto *OrigI : SpecList) { + auto *NewI = OrigI->clone(); + NewI->setName(Twine(OrigI->getName()) + "." + Twine(PredIdx)); + NewI->insertBefore(PredBB->getTerminator()); + + // Rewrite all the operands to the previously speculated instructions. + // Because we're walking in-order, the defs must precede the uses and we + // should already have these mappings. + for (Use &U : NewI->operands()) { + auto *OpI = dyn_cast<Instruction>(U.get()); + if (!OpI) + continue; + auto MapIt = SpeculatedValueMap.find(OpI); + if (MapIt == SpeculatedValueMap.end()) + continue; + const auto &SpeculatedVals = MapIt->second; + assert(SpeculatedVals[PredIdx] && + "Must have a speculated value for this predecessor!"); + assert(SpeculatedVals[PredIdx]->getType() == OpI->getType() && + "Speculated value has the wrong type!"); + + // Rewrite the use to this predecessor's speculated instruction. + U.set(SpeculatedVals[PredIdx]); + } + + // Commute instructions which now have a constant in the LHS but not the + // RHS. + if (NewI->isBinaryOp() && NewI->isCommutative() && + isa<Constant>(NewI->getOperand(0)) && + !isa<Constant>(NewI->getOperand(1))) + NewI->getOperandUse(0).swap(NewI->getOperandUse(1)); + + SpeculatedValueMap[OrigI].push_back(NewI); + assert(SpeculatedValueMap[OrigI][PredIdx] == NewI && + "Mismatched speculated instruction index!"); + } + } + + // Walk the speculated instruction list and if they have uses, insert a PHI + // for them from the speculated versions, and replace the uses with the PHI. + // Then erase the instructions as they have been fully speculated. The walk + // needs to be in reverse so that we don't think there are users when we'll + // actually eventually remove them later. + IRBuilder<> IRB(SpecPNs[0]); + for (auto *OrigI : llvm::reverse(SpecList)) { + // Check if we need a PHI for any remaining users and if so, insert it. + if (!OrigI->use_empty()) { + auto *SpecIPN = IRB.CreatePHI(OrigI->getType(), SpecPreds.size(), + Twine(OrigI->getName()) + ".phi"); + // Add the incoming values we speculated. + auto &SpeculatedVals = SpeculatedValueMap.find(OrigI)->second; + for (int PredIdx : llvm::seq<int>(0, SpecPreds.size())) + SpecIPN->addIncoming(SpeculatedVals[PredIdx], SpecPreds[PredIdx]); + + // And replace the uses with the PHI node. + OrigI->replaceAllUsesWith(SpecIPN); + } + + // It is important to immediately erase this so that it stops using other + // instructions. This avoids inserting needless PHIs of them. + OrigI->eraseFromParent(); + } + + // All of the uses of the speculated phi nodes should be removed at this + // point, so erase them. + for (auto *SpecPN : SpecPNs) { + assert(SpecPN->use_empty() && "All users should have been speculated!"); + SpecPN->eraseFromParent(); + } +} + +/// Try to speculate around a series of PHIs from a single basic block. +/// +/// This routine checks whether any of these PHIs are profitable to speculate +/// users around. If safe and profitable, it does the speculation. It returns +/// true when at least some speculation occurs. +static bool tryToSpeculatePHIs(SmallVectorImpl<PHINode *> &PNs, + DominatorTree &DT, TargetTransformInfo &TTI) { + DEBUG(dbgs() << "Evaluating phi nodes for speculation:\n"); + + // Savings in cost from speculating around a PHI node. + SmallDenseMap<PHINode *, int, 16> CostSavingsMap; + + // Remember the set of instructions that are candidates for speculation so + // that we can quickly walk things within that space. This prunes out + // instructions already available along edges, etc. + SmallPtrSet<Instruction *, 16> PotentialSpecSet; + + // Remember the set of instructions that are (transitively) unsafe to + // speculate into the incoming edges of this basic block. This avoids + // recomputing them for each PHI node we check. This set is specific to this + // block though as things are pruned out of it based on what is available + // along incoming edges. + SmallPtrSet<Instruction *, 16> UnsafeSet; + + // For each PHI node in this block, check whether there are immediate folding + // opportunities from speculation, and whether that speculation will be + // valid. This determise the set of safe PHIs to speculate. + PNs.erase(llvm::remove_if(PNs, + [&](PHINode *PN) { + return !isSafeAndProfitableToSpeculateAroundPHI( + *PN, CostSavingsMap, PotentialSpecSet, + UnsafeSet, DT, TTI); + }), + PNs.end()); + // If no PHIs were profitable, skip. + if (PNs.empty()) { + DEBUG(dbgs() << " No safe and profitable PHIs found!\n"); + return false; + } + + // We need to know how much speculation will cost which is determined by how + // many incoming edges will need a copy of each speculated instruction. + SmallSetVector<BasicBlock *, 16> PredSet; + for (auto *PredBB : PNs[0]->blocks()) { + if (!PredSet.insert(PredBB)) + continue; + + // We cannot speculate when a predecessor is an indirect branch. + // FIXME: We also can't reliably create a non-critical edge block for + // speculation if the predecessor is an invoke. This doesn't seem + // fundamental and we should probably be splitting critical edges + // differently. + if (isa<IndirectBrInst>(PredBB->getTerminator()) || + isa<InvokeInst>(PredBB->getTerminator())) { + DEBUG(dbgs() << " Invalid: predecessor terminator: " << PredBB->getName() + << "\n"); + return false; + } + } + if (PredSet.size() < 2) { + DEBUG(dbgs() << " Unimportant: phi with only one predecessor\n"); + return false; + } + + SmallVector<PHINode *, 16> SpecPNs = findProfitablePHIs( + PNs, CostSavingsMap, PotentialSpecSet, PredSet.size(), DT, TTI); + if (SpecPNs.empty()) + // Nothing to do. + return false; + + speculatePHIs(SpecPNs, PotentialSpecSet, PredSet, DT); + return true; +} + +PreservedAnalyses SpeculateAroundPHIsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &DT = AM.getResult<DominatorTreeAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + + bool Changed = false; + for (auto *BB : ReversePostOrderTraversal<Function *>(&F)) { + SmallVector<PHINode *, 16> PNs; + auto BBI = BB->begin(); + while (auto *PN = dyn_cast<PHINode>(&*BBI)) { + PNs.push_back(PN); + ++BBI; + } + + if (PNs.empty()) + continue; + + Changed |= tryToSpeculatePHIs(PNs, DT, TTI); + } + + if (!Changed) + return PreservedAnalyses::all(); + + PreservedAnalyses PA; + return PA; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp b/contrib/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp new file mode 100644 index 000000000000..a7c308b59877 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/SpeculativeExecution.cpp @@ -0,0 +1,319 @@ +//===- SpeculativeExecution.cpp ---------------------------------*- C++ -*-===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This pass hoists instructions to enable speculative execution on +// targets where branches are expensive. This is aimed at GPUs. It +// currently works on simple if-then and if-then-else +// patterns. +// +// Removing branches is not the only motivation for this +// pass. E.g. consider this code and assume that there is no +// addressing mode for multiplying by sizeof(*a): +// +// if (b > 0) +// c = a[i + 1] +// if (d > 0) +// e = a[i + 2] +// +// turns into +// +// p = &a[i + 1]; +// if (b > 0) +// c = *p; +// q = &a[i + 2]; +// if (d > 0) +// e = *q; +// +// which could later be optimized to +// +// r = &a[i]; +// if (b > 0) +// c = r[1]; +// if (d > 0) +// e = r[2]; +// +// Later passes sink back much of the speculated code that did not enable +// further optimization. +// +// This pass is more aggressive than the function SpeculativeyExecuteBB in +// SimplifyCFG. SimplifyCFG will not speculate if no selects are introduced and +// it will speculate at most one instruction. It also will not speculate if +// there is a value defined in the if-block that is only used in the then-block. +// These restrictions make sense since the speculation in SimplifyCFG seems +// aimed at introducing cheap selects, while this pass is intended to do more +// aggressive speculation while counting on later passes to either capitalize on +// that or clean it up. +// +// If the pass was created by calling +// createSpeculativeExecutionIfHasBranchDivergencePass or the +// -spec-exec-only-if-divergent-target option is present, this pass only has an +// effect on targets where TargetTransformInfo::hasBranchDivergence() is true; +// on other targets, it is a nop. +// +// This lets you include this pass unconditionally in the IR pass pipeline, but +// only enable it for relevant targets. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/SpeculativeExecution.h" +#include "llvm/ADT/SmallSet.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" + +using namespace llvm; + +#define DEBUG_TYPE "speculative-execution" + +// The risk that speculation will not pay off increases with the +// number of instructions speculated, so we put a limit on that. +static cl::opt<unsigned> SpecExecMaxSpeculationCost( + "spec-exec-max-speculation-cost", cl::init(7), cl::Hidden, + cl::desc("Speculative execution is not applied to basic blocks where " + "the cost of the instructions to speculatively execute " + "exceeds this limit.")); + +// Speculating just a few instructions from a larger block tends not +// to be profitable and this limit prevents that. A reason for that is +// that small basic blocks are more likely to be candidates for +// further optimization. +static cl::opt<unsigned> SpecExecMaxNotHoisted( + "spec-exec-max-not-hoisted", cl::init(5), cl::Hidden, + cl::desc("Speculative execution is not applied to basic blocks where the " + "number of instructions that would not be speculatively executed " + "exceeds this limit.")); + +static cl::opt<bool> SpecExecOnlyIfDivergentTarget( + "spec-exec-only-if-divergent-target", cl::init(false), cl::Hidden, + cl::desc("Speculative execution is applied only to targets with divergent " + "branches, even if the pass was configured to apply only to all " + "targets.")); + +namespace { + +class SpeculativeExecutionLegacyPass : public FunctionPass { +public: + static char ID; + explicit SpeculativeExecutionLegacyPass(bool OnlyIfDivergentTarget = false) + : FunctionPass(ID), OnlyIfDivergentTarget(OnlyIfDivergentTarget || + SpecExecOnlyIfDivergentTarget), + Impl(OnlyIfDivergentTarget) {} + + void getAnalysisUsage(AnalysisUsage &AU) const override; + bool runOnFunction(Function &F) override; + + StringRef getPassName() const override { + if (OnlyIfDivergentTarget) + return "Speculatively execute instructions if target has divergent " + "branches"; + return "Speculatively execute instructions"; + } + +private: + // Variable preserved purely for correct name printing. + const bool OnlyIfDivergentTarget; + + SpeculativeExecutionPass Impl; +}; +} // namespace + +char SpeculativeExecutionLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(SpeculativeExecutionLegacyPass, "speculative-execution", + "Speculatively execute instructions", false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(SpeculativeExecutionLegacyPass, "speculative-execution", + "Speculatively execute instructions", false, false) + +void SpeculativeExecutionLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const { + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); +} + +bool SpeculativeExecutionLegacyPass::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + auto *TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + return Impl.runImpl(F, TTI); +} + +namespace llvm { + +bool SpeculativeExecutionPass::runImpl(Function &F, TargetTransformInfo *TTI) { + if (OnlyIfDivergentTarget && !TTI->hasBranchDivergence()) { + DEBUG(dbgs() << "Not running SpeculativeExecution because " + "TTI->hasBranchDivergence() is false.\n"); + return false; + } + + this->TTI = TTI; + bool Changed = false; + for (auto& B : F) { + Changed |= runOnBasicBlock(B); + } + return Changed; +} + +bool SpeculativeExecutionPass::runOnBasicBlock(BasicBlock &B) { + BranchInst *BI = dyn_cast<BranchInst>(B.getTerminator()); + if (BI == nullptr) + return false; + + if (BI->getNumSuccessors() != 2) + return false; + BasicBlock &Succ0 = *BI->getSuccessor(0); + BasicBlock &Succ1 = *BI->getSuccessor(1); + + if (&B == &Succ0 || &B == &Succ1 || &Succ0 == &Succ1) { + return false; + } + + // Hoist from if-then (triangle). + if (Succ0.getSinglePredecessor() != nullptr && + Succ0.getSingleSuccessor() == &Succ1) { + return considerHoistingFromTo(Succ0, B); + } + + // Hoist from if-else (triangle). + if (Succ1.getSinglePredecessor() != nullptr && + Succ1.getSingleSuccessor() == &Succ0) { + return considerHoistingFromTo(Succ1, B); + } + + // Hoist from if-then-else (diamond), but only if it is equivalent to + // an if-else or if-then due to one of the branches doing nothing. + if (Succ0.getSinglePredecessor() != nullptr && + Succ1.getSinglePredecessor() != nullptr && + Succ1.getSingleSuccessor() != nullptr && + Succ1.getSingleSuccessor() != &B && + Succ1.getSingleSuccessor() == Succ0.getSingleSuccessor()) { + // If a block has only one instruction, then that is a terminator + // instruction so that the block does nothing. This does happen. + if (Succ1.size() == 1) // equivalent to if-then + return considerHoistingFromTo(Succ0, B); + if (Succ0.size() == 1) // equivalent to if-else + return considerHoistingFromTo(Succ1, B); + } + + return false; +} + +static unsigned ComputeSpeculationCost(const Instruction *I, + const TargetTransformInfo &TTI) { + switch (Operator::getOpcode(I)) { + case Instruction::GetElementPtr: + case Instruction::Add: + case Instruction::Mul: + case Instruction::And: + case Instruction::Or: + case Instruction::Select: + case Instruction::Shl: + case Instruction::Sub: + case Instruction::LShr: + case Instruction::AShr: + case Instruction::Xor: + case Instruction::ZExt: + case Instruction::SExt: + case Instruction::Call: + case Instruction::BitCast: + case Instruction::PtrToInt: + case Instruction::IntToPtr: + case Instruction::AddrSpaceCast: + case Instruction::FPToUI: + case Instruction::FPToSI: + case Instruction::UIToFP: + case Instruction::SIToFP: + case Instruction::FPExt: + case Instruction::FPTrunc: + case Instruction::FAdd: + case Instruction::FSub: + case Instruction::FMul: + case Instruction::FDiv: + case Instruction::FRem: + case Instruction::ICmp: + case Instruction::FCmp: + return TTI.getUserCost(I); + + default: + return UINT_MAX; // Disallow anything not whitelisted. + } +} + +bool SpeculativeExecutionPass::considerHoistingFromTo( + BasicBlock &FromBlock, BasicBlock &ToBlock) { + SmallSet<const Instruction *, 8> NotHoisted; + const auto AllPrecedingUsesFromBlockHoisted = [&NotHoisted](User *U) { + for (Value* V : U->operand_values()) { + if (Instruction *I = dyn_cast<Instruction>(V)) { + if (NotHoisted.count(I) > 0) + return false; + } + } + return true; + }; + + unsigned TotalSpeculationCost = 0; + for (auto& I : FromBlock) { + const unsigned Cost = ComputeSpeculationCost(&I, *TTI); + if (Cost != UINT_MAX && isSafeToSpeculativelyExecute(&I) && + AllPrecedingUsesFromBlockHoisted(&I)) { + TotalSpeculationCost += Cost; + if (TotalSpeculationCost > SpecExecMaxSpeculationCost) + return false; // too much to hoist + } else { + NotHoisted.insert(&I); + if (NotHoisted.size() > SpecExecMaxNotHoisted) + return false; // too much left behind + } + } + + if (TotalSpeculationCost == 0) + return false; // nothing to hoist + + for (auto I = FromBlock.begin(); I != FromBlock.end();) { + // We have to increment I before moving Current as moving Current + // changes the list that I is iterating through. + auto Current = I; + ++I; + if (!NotHoisted.count(&*Current)) { + Current->moveBefore(ToBlock.getTerminator()); + } + } + return true; +} + +FunctionPass *createSpeculativeExecutionPass() { + return new SpeculativeExecutionLegacyPass(); +} + +FunctionPass *createSpeculativeExecutionIfHasBranchDivergencePass() { + return new SpeculativeExecutionLegacyPass(/* OnlyIfDivergentTarget = */ true); +} + +SpeculativeExecutionPass::SpeculativeExecutionPass(bool OnlyIfDivergentTarget) + : OnlyIfDivergentTarget(OnlyIfDivergentTarget || + SpecExecOnlyIfDivergentTarget) {} + +PreservedAnalyses SpeculativeExecutionPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto *TTI = &AM.getResult<TargetIRAnalysis>(F); + + bool Changed = runImpl(F, TTI); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} +} // namespace llvm diff --git a/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp b/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp new file mode 100644 index 000000000000..ce40af1223f6 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/StraightLineStrengthReduce.cpp @@ -0,0 +1,737 @@ +//===- StraightLineStrengthReduce.cpp - -----------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file implements straight-line strength reduction (SLSR). Unlike loop +// strength reduction, this algorithm is designed to reduce arithmetic +// redundancy in straight-line code instead of loops. It has proven to be +// effective in simplifying arithmetic statements derived from an unrolled loop. +// It can also simplify the logic of SeparateConstOffsetFromGEP. +// +// There are many optimizations we can perform in the domain of SLSR. This file +// for now contains only an initial step. Specifically, we look for strength +// reduction candidates in the following forms: +// +// Form 1: B + i * S +// Form 2: (B + i) * S +// Form 3: &B[i * S] +// +// where S is an integer variable, and i is a constant integer. If we found two +// candidates S1 and S2 in the same form and S1 dominates S2, we may rewrite S2 +// in a simpler way with respect to S1. For example, +// +// S1: X = B + i * S +// S2: Y = B + i' * S => X + (i' - i) * S +// +// S1: X = (B + i) * S +// S2: Y = (B + i') * S => X + (i' - i) * S +// +// S1: X = &B[i * S] +// S2: Y = &B[i' * S] => &X[(i' - i) * S] +// +// Note: (i' - i) * S is folded to the extent possible. +// +// This rewriting is in general a good idea. The code patterns we focus on +// usually come from loop unrolling, so (i' - i) * S is likely the same +// across iterations and can be reused. When that happens, the optimized form +// takes only one add starting from the second iteration. +// +// When such rewriting is possible, we call S1 a "basis" of S2. When S2 has +// multiple bases, we choose to rewrite S2 with respect to its "immediate" +// basis, the basis that is the closest ancestor in the dominator tree. +// +// TODO: +// +// - Floating point arithmetics when fast math is enabled. +// +// - SLSR may decrease ILP at the architecture level. Targets that are very +// sensitive to ILP may want to disable it. Having SLSR to consider ILP is +// left as future work. +// +// - When (i' - i) is constant but i and i' are not, we could still perform +// SLSR. + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/DepthFirstIterator.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/ScalarEvolution.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/Analysis/ValueTracking.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/GetElementPtrTypeIterator.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/Local.h" +#include <cassert> +#include <cstdint> +#include <limits> +#include <list> +#include <vector> + +using namespace llvm; +using namespace PatternMatch; + +static const unsigned UnknownAddressSpace = + std::numeric_limits<unsigned>::max(); + +namespace { + +class StraightLineStrengthReduce : public FunctionPass { +public: + // SLSR candidate. Such a candidate must be in one of the forms described in + // the header comments. + struct Candidate { + enum Kind { + Invalid, // reserved for the default constructor + Add, // B + i * S + Mul, // (B + i) * S + GEP, // &B[..][i * S][..] + }; + + Candidate() = default; + Candidate(Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, + Instruction *I) + : CandidateKind(CT), Base(B), Index(Idx), Stride(S), Ins(I) {} + + Kind CandidateKind = Invalid; + + const SCEV *Base = nullptr; + + // Note that Index and Stride of a GEP candidate do not necessarily have the + // same integer type. In that case, during rewriting, Stride will be + // sign-extended or truncated to Index's type. + ConstantInt *Index = nullptr; + + Value *Stride = nullptr; + + // The instruction this candidate corresponds to. It helps us to rewrite a + // candidate with respect to its immediate basis. Note that one instruction + // can correspond to multiple candidates depending on how you associate the + // expression. For instance, + // + // (a + 1) * (b + 2) + // + // can be treated as + // + // <Base: a, Index: 1, Stride: b + 2> + // + // or + // + // <Base: b, Index: 2, Stride: a + 1> + Instruction *Ins = nullptr; + + // Points to the immediate basis of this candidate, or nullptr if we cannot + // find any basis for this candidate. + Candidate *Basis = nullptr; + }; + + static char ID; + + StraightLineStrengthReduce() : FunctionPass(ID) { + initializeStraightLineStrengthReducePass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<ScalarEvolutionWrapperPass>(); + AU.addRequired<TargetTransformInfoWrapperPass>(); + // We do not modify the shape of the CFG. + AU.setPreservesCFG(); + } + + bool doInitialization(Module &M) override { + DL = &M.getDataLayout(); + return false; + } + + bool runOnFunction(Function &F) override; + +private: + // Returns true if Basis is a basis for C, i.e., Basis dominates C and they + // share the same base and stride. + bool isBasisFor(const Candidate &Basis, const Candidate &C); + + // Returns whether the candidate can be folded into an addressing mode. + bool isFoldable(const Candidate &C, TargetTransformInfo *TTI, + const DataLayout *DL); + + // Returns true if C is already in a simplest form and not worth being + // rewritten. + bool isSimplestForm(const Candidate &C); + + // Checks whether I is in a candidate form. If so, adds all the matching forms + // to Candidates, and tries to find the immediate basis for each of them. + void allocateCandidatesAndFindBasis(Instruction *I); + + // Allocate candidates and find bases for Add instructions. + void allocateCandidatesAndFindBasisForAdd(Instruction *I); + + // Given I = LHS + RHS, factors RHS into i * S and makes (LHS + i * S) a + // candidate. + void allocateCandidatesAndFindBasisForAdd(Value *LHS, Value *RHS, + Instruction *I); + // Allocate candidates and find bases for Mul instructions. + void allocateCandidatesAndFindBasisForMul(Instruction *I); + + // Splits LHS into Base + Index and, if succeeds, calls + // allocateCandidatesAndFindBasis. + void allocateCandidatesAndFindBasisForMul(Value *LHS, Value *RHS, + Instruction *I); + + // Allocate candidates and find bases for GetElementPtr instructions. + void allocateCandidatesAndFindBasisForGEP(GetElementPtrInst *GEP); + + // A helper function that scales Idx with ElementSize before invoking + // allocateCandidatesAndFindBasis. + void allocateCandidatesAndFindBasisForGEP(const SCEV *B, ConstantInt *Idx, + Value *S, uint64_t ElementSize, + Instruction *I); + + // Adds the given form <CT, B, Idx, S> to Candidates, and finds its immediate + // basis. + void allocateCandidatesAndFindBasis(Candidate::Kind CT, const SCEV *B, + ConstantInt *Idx, Value *S, + Instruction *I); + + // Rewrites candidate C with respect to Basis. + void rewriteCandidateWithBasis(const Candidate &C, const Candidate &Basis); + + // A helper function that factors ArrayIdx to a product of a stride and a + // constant index, and invokes allocateCandidatesAndFindBasis with the + // factorings. + void factorArrayIndex(Value *ArrayIdx, const SCEV *Base, uint64_t ElementSize, + GetElementPtrInst *GEP); + + // Emit code that computes the "bump" from Basis to C. If the candidate is a + // GEP and the bump is not divisible by the element size of the GEP, this + // function sets the BumpWithUglyGEP flag to notify its caller to bump the + // basis using an ugly GEP. + static Value *emitBump(const Candidate &Basis, const Candidate &C, + IRBuilder<> &Builder, const DataLayout *DL, + bool &BumpWithUglyGEP); + + const DataLayout *DL = nullptr; + DominatorTree *DT = nullptr; + ScalarEvolution *SE; + TargetTransformInfo *TTI = nullptr; + std::list<Candidate> Candidates; + + // Temporarily holds all instructions that are unlinked (but not deleted) by + // rewriteCandidateWithBasis. These instructions will be actually removed + // after all rewriting finishes. + std::vector<Instruction *> UnlinkedInstructions; +}; + +} // end anonymous namespace + +char StraightLineStrengthReduce::ID = 0; + +INITIALIZE_PASS_BEGIN(StraightLineStrengthReduce, "slsr", + "Straight line strength reduction", false, false) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_END(StraightLineStrengthReduce, "slsr", + "Straight line strength reduction", false, false) + +FunctionPass *llvm::createStraightLineStrengthReducePass() { + return new StraightLineStrengthReduce(); +} + +bool StraightLineStrengthReduce::isBasisFor(const Candidate &Basis, + const Candidate &C) { + return (Basis.Ins != C.Ins && // skip the same instruction + // They must have the same type too. Basis.Base == C.Base doesn't + // guarantee their types are the same (PR23975). + Basis.Ins->getType() == C.Ins->getType() && + // Basis must dominate C in order to rewrite C with respect to Basis. + DT->dominates(Basis.Ins->getParent(), C.Ins->getParent()) && + // They share the same base, stride, and candidate kind. + Basis.Base == C.Base && Basis.Stride == C.Stride && + Basis.CandidateKind == C.CandidateKind); +} + +static bool isGEPFoldable(GetElementPtrInst *GEP, + const TargetTransformInfo *TTI) { + SmallVector<const Value*, 4> Indices; + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) + Indices.push_back(*I); + return TTI->getGEPCost(GEP->getSourceElementType(), GEP->getPointerOperand(), + Indices) == TargetTransformInfo::TCC_Free; +} + +// Returns whether (Base + Index * Stride) can be folded to an addressing mode. +static bool isAddFoldable(const SCEV *Base, ConstantInt *Index, Value *Stride, + TargetTransformInfo *TTI) { + // Index->getSExtValue() may crash if Index is wider than 64-bit. + return Index->getBitWidth() <= 64 && + TTI->isLegalAddressingMode(Base->getType(), nullptr, 0, true, + Index->getSExtValue(), UnknownAddressSpace); +} + +bool StraightLineStrengthReduce::isFoldable(const Candidate &C, + TargetTransformInfo *TTI, + const DataLayout *DL) { + if (C.CandidateKind == Candidate::Add) + return isAddFoldable(C.Base, C.Index, C.Stride, TTI); + if (C.CandidateKind == Candidate::GEP) + return isGEPFoldable(cast<GetElementPtrInst>(C.Ins), TTI); + return false; +} + +// Returns true if GEP has zero or one non-zero index. +static bool hasOnlyOneNonZeroIndex(GetElementPtrInst *GEP) { + unsigned NumNonZeroIndices = 0; + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) { + ConstantInt *ConstIdx = dyn_cast<ConstantInt>(*I); + if (ConstIdx == nullptr || !ConstIdx->isZero()) + ++NumNonZeroIndices; + } + return NumNonZeroIndices <= 1; +} + +bool StraightLineStrengthReduce::isSimplestForm(const Candidate &C) { + if (C.CandidateKind == Candidate::Add) { + // B + 1 * S or B + (-1) * S + return C.Index->isOne() || C.Index->isMinusOne(); + } + if (C.CandidateKind == Candidate::Mul) { + // (B + 0) * S + return C.Index->isZero(); + } + if (C.CandidateKind == Candidate::GEP) { + // (char*)B + S or (char*)B - S + return ((C.Index->isOne() || C.Index->isMinusOne()) && + hasOnlyOneNonZeroIndex(cast<GetElementPtrInst>(C.Ins))); + } + return false; +} + +// TODO: We currently implement an algorithm whose time complexity is linear in +// the number of existing candidates. However, we could do better by using +// ScopedHashTable. Specifically, while traversing the dominator tree, we could +// maintain all the candidates that dominate the basic block being traversed in +// a ScopedHashTable. This hash table is indexed by the base and the stride of +// a candidate. Therefore, finding the immediate basis of a candidate boils down +// to one hash-table look up. +void StraightLineStrengthReduce::allocateCandidatesAndFindBasis( + Candidate::Kind CT, const SCEV *B, ConstantInt *Idx, Value *S, + Instruction *I) { + Candidate C(CT, B, Idx, S, I); + // SLSR can complicate an instruction in two cases: + // + // 1. If we can fold I into an addressing mode, computing I is likely free or + // takes only one instruction. + // + // 2. I is already in a simplest form. For example, when + // X = B + 8 * S + // Y = B + S, + // rewriting Y to X - 7 * S is probably a bad idea. + // + // In the above cases, we still add I to the candidate list so that I can be + // the basis of other candidates, but we leave I's basis blank so that I + // won't be rewritten. + if (!isFoldable(C, TTI, DL) && !isSimplestForm(C)) { + // Try to compute the immediate basis of C. + unsigned NumIterations = 0; + // Limit the scan radius to avoid running in quadratice time. + static const unsigned MaxNumIterations = 50; + for (auto Basis = Candidates.rbegin(); + Basis != Candidates.rend() && NumIterations < MaxNumIterations; + ++Basis, ++NumIterations) { + if (isBasisFor(*Basis, C)) { + C.Basis = &(*Basis); + break; + } + } + } + // Regardless of whether we find a basis for C, we need to push C to the + // candidate list so that it can be the basis of other candidates. + Candidates.push_back(C); +} + +void StraightLineStrengthReduce::allocateCandidatesAndFindBasis( + Instruction *I) { + switch (I->getOpcode()) { + case Instruction::Add: + allocateCandidatesAndFindBasisForAdd(I); + break; + case Instruction::Mul: + allocateCandidatesAndFindBasisForMul(I); + break; + case Instruction::GetElementPtr: + allocateCandidatesAndFindBasisForGEP(cast<GetElementPtrInst>(I)); + break; + } +} + +void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd( + Instruction *I) { + // Try matching B + i * S. + if (!isa<IntegerType>(I->getType())) + return; + + assert(I->getNumOperands() == 2 && "isn't I an add?"); + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + allocateCandidatesAndFindBasisForAdd(LHS, RHS, I); + if (LHS != RHS) + allocateCandidatesAndFindBasisForAdd(RHS, LHS, I); +} + +void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForAdd( + Value *LHS, Value *RHS, Instruction *I) { + Value *S = nullptr; + ConstantInt *Idx = nullptr; + if (match(RHS, m_Mul(m_Value(S), m_ConstantInt(Idx)))) { + // I = LHS + RHS = LHS + Idx * S + allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I); + } else if (match(RHS, m_Shl(m_Value(S), m_ConstantInt(Idx)))) { + // I = LHS + RHS = LHS + (S << Idx) = LHS + S * (1 << Idx) + APInt One(Idx->getBitWidth(), 1); + Idx = ConstantInt::get(Idx->getContext(), One << Idx->getValue()); + allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), Idx, S, I); + } else { + // At least, I = LHS + 1 * RHS + ConstantInt *One = ConstantInt::get(cast<IntegerType>(I->getType()), 1); + allocateCandidatesAndFindBasis(Candidate::Add, SE->getSCEV(LHS), One, RHS, + I); + } +} + +// Returns true if A matches B + C where C is constant. +static bool matchesAdd(Value *A, Value *&B, ConstantInt *&C) { + return (match(A, m_Add(m_Value(B), m_ConstantInt(C))) || + match(A, m_Add(m_ConstantInt(C), m_Value(B)))); +} + +// Returns true if A matches B | C where C is constant. +static bool matchesOr(Value *A, Value *&B, ConstantInt *&C) { + return (match(A, m_Or(m_Value(B), m_ConstantInt(C))) || + match(A, m_Or(m_ConstantInt(C), m_Value(B)))); +} + +void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul( + Value *LHS, Value *RHS, Instruction *I) { + Value *B = nullptr; + ConstantInt *Idx = nullptr; + if (matchesAdd(LHS, B, Idx)) { + // If LHS is in the form of "Base + Index", then I is in the form of + // "(Base + Index) * RHS". + allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I); + } else if (matchesOr(LHS, B, Idx) && haveNoCommonBitsSet(B, Idx, *DL)) { + // If LHS is in the form of "Base | Index" and Base and Index have no common + // bits set, then + // Base | Index = Base + Index + // and I is thus in the form of "(Base + Index) * RHS". + allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(B), Idx, RHS, I); + } else { + // Otherwise, at least try the form (LHS + 0) * RHS. + ConstantInt *Zero = ConstantInt::get(cast<IntegerType>(I->getType()), 0); + allocateCandidatesAndFindBasis(Candidate::Mul, SE->getSCEV(LHS), Zero, RHS, + I); + } +} + +void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForMul( + Instruction *I) { + // Try matching (B + i) * S. + // TODO: we could extend SLSR to float and vector types. + if (!isa<IntegerType>(I->getType())) + return; + + assert(I->getNumOperands() == 2 && "isn't I a mul?"); + Value *LHS = I->getOperand(0), *RHS = I->getOperand(1); + allocateCandidatesAndFindBasisForMul(LHS, RHS, I); + if (LHS != RHS) { + // Symmetrically, try to split RHS to Base + Index. + allocateCandidatesAndFindBasisForMul(RHS, LHS, I); + } +} + +void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( + const SCEV *B, ConstantInt *Idx, Value *S, uint64_t ElementSize, + Instruction *I) { + // I = B + sext(Idx *nsw S) * ElementSize + // = B + (sext(Idx) * sext(S)) * ElementSize + // = B + (sext(Idx) * ElementSize) * sext(S) + // Casting to IntegerType is safe because we skipped vector GEPs. + IntegerType *IntPtrTy = cast<IntegerType>(DL->getIntPtrType(I->getType())); + ConstantInt *ScaledIdx = ConstantInt::get( + IntPtrTy, Idx->getSExtValue() * (int64_t)ElementSize, true); + allocateCandidatesAndFindBasis(Candidate::GEP, B, ScaledIdx, S, I); +} + +void StraightLineStrengthReduce::factorArrayIndex(Value *ArrayIdx, + const SCEV *Base, + uint64_t ElementSize, + GetElementPtrInst *GEP) { + // At least, ArrayIdx = ArrayIdx *nsw 1. + allocateCandidatesAndFindBasisForGEP( + Base, ConstantInt::get(cast<IntegerType>(ArrayIdx->getType()), 1), + ArrayIdx, ElementSize, GEP); + Value *LHS = nullptr; + ConstantInt *RHS = nullptr; + // One alternative is matching the SCEV of ArrayIdx instead of ArrayIdx + // itself. This would allow us to handle the shl case for free. However, + // matching SCEVs has two issues: + // + // 1. this would complicate rewriting because the rewriting procedure + // would have to translate SCEVs back to IR instructions. This translation + // is difficult when LHS is further evaluated to a composite SCEV. + // + // 2. ScalarEvolution is designed to be control-flow oblivious. It tends + // to strip nsw/nuw flags which are critical for SLSR to trace into + // sext'ed multiplication. + if (match(ArrayIdx, m_NSWMul(m_Value(LHS), m_ConstantInt(RHS)))) { + // SLSR is currently unsafe if i * S may overflow. + // GEP = Base + sext(LHS *nsw RHS) * ElementSize + allocateCandidatesAndFindBasisForGEP(Base, RHS, LHS, ElementSize, GEP); + } else if (match(ArrayIdx, m_NSWShl(m_Value(LHS), m_ConstantInt(RHS)))) { + // GEP = Base + sext(LHS <<nsw RHS) * ElementSize + // = Base + sext(LHS *nsw (1 << RHS)) * ElementSize + APInt One(RHS->getBitWidth(), 1); + ConstantInt *PowerOf2 = + ConstantInt::get(RHS->getContext(), One << RHS->getValue()); + allocateCandidatesAndFindBasisForGEP(Base, PowerOf2, LHS, ElementSize, GEP); + } +} + +void StraightLineStrengthReduce::allocateCandidatesAndFindBasisForGEP( + GetElementPtrInst *GEP) { + // TODO: handle vector GEPs + if (GEP->getType()->isVectorTy()) + return; + + SmallVector<const SCEV *, 4> IndexExprs; + for (auto I = GEP->idx_begin(); I != GEP->idx_end(); ++I) + IndexExprs.push_back(SE->getSCEV(*I)); + + gep_type_iterator GTI = gep_type_begin(GEP); + for (unsigned I = 1, E = GEP->getNumOperands(); I != E; ++I, ++GTI) { + if (GTI.isStruct()) + continue; + + const SCEV *OrigIndexExpr = IndexExprs[I - 1]; + IndexExprs[I - 1] = SE->getZero(OrigIndexExpr->getType()); + + // The base of this candidate is GEP's base plus the offsets of all + // indices except this current one. + const SCEV *BaseExpr = SE->getGEPExpr(cast<GEPOperator>(GEP), IndexExprs); + Value *ArrayIdx = GEP->getOperand(I); + uint64_t ElementSize = DL->getTypeAllocSize(GTI.getIndexedType()); + if (ArrayIdx->getType()->getIntegerBitWidth() <= + DL->getPointerSizeInBits(GEP->getAddressSpace())) { + // Skip factoring if ArrayIdx is wider than the pointer size, because + // ArrayIdx is implicitly truncated to the pointer size. + factorArrayIndex(ArrayIdx, BaseExpr, ElementSize, GEP); + } + // When ArrayIdx is the sext of a value, we try to factor that value as + // well. Handling this case is important because array indices are + // typically sign-extended to the pointer size. + Value *TruncatedArrayIdx = nullptr; + if (match(ArrayIdx, m_SExt(m_Value(TruncatedArrayIdx))) && + TruncatedArrayIdx->getType()->getIntegerBitWidth() <= + DL->getPointerSizeInBits(GEP->getAddressSpace())) { + // Skip factoring if TruncatedArrayIdx is wider than the pointer size, + // because TruncatedArrayIdx is implicitly truncated to the pointer size. + factorArrayIndex(TruncatedArrayIdx, BaseExpr, ElementSize, GEP); + } + + IndexExprs[I - 1] = OrigIndexExpr; + } +} + +// A helper function that unifies the bitwidth of A and B. +static void unifyBitWidth(APInt &A, APInt &B) { + if (A.getBitWidth() < B.getBitWidth()) + A = A.sext(B.getBitWidth()); + else if (A.getBitWidth() > B.getBitWidth()) + B = B.sext(A.getBitWidth()); +} + +Value *StraightLineStrengthReduce::emitBump(const Candidate &Basis, + const Candidate &C, + IRBuilder<> &Builder, + const DataLayout *DL, + bool &BumpWithUglyGEP) { + APInt Idx = C.Index->getValue(), BasisIdx = Basis.Index->getValue(); + unifyBitWidth(Idx, BasisIdx); + APInt IndexOffset = Idx - BasisIdx; + + BumpWithUglyGEP = false; + if (Basis.CandidateKind == Candidate::GEP) { + APInt ElementSize( + IndexOffset.getBitWidth(), + DL->getTypeAllocSize( + cast<GetElementPtrInst>(Basis.Ins)->getResultElementType())); + APInt Q, R; + APInt::sdivrem(IndexOffset, ElementSize, Q, R); + if (R == 0) + IndexOffset = Q; + else + BumpWithUglyGEP = true; + } + + // Compute Bump = C - Basis = (i' - i) * S. + // Common case 1: if (i' - i) is 1, Bump = S. + if (IndexOffset == 1) + return C.Stride; + // Common case 2: if (i' - i) is -1, Bump = -S. + if (IndexOffset.isAllOnesValue()) + return Builder.CreateNeg(C.Stride); + + // Otherwise, Bump = (i' - i) * sext/trunc(S). Note that (i' - i) and S may + // have different bit widths. + IntegerType *DeltaType = + IntegerType::get(Basis.Ins->getContext(), IndexOffset.getBitWidth()); + Value *ExtendedStride = Builder.CreateSExtOrTrunc(C.Stride, DeltaType); + if (IndexOffset.isPowerOf2()) { + // If (i' - i) is a power of 2, Bump = sext/trunc(S) << log(i' - i). + ConstantInt *Exponent = ConstantInt::get(DeltaType, IndexOffset.logBase2()); + return Builder.CreateShl(ExtendedStride, Exponent); + } + if ((-IndexOffset).isPowerOf2()) { + // If (i - i') is a power of 2, Bump = -sext/trunc(S) << log(i' - i). + ConstantInt *Exponent = + ConstantInt::get(DeltaType, (-IndexOffset).logBase2()); + return Builder.CreateNeg(Builder.CreateShl(ExtendedStride, Exponent)); + } + Constant *Delta = ConstantInt::get(DeltaType, IndexOffset); + return Builder.CreateMul(ExtendedStride, Delta); +} + +void StraightLineStrengthReduce::rewriteCandidateWithBasis( + const Candidate &C, const Candidate &Basis) { + assert(C.CandidateKind == Basis.CandidateKind && C.Base == Basis.Base && + C.Stride == Basis.Stride); + // We run rewriteCandidateWithBasis on all candidates in a post-order, so the + // basis of a candidate cannot be unlinked before the candidate. + assert(Basis.Ins->getParent() != nullptr && "the basis is unlinked"); + + // An instruction can correspond to multiple candidates. Therefore, instead of + // simply deleting an instruction when we rewrite it, we mark its parent as + // nullptr (i.e. unlink it) so that we can skip the candidates whose + // instruction is already rewritten. + if (!C.Ins->getParent()) + return; + + IRBuilder<> Builder(C.Ins); + bool BumpWithUglyGEP; + Value *Bump = emitBump(Basis, C, Builder, DL, BumpWithUglyGEP); + Value *Reduced = nullptr; // equivalent to but weaker than C.Ins + switch (C.CandidateKind) { + case Candidate::Add: + case Candidate::Mul: + // C = Basis + Bump + if (BinaryOperator::isNeg(Bump)) { + // If Bump is a neg instruction, emit C = Basis - (-Bump). + Reduced = + Builder.CreateSub(Basis.Ins, BinaryOperator::getNegArgument(Bump)); + // We only use the negative argument of Bump, and Bump itself may be + // trivially dead. + RecursivelyDeleteTriviallyDeadInstructions(Bump); + } else { + // It's tempting to preserve nsw on Bump and/or Reduced. However, it's + // usually unsound, e.g., + // + // X = (-2 +nsw 1) *nsw INT_MAX + // Y = (-2 +nsw 3) *nsw INT_MAX + // => + // Y = X + 2 * INT_MAX + // + // Neither + and * in the resultant expression are nsw. + Reduced = Builder.CreateAdd(Basis.Ins, Bump); + } + break; + case Candidate::GEP: + { + Type *IntPtrTy = DL->getIntPtrType(C.Ins->getType()); + bool InBounds = cast<GetElementPtrInst>(C.Ins)->isInBounds(); + if (BumpWithUglyGEP) { + // C = (char *)Basis + Bump + unsigned AS = Basis.Ins->getType()->getPointerAddressSpace(); + Type *CharTy = Type::getInt8PtrTy(Basis.Ins->getContext(), AS); + Reduced = Builder.CreateBitCast(Basis.Ins, CharTy); + if (InBounds) + Reduced = + Builder.CreateInBoundsGEP(Builder.getInt8Ty(), Reduced, Bump); + else + Reduced = Builder.CreateGEP(Builder.getInt8Ty(), Reduced, Bump); + Reduced = Builder.CreateBitCast(Reduced, C.Ins->getType()); + } else { + // C = gep Basis, Bump + // Canonicalize bump to pointer size. + Bump = Builder.CreateSExtOrTrunc(Bump, IntPtrTy); + if (InBounds) + Reduced = Builder.CreateInBoundsGEP(nullptr, Basis.Ins, Bump); + else + Reduced = Builder.CreateGEP(nullptr, Basis.Ins, Bump); + } + break; + } + default: + llvm_unreachable("C.CandidateKind is invalid"); + }; + Reduced->takeName(C.Ins); + C.Ins->replaceAllUsesWith(Reduced); + // Unlink C.Ins so that we can skip other candidates also corresponding to + // C.Ins. The actual deletion is postponed to the end of runOnFunction. + C.Ins->removeFromParent(); + UnlinkedInstructions.push_back(C.Ins); +} + +bool StraightLineStrengthReduce::runOnFunction(Function &F) { + if (skipFunction(F)) + return false; + + TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); + // Traverse the dominator tree in the depth-first order. This order makes sure + // all bases of a candidate are in Candidates when we process it. + for (const auto Node : depth_first(DT)) + for (auto &I : *(Node->getBlock())) + allocateCandidatesAndFindBasis(&I); + + // Rewrite candidates in the reverse depth-first order. This order makes sure + // a candidate being rewritten is not a basis for any other candidate. + while (!Candidates.empty()) { + const Candidate &C = Candidates.back(); + if (C.Basis != nullptr) { + rewriteCandidateWithBasis(C, *C.Basis); + } + Candidates.pop_back(); + } + + // Delete all unlink instructions. + for (auto *UnlinkedInst : UnlinkedInstructions) { + for (unsigned I = 0, E = UnlinkedInst->getNumOperands(); I != E; ++I) { + Value *Op = UnlinkedInst->getOperand(I); + UnlinkedInst->setOperand(I, nullptr); + RecursivelyDeleteTriviallyDeadInstructions(Op); + } + UnlinkedInst->deleteValue(); + } + bool Ret = !UnlinkedInstructions.empty(); + UnlinkedInstructions.clear(); + return Ret; +} diff --git a/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp b/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp new file mode 100644 index 000000000000..b8fb80b6cc26 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/StructurizeCFG.cpp @@ -0,0 +1,953 @@ +//===- StructurizeCFG.cpp -------------------------------------------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/PostOrderIterator.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Analysis/DivergenceAnalysis.h" +#include "llvm/Analysis/LoopInfo.h" +#include "llvm/Analysis/RegionInfo.h" +#include "llvm/Analysis/RegionIterator.h" +#include "llvm/Analysis/RegionPass.h" +#include "llvm/IR/Argument.h" +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/Constant.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/Dominators.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstrTypes.h" +#include "llvm/IR/Instruction.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Metadata.h" +#include "llvm/IR/PatternMatch.h" +#include "llvm/IR/Type.h" +#include "llvm/IR/Use.h" +#include "llvm/IR/User.h" +#include "llvm/IR/Value.h" +#include "llvm/Pass.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/SSAUpdater.h" +#include <algorithm> +#include <cassert> +#include <utility> + +using namespace llvm; +using namespace llvm::PatternMatch; + +#define DEBUG_TYPE "structurizecfg" + +// The name for newly created blocks. +static const char *const FlowBlockName = "Flow"; + +namespace { + +// Definition of the complex types used in this pass. + +using BBValuePair = std::pair<BasicBlock *, Value *>; + +using RNVector = SmallVector<RegionNode *, 8>; +using BBVector = SmallVector<BasicBlock *, 8>; +using BranchVector = SmallVector<BranchInst *, 8>; +using BBValueVector = SmallVector<BBValuePair, 2>; + +using BBSet = SmallPtrSet<BasicBlock *, 8>; + +using PhiMap = MapVector<PHINode *, BBValueVector>; +using BB2BBVecMap = MapVector<BasicBlock *, BBVector>; + +using BBPhiMap = DenseMap<BasicBlock *, PhiMap>; +using BBPredicates = DenseMap<BasicBlock *, Value *>; +using PredMap = DenseMap<BasicBlock *, BBPredicates>; +using BB2BBMap = DenseMap<BasicBlock *, BasicBlock *>; + +/// Finds the nearest common dominator of a set of BasicBlocks. +/// +/// For every BB you add to the set, you can specify whether we "remember" the +/// block. When you get the common dominator, you can also ask whether it's one +/// of the blocks we remembered. +class NearestCommonDominator { + DominatorTree *DT; + BasicBlock *Result = nullptr; + bool ResultIsRemembered = false; + + /// Add BB to the resulting dominator. + void addBlock(BasicBlock *BB, bool Remember) { + if (!Result) { + Result = BB; + ResultIsRemembered = Remember; + return; + } + + BasicBlock *NewResult = DT->findNearestCommonDominator(Result, BB); + if (NewResult != Result) + ResultIsRemembered = false; + if (NewResult == BB) + ResultIsRemembered |= Remember; + Result = NewResult; + } + +public: + explicit NearestCommonDominator(DominatorTree *DomTree) : DT(DomTree) {} + + void addBlock(BasicBlock *BB) { + addBlock(BB, /* Remember = */ false); + } + + void addAndRememberBlock(BasicBlock *BB) { + addBlock(BB, /* Remember = */ true); + } + + /// Get the nearest common dominator of all the BBs added via addBlock() and + /// addAndRememberBlock(). + BasicBlock *result() { return Result; } + + /// Is the BB returned by getResult() one of the blocks we added to the set + /// with addAndRememberBlock()? + bool resultIsRememberedBlock() { return ResultIsRemembered; } +}; + +/// @brief Transforms the control flow graph on one single entry/exit region +/// at a time. +/// +/// After the transform all "If"/"Then"/"Else" style control flow looks like +/// this: +/// +/// \verbatim +/// 1 +/// || +/// | | +/// 2 | +/// | / +/// |/ +/// 3 +/// || Where: +/// | | 1 = "If" block, calculates the condition +/// 4 | 2 = "Then" subregion, runs if the condition is true +/// | / 3 = "Flow" blocks, newly inserted flow blocks, rejoins the flow +/// |/ 4 = "Else" optional subregion, runs if the condition is false +/// 5 5 = "End" block, also rejoins the control flow +/// \endverbatim +/// +/// Control flow is expressed as a branch where the true exit goes into the +/// "Then"/"Else" region, while the false exit skips the region +/// The condition for the optional "Else" region is expressed as a PHI node. +/// The incoming values of the PHI node are true for the "If" edge and false +/// for the "Then" edge. +/// +/// Additionally to that even complicated loops look like this: +/// +/// \verbatim +/// 1 +/// || +/// | | +/// 2 ^ Where: +/// | / 1 = "Entry" block +/// |/ 2 = "Loop" optional subregion, with all exits at "Flow" block +/// 3 3 = "Flow" block, with back edge to entry block +/// | +/// \endverbatim +/// +/// The back edge of the "Flow" block is always on the false side of the branch +/// while the true side continues the general flow. So the loop condition +/// consist of a network of PHI nodes where the true incoming values expresses +/// breaks and the false values expresses continue states. +class StructurizeCFG : public RegionPass { + bool SkipUniformRegions; + + Type *Boolean; + ConstantInt *BoolTrue; + ConstantInt *BoolFalse; + UndefValue *BoolUndef; + + Function *Func; + Region *ParentRegion; + + DominatorTree *DT; + LoopInfo *LI; + + SmallVector<RegionNode *, 8> Order; + BBSet Visited; + + BBPhiMap DeletedPhis; + BB2BBVecMap AddedPhis; + + PredMap Predicates; + BranchVector Conditions; + + BB2BBMap Loops; + PredMap LoopPreds; + BranchVector LoopConds; + + RegionNode *PrevNode; + + void orderNodes(); + + void analyzeLoops(RegionNode *N); + + Value *invert(Value *Condition); + + Value *buildCondition(BranchInst *Term, unsigned Idx, bool Invert); + + void gatherPredicates(RegionNode *N); + + void collectInfos(); + + void insertConditions(bool Loops); + + void delPhiValues(BasicBlock *From, BasicBlock *To); + + void addPhiValues(BasicBlock *From, BasicBlock *To); + + void setPhiValues(); + + void killTerminator(BasicBlock *BB); + + void changeExit(RegionNode *Node, BasicBlock *NewExit, + bool IncludeDominator); + + BasicBlock *getNextFlow(BasicBlock *Dominator); + + BasicBlock *needPrefix(bool NeedEmpty); + + BasicBlock *needPostfix(BasicBlock *Flow, bool ExitUseAllowed); + + void setPrevNode(BasicBlock *BB); + + bool dominatesPredicates(BasicBlock *BB, RegionNode *Node); + + bool isPredictableTrue(RegionNode *Node); + + void wireFlow(bool ExitUseAllowed, BasicBlock *LoopEnd); + + void handleLoops(bool ExitUseAllowed, BasicBlock *LoopEnd); + + void createFlow(); + + void rebuildSSA(); + +public: + static char ID; + + explicit StructurizeCFG(bool SkipUniformRegions = false) + : RegionPass(ID), SkipUniformRegions(SkipUniformRegions) { + initializeStructurizeCFGPass(*PassRegistry::getPassRegistry()); + } + + bool doInitialization(Region *R, RGPassManager &RGM) override; + + bool runOnRegion(Region *R, RGPassManager &RGM) override; + + StringRef getPassName() const override { return "Structurize control flow"; } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + if (SkipUniformRegions) + AU.addRequired<DivergenceAnalysis>(); + AU.addRequiredID(LowerSwitchID); + AU.addRequired<DominatorTreeWrapperPass>(); + AU.addRequired<LoopInfoWrapperPass>(); + + AU.addPreserved<DominatorTreeWrapperPass>(); + RegionPass::getAnalysisUsage(AU); + } +}; + +} // end anonymous namespace + +char StructurizeCFG::ID = 0; + +INITIALIZE_PASS_BEGIN(StructurizeCFG, "structurizecfg", "Structurize the CFG", + false, false) +INITIALIZE_PASS_DEPENDENCY(DivergenceAnalysis) +INITIALIZE_PASS_DEPENDENCY(LowerSwitch) +INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) +INITIALIZE_PASS_DEPENDENCY(RegionInfoPass) +INITIALIZE_PASS_END(StructurizeCFG, "structurizecfg", "Structurize the CFG", + false, false) + +/// \brief Initialize the types and constants used in the pass +bool StructurizeCFG::doInitialization(Region *R, RGPassManager &RGM) { + LLVMContext &Context = R->getEntry()->getContext(); + + Boolean = Type::getInt1Ty(Context); + BoolTrue = ConstantInt::getTrue(Context); + BoolFalse = ConstantInt::getFalse(Context); + BoolUndef = UndefValue::get(Boolean); + + return false; +} + +/// \brief Build up the general order of nodes +void StructurizeCFG::orderNodes() { + ReversePostOrderTraversal<Region*> RPOT(ParentRegion); + SmallDenseMap<Loop*, unsigned, 8> LoopBlocks; + + // The reverse post-order traversal of the list gives us an ordering close + // to what we want. The only problem with it is that sometimes backedges + // for outer loops will be visited before backedges for inner loops. + for (RegionNode *RN : RPOT) { + BasicBlock *BB = RN->getEntry(); + Loop *Loop = LI->getLoopFor(BB); + ++LoopBlocks[Loop]; + } + + unsigned CurrentLoopDepth = 0; + Loop *CurrentLoop = nullptr; + for (auto I = RPOT.begin(), E = RPOT.end(); I != E; ++I) { + BasicBlock *BB = (*I)->getEntry(); + unsigned LoopDepth = LI->getLoopDepth(BB); + + if (is_contained(Order, *I)) + continue; + + if (LoopDepth < CurrentLoopDepth) { + // Make sure we have visited all blocks in this loop before moving back to + // the outer loop. + + auto LoopI = I; + while (unsigned &BlockCount = LoopBlocks[CurrentLoop]) { + LoopI++; + BasicBlock *LoopBB = (*LoopI)->getEntry(); + if (LI->getLoopFor(LoopBB) == CurrentLoop) { + --BlockCount; + Order.push_back(*LoopI); + } + } + } + + CurrentLoop = LI->getLoopFor(BB); + if (CurrentLoop) + LoopBlocks[CurrentLoop]--; + + CurrentLoopDepth = LoopDepth; + Order.push_back(*I); + } + + // This pass originally used a post-order traversal and then operated on + // the list in reverse. Now that we are using a reverse post-order traversal + // rather than re-working the whole pass to operate on the list in order, + // we just reverse the list and continue to operate on it in reverse. + std::reverse(Order.begin(), Order.end()); +} + +/// \brief Determine the end of the loops +void StructurizeCFG::analyzeLoops(RegionNode *N) { + if (N->isSubRegion()) { + // Test for exit as back edge + BasicBlock *Exit = N->getNodeAs<Region>()->getExit(); + if (Visited.count(Exit)) + Loops[Exit] = N->getEntry(); + + } else { + // Test for successors as back edge + BasicBlock *BB = N->getNodeAs<BasicBlock>(); + BranchInst *Term = cast<BranchInst>(BB->getTerminator()); + + for (BasicBlock *Succ : Term->successors()) + if (Visited.count(Succ)) + Loops[Succ] = BB; + } +} + +/// \brief Invert the given condition +Value *StructurizeCFG::invert(Value *Condition) { + // First: Check if it's a constant + if (Constant *C = dyn_cast<Constant>(Condition)) + return ConstantExpr::getNot(C); + + // Second: If the condition is already inverted, return the original value + if (match(Condition, m_Not(m_Value(Condition)))) + return Condition; + + if (Instruction *Inst = dyn_cast<Instruction>(Condition)) { + // Third: Check all the users for an invert + BasicBlock *Parent = Inst->getParent(); + for (User *U : Condition->users()) + if (Instruction *I = dyn_cast<Instruction>(U)) + if (I->getParent() == Parent && match(I, m_Not(m_Specific(Condition)))) + return I; + + // Last option: Create a new instruction + return BinaryOperator::CreateNot(Condition, "", Parent->getTerminator()); + } + + if (Argument *Arg = dyn_cast<Argument>(Condition)) { + BasicBlock &EntryBlock = Arg->getParent()->getEntryBlock(); + return BinaryOperator::CreateNot(Condition, + Arg->getName() + ".inv", + EntryBlock.getTerminator()); + } + + llvm_unreachable("Unhandled condition to invert"); +} + +/// \brief Build the condition for one edge +Value *StructurizeCFG::buildCondition(BranchInst *Term, unsigned Idx, + bool Invert) { + Value *Cond = Invert ? BoolFalse : BoolTrue; + if (Term->isConditional()) { + Cond = Term->getCondition(); + + if (Idx != (unsigned)Invert) + Cond = invert(Cond); + } + return Cond; +} + +/// \brief Analyze the predecessors of each block and build up predicates +void StructurizeCFG::gatherPredicates(RegionNode *N) { + RegionInfo *RI = ParentRegion->getRegionInfo(); + BasicBlock *BB = N->getEntry(); + BBPredicates &Pred = Predicates[BB]; + BBPredicates &LPred = LoopPreds[BB]; + + for (BasicBlock *P : predecessors(BB)) { + // Ignore it if it's a branch from outside into our region entry + if (!ParentRegion->contains(P)) + continue; + + Region *R = RI->getRegionFor(P); + if (R == ParentRegion) { + // It's a top level block in our region + BranchInst *Term = cast<BranchInst>(P->getTerminator()); + for (unsigned i = 0, e = Term->getNumSuccessors(); i != e; ++i) { + BasicBlock *Succ = Term->getSuccessor(i); + if (Succ != BB) + continue; + + if (Visited.count(P)) { + // Normal forward edge + if (Term->isConditional()) { + // Try to treat it like an ELSE block + BasicBlock *Other = Term->getSuccessor(!i); + if (Visited.count(Other) && !Loops.count(Other) && + !Pred.count(Other) && !Pred.count(P)) { + + Pred[Other] = BoolFalse; + Pred[P] = BoolTrue; + continue; + } + } + Pred[P] = buildCondition(Term, i, false); + } else { + // Back edge + LPred[P] = buildCondition(Term, i, true); + } + } + } else { + // It's an exit from a sub region + while (R->getParent() != ParentRegion) + R = R->getParent(); + + // Edge from inside a subregion to its entry, ignore it + if (*R == *N) + continue; + + BasicBlock *Entry = R->getEntry(); + if (Visited.count(Entry)) + Pred[Entry] = BoolTrue; + else + LPred[Entry] = BoolFalse; + } + } +} + +/// \brief Collect various loop and predicate infos +void StructurizeCFG::collectInfos() { + // Reset predicate + Predicates.clear(); + + // and loop infos + Loops.clear(); + LoopPreds.clear(); + + // Reset the visited nodes + Visited.clear(); + + for (RegionNode *RN : reverse(Order)) { + DEBUG(dbgs() << "Visiting: " + << (RN->isSubRegion() ? "SubRegion with entry: " : "") + << RN->getEntry()->getName() << " Loop Depth: " + << LI->getLoopDepth(RN->getEntry()) << "\n"); + + // Analyze all the conditions leading to a node + gatherPredicates(RN); + + // Remember that we've seen this node + Visited.insert(RN->getEntry()); + + // Find the last back edges + analyzeLoops(RN); + } +} + +/// \brief Insert the missing branch conditions +void StructurizeCFG::insertConditions(bool Loops) { + BranchVector &Conds = Loops ? LoopConds : Conditions; + Value *Default = Loops ? BoolTrue : BoolFalse; + SSAUpdater PhiInserter; + + for (BranchInst *Term : Conds) { + assert(Term->isConditional()); + + BasicBlock *Parent = Term->getParent(); + BasicBlock *SuccTrue = Term->getSuccessor(0); + BasicBlock *SuccFalse = Term->getSuccessor(1); + + PhiInserter.Initialize(Boolean, ""); + PhiInserter.AddAvailableValue(&Func->getEntryBlock(), Default); + PhiInserter.AddAvailableValue(Loops ? SuccFalse : Parent, Default); + + BBPredicates &Preds = Loops ? LoopPreds[SuccFalse] : Predicates[SuccTrue]; + + NearestCommonDominator Dominator(DT); + Dominator.addBlock(Parent); + + Value *ParentValue = nullptr; + for (std::pair<BasicBlock *, Value *> BBAndPred : Preds) { + BasicBlock *BB = BBAndPred.first; + Value *Pred = BBAndPred.second; + + if (BB == Parent) { + ParentValue = Pred; + break; + } + PhiInserter.AddAvailableValue(BB, Pred); + Dominator.addAndRememberBlock(BB); + } + + if (ParentValue) { + Term->setCondition(ParentValue); + } else { + if (!Dominator.resultIsRememberedBlock()) + PhiInserter.AddAvailableValue(Dominator.result(), Default); + + Term->setCondition(PhiInserter.GetValueInMiddleOfBlock(Parent)); + } + } +} + +/// \brief Remove all PHI values coming from "From" into "To" and remember +/// them in DeletedPhis +void StructurizeCFG::delPhiValues(BasicBlock *From, BasicBlock *To) { + PhiMap &Map = DeletedPhis[To]; + for (PHINode &Phi : To->phis()) { + while (Phi.getBasicBlockIndex(From) != -1) { + Value *Deleted = Phi.removeIncomingValue(From, false); + Map[&Phi].push_back(std::make_pair(From, Deleted)); + } + } +} + +/// \brief Add a dummy PHI value as soon as we knew the new predecessor +void StructurizeCFG::addPhiValues(BasicBlock *From, BasicBlock *To) { + for (PHINode &Phi : To->phis()) { + Value *Undef = UndefValue::get(Phi.getType()); + Phi.addIncoming(Undef, From); + } + AddedPhis[To].push_back(From); +} + +/// \brief Add the real PHI value as soon as everything is set up +void StructurizeCFG::setPhiValues() { + SSAUpdater Updater; + for (const auto &AddedPhi : AddedPhis) { + BasicBlock *To = AddedPhi.first; + const BBVector &From = AddedPhi.second; + + if (!DeletedPhis.count(To)) + continue; + + PhiMap &Map = DeletedPhis[To]; + for (const auto &PI : Map) { + PHINode *Phi = PI.first; + Value *Undef = UndefValue::get(Phi->getType()); + Updater.Initialize(Phi->getType(), ""); + Updater.AddAvailableValue(&Func->getEntryBlock(), Undef); + Updater.AddAvailableValue(To, Undef); + + NearestCommonDominator Dominator(DT); + Dominator.addBlock(To); + for (const auto &VI : PI.second) { + Updater.AddAvailableValue(VI.first, VI.second); + Dominator.addAndRememberBlock(VI.first); + } + + if (!Dominator.resultIsRememberedBlock()) + Updater.AddAvailableValue(Dominator.result(), Undef); + + for (BasicBlock *FI : From) { + int Idx = Phi->getBasicBlockIndex(FI); + assert(Idx != -1); + Phi->setIncomingValue(Idx, Updater.GetValueAtEndOfBlock(FI)); + } + } + + DeletedPhis.erase(To); + } + assert(DeletedPhis.empty()); +} + +/// \brief Remove phi values from all successors and then remove the terminator. +void StructurizeCFG::killTerminator(BasicBlock *BB) { + TerminatorInst *Term = BB->getTerminator(); + if (!Term) + return; + + for (succ_iterator SI = succ_begin(BB), SE = succ_end(BB); + SI != SE; ++SI) + delPhiValues(BB, *SI); + + Term->eraseFromParent(); +} + +/// \brief Let node exit(s) point to NewExit +void StructurizeCFG::changeExit(RegionNode *Node, BasicBlock *NewExit, + bool IncludeDominator) { + if (Node->isSubRegion()) { + Region *SubRegion = Node->getNodeAs<Region>(); + BasicBlock *OldExit = SubRegion->getExit(); + BasicBlock *Dominator = nullptr; + + // Find all the edges from the sub region to the exit + for (auto BBI = pred_begin(OldExit), E = pred_end(OldExit); BBI != E;) { + // Incrememt BBI before mucking with BB's terminator. + BasicBlock *BB = *BBI++; + + if (!SubRegion->contains(BB)) + continue; + + // Modify the edges to point to the new exit + delPhiValues(BB, OldExit); + BB->getTerminator()->replaceUsesOfWith(OldExit, NewExit); + addPhiValues(BB, NewExit); + + // Find the new dominator (if requested) + if (IncludeDominator) { + if (!Dominator) + Dominator = BB; + else + Dominator = DT->findNearestCommonDominator(Dominator, BB); + } + } + + // Change the dominator (if requested) + if (Dominator) + DT->changeImmediateDominator(NewExit, Dominator); + + // Update the region info + SubRegion->replaceExit(NewExit); + } else { + BasicBlock *BB = Node->getNodeAs<BasicBlock>(); + killTerminator(BB); + BranchInst::Create(NewExit, BB); + addPhiValues(BB, NewExit); + if (IncludeDominator) + DT->changeImmediateDominator(NewExit, BB); + } +} + +/// \brief Create a new flow node and update dominator tree and region info +BasicBlock *StructurizeCFG::getNextFlow(BasicBlock *Dominator) { + LLVMContext &Context = Func->getContext(); + BasicBlock *Insert = Order.empty() ? ParentRegion->getExit() : + Order.back()->getEntry(); + BasicBlock *Flow = BasicBlock::Create(Context, FlowBlockName, + Func, Insert); + DT->addNewBlock(Flow, Dominator); + ParentRegion->getRegionInfo()->setRegionFor(Flow, ParentRegion); + return Flow; +} + +/// \brief Create a new or reuse the previous node as flow node +BasicBlock *StructurizeCFG::needPrefix(bool NeedEmpty) { + BasicBlock *Entry = PrevNode->getEntry(); + + if (!PrevNode->isSubRegion()) { + killTerminator(Entry); + if (!NeedEmpty || Entry->getFirstInsertionPt() == Entry->end()) + return Entry; + } + + // create a new flow node + BasicBlock *Flow = getNextFlow(Entry); + + // and wire it up + changeExit(PrevNode, Flow, true); + PrevNode = ParentRegion->getBBNode(Flow); + return Flow; +} + +/// \brief Returns the region exit if possible, otherwise just a new flow node +BasicBlock *StructurizeCFG::needPostfix(BasicBlock *Flow, + bool ExitUseAllowed) { + if (!Order.empty() || !ExitUseAllowed) + return getNextFlow(Flow); + + BasicBlock *Exit = ParentRegion->getExit(); + DT->changeImmediateDominator(Exit, Flow); + addPhiValues(Flow, Exit); + return Exit; +} + +/// \brief Set the previous node +void StructurizeCFG::setPrevNode(BasicBlock *BB) { + PrevNode = ParentRegion->contains(BB) ? ParentRegion->getBBNode(BB) + : nullptr; +} + +/// \brief Does BB dominate all the predicates of Node? +bool StructurizeCFG::dominatesPredicates(BasicBlock *BB, RegionNode *Node) { + BBPredicates &Preds = Predicates[Node->getEntry()]; + return llvm::all_of(Preds, [&](std::pair<BasicBlock *, Value *> Pred) { + return DT->dominates(BB, Pred.first); + }); +} + +/// \brief Can we predict that this node will always be called? +bool StructurizeCFG::isPredictableTrue(RegionNode *Node) { + BBPredicates &Preds = Predicates[Node->getEntry()]; + bool Dominated = false; + + // Regionentry is always true + if (!PrevNode) + return true; + + for (std::pair<BasicBlock*, Value*> Pred : Preds) { + BasicBlock *BB = Pred.first; + Value *V = Pred.second; + + if (V != BoolTrue) + return false; + + if (!Dominated && DT->dominates(BB, PrevNode->getEntry())) + Dominated = true; + } + + // TODO: The dominator check is too strict + return Dominated; +} + +/// Take one node from the order vector and wire it up +void StructurizeCFG::wireFlow(bool ExitUseAllowed, + BasicBlock *LoopEnd) { + RegionNode *Node = Order.pop_back_val(); + Visited.insert(Node->getEntry()); + + if (isPredictableTrue(Node)) { + // Just a linear flow + if (PrevNode) { + changeExit(PrevNode, Node->getEntry(), true); + } + PrevNode = Node; + } else { + // Insert extra prefix node (or reuse last one) + BasicBlock *Flow = needPrefix(false); + + // Insert extra postfix node (or use exit instead) + BasicBlock *Entry = Node->getEntry(); + BasicBlock *Next = needPostfix(Flow, ExitUseAllowed); + + // let it point to entry and next block + Conditions.push_back(BranchInst::Create(Entry, Next, BoolUndef, Flow)); + addPhiValues(Flow, Entry); + DT->changeImmediateDominator(Entry, Flow); + + PrevNode = Node; + while (!Order.empty() && !Visited.count(LoopEnd) && + dominatesPredicates(Entry, Order.back())) { + handleLoops(false, LoopEnd); + } + + changeExit(PrevNode, Next, false); + setPrevNode(Next); + } +} + +void StructurizeCFG::handleLoops(bool ExitUseAllowed, + BasicBlock *LoopEnd) { + RegionNode *Node = Order.back(); + BasicBlock *LoopStart = Node->getEntry(); + + if (!Loops.count(LoopStart)) { + wireFlow(ExitUseAllowed, LoopEnd); + return; + } + + if (!isPredictableTrue(Node)) + LoopStart = needPrefix(true); + + LoopEnd = Loops[Node->getEntry()]; + wireFlow(false, LoopEnd); + while (!Visited.count(LoopEnd)) { + handleLoops(false, LoopEnd); + } + + // If the start of the loop is the entry block, we can't branch to it so + // insert a new dummy entry block. + Function *LoopFunc = LoopStart->getParent(); + if (LoopStart == &LoopFunc->getEntryBlock()) { + LoopStart->setName("entry.orig"); + + BasicBlock *NewEntry = + BasicBlock::Create(LoopStart->getContext(), + "entry", + LoopFunc, + LoopStart); + BranchInst::Create(LoopStart, NewEntry); + DT->setNewRoot(NewEntry); + } + + // Create an extra loop end node + LoopEnd = needPrefix(false); + BasicBlock *Next = needPostfix(LoopEnd, ExitUseAllowed); + LoopConds.push_back(BranchInst::Create(Next, LoopStart, + BoolUndef, LoopEnd)); + addPhiValues(LoopEnd, LoopStart); + setPrevNode(Next); +} + +/// After this function control flow looks like it should be, but +/// branches and PHI nodes only have undefined conditions. +void StructurizeCFG::createFlow() { + BasicBlock *Exit = ParentRegion->getExit(); + bool EntryDominatesExit = DT->dominates(ParentRegion->getEntry(), Exit); + + DeletedPhis.clear(); + AddedPhis.clear(); + Conditions.clear(); + LoopConds.clear(); + + PrevNode = nullptr; + Visited.clear(); + + while (!Order.empty()) { + handleLoops(EntryDominatesExit, nullptr); + } + + if (PrevNode) + changeExit(PrevNode, Exit, EntryDominatesExit); + else + assert(EntryDominatesExit); +} + +/// Handle a rare case where the disintegrated nodes instructions +/// no longer dominate all their uses. Not sure if this is really nessasary +void StructurizeCFG::rebuildSSA() { + SSAUpdater Updater; + for (BasicBlock *BB : ParentRegion->blocks()) + for (Instruction &I : *BB) { + bool Initialized = false; + // We may modify the use list as we iterate over it, so be careful to + // compute the next element in the use list at the top of the loop. + for (auto UI = I.use_begin(), E = I.use_end(); UI != E;) { + Use &U = *UI++; + Instruction *User = cast<Instruction>(U.getUser()); + if (User->getParent() == BB) { + continue; + } else if (PHINode *UserPN = dyn_cast<PHINode>(User)) { + if (UserPN->getIncomingBlock(U) == BB) + continue; + } + + if (DT->dominates(&I, User)) + continue; + + if (!Initialized) { + Value *Undef = UndefValue::get(I.getType()); + Updater.Initialize(I.getType(), ""); + Updater.AddAvailableValue(&Func->getEntryBlock(), Undef); + Updater.AddAvailableValue(BB, &I); + Initialized = true; + } + Updater.RewriteUseAfterInsertions(U); + } + } +} + +static bool hasOnlyUniformBranches(const Region *R, + const DivergenceAnalysis &DA) { + for (const BasicBlock *BB : R->blocks()) { + const BranchInst *Br = dyn_cast<BranchInst>(BB->getTerminator()); + if (!Br || !Br->isConditional()) + continue; + + if (!DA.isUniform(Br->getCondition())) + return false; + DEBUG(dbgs() << "BB: " << BB->getName() << " has uniform terminator\n"); + } + return true; +} + +/// \brief Run the transformation for each region found +bool StructurizeCFG::runOnRegion(Region *R, RGPassManager &RGM) { + if (R->isTopLevelRegion()) + return false; + + if (SkipUniformRegions) { + // TODO: We could probably be smarter here with how we handle sub-regions. + auto &DA = getAnalysis<DivergenceAnalysis>(); + if (hasOnlyUniformBranches(R, DA)) { + DEBUG(dbgs() << "Skipping region with uniform control flow: " << *R << '\n'); + + // Mark all direct child block terminators as having been treated as + // uniform. To account for a possible future in which non-uniform + // sub-regions are treated more cleverly, indirect children are not + // marked as uniform. + MDNode *MD = MDNode::get(R->getEntry()->getParent()->getContext(), {}); + for (RegionNode *E : R->elements()) { + if (E->isSubRegion()) + continue; + + if (Instruction *Term = E->getEntry()->getTerminator()) + Term->setMetadata("structurizecfg.uniform", MD); + } + + return false; + } + } + + Func = R->getEntry()->getParent(); + ParentRegion = R; + + DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); + LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); + + orderNodes(); + collectInfos(); + createFlow(); + insertConditions(false); + insertConditions(true); + setPhiValues(); + rebuildSSA(); + + // Cleanup + Order.clear(); + Visited.clear(); + DeletedPhis.clear(); + AddedPhis.clear(); + Predicates.clear(); + Conditions.clear(); + Loops.clear(); + LoopPreds.clear(); + LoopConds.clear(); + + return true; +} + +Pass *llvm::createStructurizeCFGPass(bool SkipUniformRegions) { + return new StructurizeCFG(SkipUniformRegions); +} diff --git a/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp b/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp new file mode 100644 index 000000000000..2a1106b41de2 --- /dev/null +++ b/contrib/llvm/lib/Transforms/Scalar/TailRecursionElimination.cpp @@ -0,0 +1,853 @@ +//===- TailRecursionElimination.cpp - Eliminate Tail Calls ----------------===// +// +// The LLVM Compiler Infrastructure +// +// This file is distributed under the University of Illinois Open Source +// License. See LICENSE.TXT for details. +// +//===----------------------------------------------------------------------===// +// +// This file transforms calls of the current function (self recursion) followed +// by a return instruction with a branch to the entry of the function, creating +// a loop. This pass also implements the following extensions to the basic +// algorithm: +// +// 1. Trivial instructions between the call and return do not prevent the +// transformation from taking place, though currently the analysis cannot +// support moving any really useful instructions (only dead ones). +// 2. This pass transforms functions that are prevented from being tail +// recursive by an associative and commutative expression to use an +// accumulator variable, thus compiling the typical naive factorial or +// 'fib' implementation into efficient code. +// 3. TRE is performed if the function returns void, if the return +// returns the result returned by the call, or if the function returns a +// run-time constant on all exits from the function. It is possible, though +// unlikely, that the return returns something else (like constant 0), and +// can still be TRE'd. It can be TRE'd if ALL OTHER return instructions in +// the function return the exact same value. +// 4. If it can prove that callees do not access their caller stack frame, +// they are marked as eligible for tail call elimination (by the code +// generator). +// +// There are several improvements that could be made: +// +// 1. If the function has any alloca instructions, these instructions will be +// moved out of the entry block of the function, causing them to be +// evaluated each time through the tail recursion. Safely keeping allocas +// in the entry block requires analysis to proves that the tail-called +// function does not read or write the stack object. +// 2. Tail recursion is only performed if the call immediately precedes the +// return instruction. It's possible that there could be a jump between +// the call and the return. +// 3. There can be intervening operations between the call and the return that +// prevent the TRE from occurring. For example, there could be GEP's and +// stores to memory that will not be read or written by the call. This +// requires some substantial analysis (such as with DSA) to prove safe to +// move ahead of the call, but doing so could allow many more TREs to be +// performed, for example in TreeAdd/TreeAlloc from the treeadd benchmark. +// 4. The algorithm we use to detect if callees access their caller stack +// frames is very primitive. +// +//===----------------------------------------------------------------------===// + +#include "llvm/Transforms/Scalar/TailRecursionElimination.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallPtrSet.h" +#include "llvm/ADT/Statistic.h" +#include "llvm/Analysis/CFG.h" +#include "llvm/Analysis/CaptureTracking.h" +#include "llvm/Analysis/GlobalsModRef.h" +#include "llvm/Analysis/InlineCost.h" +#include "llvm/Analysis/InstructionSimplify.h" +#include "llvm/Analysis/Loads.h" +#include "llvm/Analysis/OptimizationRemarkEmitter.h" +#include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/CFG.h" +#include "llvm/IR/CallSite.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/DataLayout.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/DiagnosticInfo.h" +#include "llvm/IR/Function.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Module.h" +#include "llvm/IR/ValueHandle.h" +#include "llvm/Pass.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" +using namespace llvm; + +#define DEBUG_TYPE "tailcallelim" + +STATISTIC(NumEliminated, "Number of tail calls removed"); +STATISTIC(NumRetDuped, "Number of return duplicated"); +STATISTIC(NumAccumAdded, "Number of accumulators introduced"); + +/// \brief Scan the specified function for alloca instructions. +/// If it contains any dynamic allocas, returns false. +static bool canTRE(Function &F) { + // Because of PR962, we don't TRE dynamic allocas. + return llvm::all_of(instructions(F), [](Instruction &I) { + auto *AI = dyn_cast<AllocaInst>(&I); + return !AI || AI->isStaticAlloca(); + }); +} + +namespace { +struct AllocaDerivedValueTracker { + // Start at a root value and walk its use-def chain to mark calls that use the + // value or a derived value in AllocaUsers, and places where it may escape in + // EscapePoints. + void walk(Value *Root) { + SmallVector<Use *, 32> Worklist; + SmallPtrSet<Use *, 32> Visited; + + auto AddUsesToWorklist = [&](Value *V) { + for (auto &U : V->uses()) { + if (!Visited.insert(&U).second) + continue; + Worklist.push_back(&U); + } + }; + + AddUsesToWorklist(Root); + + while (!Worklist.empty()) { + Use *U = Worklist.pop_back_val(); + Instruction *I = cast<Instruction>(U->getUser()); + + switch (I->getOpcode()) { + case Instruction::Call: + case Instruction::Invoke: { + CallSite CS(I); + bool IsNocapture = + CS.isDataOperand(U) && CS.doesNotCapture(CS.getDataOperandNo(U)); + callUsesLocalStack(CS, IsNocapture); + if (IsNocapture) { + // If the alloca-derived argument is passed in as nocapture, then it + // can't propagate to the call's return. That would be capturing. + continue; + } + break; + } + case Instruction::Load: { + // The result of a load is not alloca-derived (unless an alloca has + // otherwise escaped, but this is a local analysis). + continue; + } + case Instruction::Store: { + if (U->getOperandNo() == 0) + EscapePoints.insert(I); + continue; // Stores have no users to analyze. + } + case Instruction::BitCast: + case Instruction::GetElementPtr: + case Instruction::PHI: + case Instruction::Select: + case Instruction::AddrSpaceCast: + break; + default: + EscapePoints.insert(I); + break; + } + + AddUsesToWorklist(I); + } + } + + void callUsesLocalStack(CallSite CS, bool IsNocapture) { + // Add it to the list of alloca users. + AllocaUsers.insert(CS.getInstruction()); + + // If it's nocapture then it can't capture this alloca. + if (IsNocapture) + return; + + // If it can write to memory, it can leak the alloca value. + if (!CS.onlyReadsMemory()) + EscapePoints.insert(CS.getInstruction()); + } + + SmallPtrSet<Instruction *, 32> AllocaUsers; + SmallPtrSet<Instruction *, 32> EscapePoints; +}; +} + +static bool markTails(Function &F, bool &AllCallsAreTailCalls, + OptimizationRemarkEmitter *ORE) { + if (F.callsFunctionThatReturnsTwice()) + return false; + AllCallsAreTailCalls = true; + + // The local stack holds all alloca instructions and all byval arguments. + AllocaDerivedValueTracker Tracker; + for (Argument &Arg : F.args()) { + if (Arg.hasByValAttr()) + Tracker.walk(&Arg); + } + for (auto &BB : F) { + for (auto &I : BB) + if (AllocaInst *AI = dyn_cast<AllocaInst>(&I)) + Tracker.walk(AI); + } + + bool Modified = false; + + // Track whether a block is reachable after an alloca has escaped. Blocks that + // contain the escaping instruction will be marked as being visited without an + // escaped alloca, since that is how the block began. + enum VisitType { + UNVISITED, + UNESCAPED, + ESCAPED + }; + DenseMap<BasicBlock *, VisitType> Visited; + + // We propagate the fact that an alloca has escaped from block to successor. + // Visit the blocks that are propagating the escapedness first. To do this, we + // maintain two worklists. + SmallVector<BasicBlock *, 32> WorklistUnescaped, WorklistEscaped; + + // We may enter a block and visit it thinking that no alloca has escaped yet, + // then see an escape point and go back around a loop edge and come back to + // the same block twice. Because of this, we defer setting tail on calls when + // we first encounter them in a block. Every entry in this list does not + // statically use an alloca via use-def chain analysis, but may find an alloca + // through other means if the block turns out to be reachable after an escape + // point. + SmallVector<CallInst *, 32> DeferredTails; + + BasicBlock *BB = &F.getEntryBlock(); + VisitType Escaped = UNESCAPED; + do { + for (auto &I : *BB) { + if (Tracker.EscapePoints.count(&I)) + Escaped = ESCAPED; + + CallInst *CI = dyn_cast<CallInst>(&I); + if (!CI || CI->isTailCall() || isa<DbgInfoIntrinsic>(&I)) + continue; + + bool IsNoTail = CI->isNoTailCall() || CI->hasOperandBundles(); + + if (!IsNoTail && CI->doesNotAccessMemory()) { + // A call to a readnone function whose arguments are all things computed + // outside this function can be marked tail. Even if you stored the + // alloca address into a global, a readnone function can't load the + // global anyhow. + // + // Note that this runs whether we know an alloca has escaped or not. If + // it has, then we can't trust Tracker.AllocaUsers to be accurate. + bool SafeToTail = true; + for (auto &Arg : CI->arg_operands()) { + if (isa<Constant>(Arg.getUser())) + continue; + if (Argument *A = dyn_cast<Argument>(Arg.getUser())) + if (!A->hasByValAttr()) + continue; + SafeToTail = false; + break; + } + if (SafeToTail) { + using namespace ore; + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "tailcall-readnone", CI) + << "marked as tail call candidate (readnone)"; + }); + CI->setTailCall(); + Modified = true; + continue; + } + } + + if (!IsNoTail && Escaped == UNESCAPED && !Tracker.AllocaUsers.count(CI)) { + DeferredTails.push_back(CI); + } else { + AllCallsAreTailCalls = false; + } + } + + for (auto *SuccBB : make_range(succ_begin(BB), succ_end(BB))) { + auto &State = Visited[SuccBB]; + if (State < Escaped) { + State = Escaped; + if (State == ESCAPED) + WorklistEscaped.push_back(SuccBB); + else + WorklistUnescaped.push_back(SuccBB); + } + } + + if (!WorklistEscaped.empty()) { + BB = WorklistEscaped.pop_back_val(); + Escaped = ESCAPED; + } else { + BB = nullptr; + while (!WorklistUnescaped.empty()) { + auto *NextBB = WorklistUnescaped.pop_back_val(); + if (Visited[NextBB] == UNESCAPED) { + BB = NextBB; + Escaped = UNESCAPED; + break; + } + } + } + } while (BB); + + for (CallInst *CI : DeferredTails) { + if (Visited[CI->getParent()] != ESCAPED) { + // If the escape point was part way through the block, calls after the + // escape point wouldn't have been put into DeferredTails. + DEBUG(dbgs() << "Marked as tail call candidate: " << *CI << "\n"); + CI->setTailCall(); + Modified = true; + } else { + AllCallsAreTailCalls = false; + } + } + + return Modified; +} + +/// Return true if it is safe to move the specified +/// instruction from after the call to before the call, assuming that all +/// instructions between the call and this instruction are movable. +/// +static bool canMoveAboveCall(Instruction *I, CallInst *CI, AliasAnalysis *AA) { + // FIXME: We can move load/store/call/free instructions above the call if the + // call does not mod/ref the memory location being processed. + if (I->mayHaveSideEffects()) // This also handles volatile loads. + return false; + + if (LoadInst *L = dyn_cast<LoadInst>(I)) { + // Loads may always be moved above calls without side effects. + if (CI->mayHaveSideEffects()) { + // Non-volatile loads may be moved above a call with side effects if it + // does not write to memory and the load provably won't trap. + // Writes to memory only matter if they may alias the pointer + // being loaded from. + const DataLayout &DL = L->getModule()->getDataLayout(); + if (isModSet(AA->getModRefInfo(CI, MemoryLocation::get(L))) || + !isSafeToLoadUnconditionally(L->getPointerOperand(), + L->getAlignment(), DL, L)) + return false; + } + } + + // Otherwise, if this is a side-effect free instruction, check to make sure + // that it does not use the return value of the call. If it doesn't use the + // return value of the call, it must only use things that are defined before + // the call, or movable instructions between the call and the instruction + // itself. + return !is_contained(I->operands(), CI); +} + +/// Return true if the specified value is the same when the return would exit +/// as it was when the initial iteration of the recursive function was executed. +/// +/// We currently handle static constants and arguments that are not modified as +/// part of the recursion. +static bool isDynamicConstant(Value *V, CallInst *CI, ReturnInst *RI) { + if (isa<Constant>(V)) return true; // Static constants are always dyn consts + + // Check to see if this is an immutable argument, if so, the value + // will be available to initialize the accumulator. + if (Argument *Arg = dyn_cast<Argument>(V)) { + // Figure out which argument number this is... + unsigned ArgNo = 0; + Function *F = CI->getParent()->getParent(); + for (Function::arg_iterator AI = F->arg_begin(); &*AI != Arg; ++AI) + ++ArgNo; + + // If we are passing this argument into call as the corresponding + // argument operand, then the argument is dynamically constant. + // Otherwise, we cannot transform this function safely. + if (CI->getArgOperand(ArgNo) == Arg) + return true; + } + + // Switch cases are always constant integers. If the value is being switched + // on and the return is only reachable from one of its cases, it's + // effectively constant. + if (BasicBlock *UniquePred = RI->getParent()->getUniquePredecessor()) + if (SwitchInst *SI = dyn_cast<SwitchInst>(UniquePred->getTerminator())) + if (SI->getCondition() == V) + return SI->getDefaultDest() != RI->getParent(); + + // Not a constant or immutable argument, we can't safely transform. + return false; +} + +/// Check to see if the function containing the specified tail call consistently +/// returns the same runtime-constant value at all exit points except for +/// IgnoreRI. If so, return the returned value. +static Value *getCommonReturnValue(ReturnInst *IgnoreRI, CallInst *CI) { + Function *F = CI->getParent()->getParent(); + Value *ReturnedValue = nullptr; + + for (BasicBlock &BBI : *F) { + ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator()); + if (RI == nullptr || RI == IgnoreRI) continue; + + // We can only perform this transformation if the value returned is + // evaluatable at the start of the initial invocation of the function, + // instead of at the end of the evaluation. + // + Value *RetOp = RI->getOperand(0); + if (!isDynamicConstant(RetOp, CI, RI)) + return nullptr; + + if (ReturnedValue && RetOp != ReturnedValue) + return nullptr; // Cannot transform if differing values are returned. + ReturnedValue = RetOp; + } + return ReturnedValue; +} + +/// If the specified instruction can be transformed using accumulator recursion +/// elimination, return the constant which is the start of the accumulator +/// value. Otherwise return null. +static Value *canTransformAccumulatorRecursion(Instruction *I, CallInst *CI) { + if (!I->isAssociative() || !I->isCommutative()) return nullptr; + assert(I->getNumOperands() == 2 && + "Associative/commutative operations should have 2 args!"); + + // Exactly one operand should be the result of the call instruction. + if ((I->getOperand(0) == CI && I->getOperand(1) == CI) || + (I->getOperand(0) != CI && I->getOperand(1) != CI)) + return nullptr; + + // The only user of this instruction we allow is a single return instruction. + if (!I->hasOneUse() || !isa<ReturnInst>(I->user_back())) + return nullptr; + + // Ok, now we have to check all of the other return instructions in this + // function. If they return non-constants or differing values, then we cannot + // transform the function safely. + return getCommonReturnValue(cast<ReturnInst>(I->user_back()), CI); +} + +static Instruction *firstNonDbg(BasicBlock::iterator I) { + while (isa<DbgInfoIntrinsic>(I)) + ++I; + return &*I; +} + +static CallInst *findTRECandidate(Instruction *TI, + bool CannotTailCallElimCallsMarkedTail, + const TargetTransformInfo *TTI) { + BasicBlock *BB = TI->getParent(); + Function *F = BB->getParent(); + + if (&BB->front() == TI) // Make sure there is something before the terminator. + return nullptr; + + // Scan backwards from the return, checking to see if there is a tail call in + // this block. If so, set CI to it. + CallInst *CI = nullptr; + BasicBlock::iterator BBI(TI); + while (true) { + CI = dyn_cast<CallInst>(BBI); + if (CI && CI->getCalledFunction() == F) + break; + + if (BBI == BB->begin()) + return nullptr; // Didn't find a potential tail call. + --BBI; + } + + // If this call is marked as a tail call, and if there are dynamic allocas in + // the function, we cannot perform this optimization. + if (CI->isTailCall() && CannotTailCallElimCallsMarkedTail) + return nullptr; + + // As a special case, detect code like this: + // double fabs(double f) { return __builtin_fabs(f); } // a 'fabs' call + // and disable this xform in this case, because the code generator will + // lower the call to fabs into inline code. + if (BB == &F->getEntryBlock() && + firstNonDbg(BB->front().getIterator()) == CI && + firstNonDbg(std::next(BB->begin())) == TI && CI->getCalledFunction() && + !TTI->isLoweredToCall(CI->getCalledFunction())) { + // A single-block function with just a call and a return. Check that + // the arguments match. + CallSite::arg_iterator I = CallSite(CI).arg_begin(), + E = CallSite(CI).arg_end(); + Function::arg_iterator FI = F->arg_begin(), + FE = F->arg_end(); + for (; I != E && FI != FE; ++I, ++FI) + if (*I != &*FI) break; + if (I == E && FI == FE) + return nullptr; + } + + return CI; +} + +static bool eliminateRecursiveTailCall(CallInst *CI, ReturnInst *Ret, + BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + SmallVectorImpl<PHINode *> &ArgumentPHIs, + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE) { + // If we are introducing accumulator recursion to eliminate operations after + // the call instruction that are both associative and commutative, the initial + // value for the accumulator is placed in this variable. If this value is set + // then we actually perform accumulator recursion elimination instead of + // simple tail recursion elimination. If the operation is an LLVM instruction + // (eg: "add") then it is recorded in AccumulatorRecursionInstr. If not, then + // we are handling the case when the return instruction returns a constant C + // which is different to the constant returned by other return instructions + // (which is recorded in AccumulatorRecursionEliminationInitVal). This is a + // special case of accumulator recursion, the operation being "return C". + Value *AccumulatorRecursionEliminationInitVal = nullptr; + Instruction *AccumulatorRecursionInstr = nullptr; + + // Ok, we found a potential tail call. We can currently only transform the + // tail call if all of the instructions between the call and the return are + // movable to above the call itself, leaving the call next to the return. + // Check that this is the case now. + BasicBlock::iterator BBI(CI); + for (++BBI; &*BBI != Ret; ++BBI) { + if (canMoveAboveCall(&*BBI, CI, AA)) + continue; + + // If we can't move the instruction above the call, it might be because it + // is an associative and commutative operation that could be transformed + // using accumulator recursion elimination. Check to see if this is the + // case, and if so, remember the initial accumulator value for later. + if ((AccumulatorRecursionEliminationInitVal = + canTransformAccumulatorRecursion(&*BBI, CI))) { + // Yes, this is accumulator recursion. Remember which instruction + // accumulates. + AccumulatorRecursionInstr = &*BBI; + } else { + return false; // Otherwise, we cannot eliminate the tail recursion! + } + } + + // We can only transform call/return pairs that either ignore the return value + // of the call and return void, ignore the value of the call and return a + // constant, return the value returned by the tail call, or that are being + // accumulator recursion variable eliminated. + if (Ret->getNumOperands() == 1 && Ret->getReturnValue() != CI && + !isa<UndefValue>(Ret->getReturnValue()) && + AccumulatorRecursionEliminationInitVal == nullptr && + !getCommonReturnValue(nullptr, CI)) { + // One case remains that we are able to handle: the current return + // instruction returns a constant, and all other return instructions + // return a different constant. + if (!isDynamicConstant(Ret->getReturnValue(), CI, Ret)) + return false; // Current return instruction does not return a constant. + // Check that all other return instructions return a common constant. If + // so, record it in AccumulatorRecursionEliminationInitVal. + AccumulatorRecursionEliminationInitVal = getCommonReturnValue(Ret, CI); + if (!AccumulatorRecursionEliminationInitVal) + return false; + } + + BasicBlock *BB = Ret->getParent(); + Function *F = BB->getParent(); + + using namespace ore; + ORE->emit([&]() { + return OptimizationRemark(DEBUG_TYPE, "tailcall-recursion", CI) + << "transforming tail recursion into loop"; + }); + + // OK! We can transform this tail call. If this is the first one found, + // create the new entry block, allowing us to branch back to the old entry. + if (!OldEntry) { + OldEntry = &F->getEntryBlock(); + BasicBlock *NewEntry = BasicBlock::Create(F->getContext(), "", F, OldEntry); + NewEntry->takeName(OldEntry); + OldEntry->setName("tailrecurse"); + BranchInst::Create(OldEntry, NewEntry); + + // If this tail call is marked 'tail' and if there are any allocas in the + // entry block, move them up to the new entry block. + TailCallsAreMarkedTail = CI->isTailCall(); + if (TailCallsAreMarkedTail) + // Move all fixed sized allocas from OldEntry to NewEntry. + for (BasicBlock::iterator OEBI = OldEntry->begin(), E = OldEntry->end(), + NEBI = NewEntry->begin(); OEBI != E; ) + if (AllocaInst *AI = dyn_cast<AllocaInst>(OEBI++)) + if (isa<ConstantInt>(AI->getArraySize())) + AI->moveBefore(&*NEBI); + + // Now that we have created a new block, which jumps to the entry + // block, insert a PHI node for each argument of the function. + // For now, we initialize each PHI to only have the real arguments + // which are passed in. + Instruction *InsertPos = &OldEntry->front(); + for (Function::arg_iterator I = F->arg_begin(), E = F->arg_end(); + I != E; ++I) { + PHINode *PN = PHINode::Create(I->getType(), 2, + I->getName() + ".tr", InsertPos); + I->replaceAllUsesWith(PN); // Everyone use the PHI node now! + PN->addIncoming(&*I, NewEntry); + ArgumentPHIs.push_back(PN); + } + } + + // If this function has self recursive calls in the tail position where some + // are marked tail and some are not, only transform one flavor or another. We + // have to choose whether we move allocas in the entry block to the new entry + // block or not, so we can't make a good choice for both. NOTE: We could do + // slightly better here in the case that the function has no entry block + // allocas. + if (TailCallsAreMarkedTail && !CI->isTailCall()) + return false; + + // Ok, now that we know we have a pseudo-entry block WITH all of the + // required PHI nodes, add entries into the PHI node for the actual + // parameters passed into the tail-recursive call. + for (unsigned i = 0, e = CI->getNumArgOperands(); i != e; ++i) + ArgumentPHIs[i]->addIncoming(CI->getArgOperand(i), BB); + + // If we are introducing an accumulator variable to eliminate the recursion, + // do so now. Note that we _know_ that no subsequent tail recursion + // eliminations will happen on this function because of the way the + // accumulator recursion predicate is set up. + // + if (AccumulatorRecursionEliminationInitVal) { + Instruction *AccRecInstr = AccumulatorRecursionInstr; + // Start by inserting a new PHI node for the accumulator. + pred_iterator PB = pred_begin(OldEntry), PE = pred_end(OldEntry); + PHINode *AccPN = PHINode::Create( + AccumulatorRecursionEliminationInitVal->getType(), + std::distance(PB, PE) + 1, "accumulator.tr", &OldEntry->front()); + + // Loop over all of the predecessors of the tail recursion block. For the + // real entry into the function we seed the PHI with the initial value, + // computed earlier. For any other existing branches to this block (due to + // other tail recursions eliminated) the accumulator is not modified. + // Because we haven't added the branch in the current block to OldEntry yet, + // it will not show up as a predecessor. + for (pred_iterator PI = PB; PI != PE; ++PI) { + BasicBlock *P = *PI; + if (P == &F->getEntryBlock()) + AccPN->addIncoming(AccumulatorRecursionEliminationInitVal, P); + else + AccPN->addIncoming(AccPN, P); + } + + if (AccRecInstr) { + // Add an incoming argument for the current block, which is computed by + // our associative and commutative accumulator instruction. + AccPN->addIncoming(AccRecInstr, BB); + + // Next, rewrite the accumulator recursion instruction so that it does not + // use the result of the call anymore, instead, use the PHI node we just + // inserted. + AccRecInstr->setOperand(AccRecInstr->getOperand(0) != CI, AccPN); + } else { + // Add an incoming argument for the current block, which is just the + // constant returned by the current return instruction. + AccPN->addIncoming(Ret->getReturnValue(), BB); + } + + // Finally, rewrite any return instructions in the program to return the PHI + // node instead of the "initval" that they do currently. This loop will + // actually rewrite the return value we are destroying, but that's ok. + for (BasicBlock &BBI : *F) + if (ReturnInst *RI = dyn_cast<ReturnInst>(BBI.getTerminator())) + RI->setOperand(0, AccPN); + ++NumAccumAdded; + } + + // Now that all of the PHI nodes are in place, remove the call and + // ret instructions, replacing them with an unconditional branch. + BranchInst *NewBI = BranchInst::Create(OldEntry, Ret); + NewBI->setDebugLoc(CI->getDebugLoc()); + + BB->getInstList().erase(Ret); // Remove return. + BB->getInstList().erase(CI); // Remove call. + ++NumEliminated; + return true; +} + +static bool foldReturnAndProcessPred( + BasicBlock *BB, ReturnInst *Ret, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, SmallVectorImpl<PHINode *> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail, const TargetTransformInfo *TTI, + AliasAnalysis *AA, OptimizationRemarkEmitter *ORE) { + bool Change = false; + + // Make sure this block is a trivial return block. + assert(BB->getFirstNonPHIOrDbg() == Ret && + "Trying to fold non-trivial return block"); + + // If the return block contains nothing but the return and PHI's, + // there might be an opportunity to duplicate the return in its + // predecessors and perform TRE there. Look for predecessors that end + // in unconditional branch and recursive call(s). + SmallVector<BranchInst*, 8> UncondBranchPreds; + for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI) { + BasicBlock *Pred = *PI; + TerminatorInst *PTI = Pred->getTerminator(); + if (BranchInst *BI = dyn_cast<BranchInst>(PTI)) + if (BI->isUnconditional()) + UncondBranchPreds.push_back(BI); + } + + while (!UncondBranchPreds.empty()) { + BranchInst *BI = UncondBranchPreds.pop_back_val(); + BasicBlock *Pred = BI->getParent(); + if (CallInst *CI = findTRECandidate(BI, CannotTailCallElimCallsMarkedTail, TTI)){ + DEBUG(dbgs() << "FOLDING: " << *BB + << "INTO UNCOND BRANCH PRED: " << *Pred); + ReturnInst *RI = FoldReturnIntoUncondBranch(Ret, BB, Pred); + + // Cleanup: if all predecessors of BB have been eliminated by + // FoldReturnIntoUncondBranch, delete it. It is important to empty it, + // because the ret instruction in there is still using a value which + // eliminateRecursiveTailCall will attempt to remove. + if (!BB->hasAddressTaken() && pred_begin(BB) == pred_end(BB)) + BB->eraseFromParent(); + + eliminateRecursiveTailCall(CI, RI, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs, AA, ORE); + ++NumRetDuped; + Change = true; + } + } + + return Change; +} + +static bool processReturningBlock(ReturnInst *Ret, BasicBlock *&OldEntry, + bool &TailCallsAreMarkedTail, + SmallVectorImpl<PHINode *> &ArgumentPHIs, + bool CannotTailCallElimCallsMarkedTail, + const TargetTransformInfo *TTI, + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE) { + CallInst *CI = findTRECandidate(Ret, CannotTailCallElimCallsMarkedTail, TTI); + if (!CI) + return false; + + return eliminateRecursiveTailCall(CI, Ret, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs, AA, ORE); +} + +static bool eliminateTailRecursion(Function &F, const TargetTransformInfo *TTI, + AliasAnalysis *AA, + OptimizationRemarkEmitter *ORE) { + if (F.getFnAttribute("disable-tail-calls").getValueAsString() == "true") + return false; + + bool MadeChange = false; + bool AllCallsAreTailCalls = false; + MadeChange |= markTails(F, AllCallsAreTailCalls, ORE); + if (!AllCallsAreTailCalls) + return MadeChange; + + // If this function is a varargs function, we won't be able to PHI the args + // right, so don't even try to convert it... + if (F.getFunctionType()->isVarArg()) + return false; + + BasicBlock *OldEntry = nullptr; + bool TailCallsAreMarkedTail = false; + SmallVector<PHINode*, 8> ArgumentPHIs; + + // If false, we cannot perform TRE on tail calls marked with the 'tail' + // attribute, because doing so would cause the stack size to increase (real + // TRE would deallocate variable sized allocas, TRE doesn't). + bool CanTRETailMarkedCall = canTRE(F); + + // Change any tail recursive calls to loops. + // + // FIXME: The code generator produces really bad code when an 'escaping + // alloca' is changed from being a static alloca to being a dynamic alloca. + // Until this is resolved, disable this transformation if that would ever + // happen. This bug is PR962. + for (Function::iterator BBI = F.begin(), E = F.end(); BBI != E; /*in loop*/) { + BasicBlock *BB = &*BBI++; // foldReturnAndProcessPred may delete BB. + if (ReturnInst *Ret = dyn_cast<ReturnInst>(BB->getTerminator())) { + bool Change = processReturningBlock(Ret, OldEntry, TailCallsAreMarkedTail, + ArgumentPHIs, !CanTRETailMarkedCall, + TTI, AA, ORE); + if (!Change && BB->getFirstNonPHIOrDbg() == Ret) + Change = foldReturnAndProcessPred(BB, Ret, OldEntry, + TailCallsAreMarkedTail, ArgumentPHIs, + !CanTRETailMarkedCall, TTI, AA, ORE); + MadeChange |= Change; + } + } + + // If we eliminated any tail recursions, it's possible that we inserted some + // silly PHI nodes which just merge an initial value (the incoming operand) + // with themselves. Check to see if we did and clean up our mess if so. This + // occurs when a function passes an argument straight through to its tail + // call. + for (PHINode *PN : ArgumentPHIs) { + // If the PHI Node is a dynamic constant, replace it with the value it is. + if (Value *PNV = SimplifyInstruction(PN, F.getParent()->getDataLayout())) { + PN->replaceAllUsesWith(PNV); + PN->eraseFromParent(); + } + } + + return MadeChange; +} + +namespace { +struct TailCallElim : public FunctionPass { + static char ID; // Pass identification, replacement for typeid + TailCallElim() : FunctionPass(ID) { + initializeTailCallElimPass(*PassRegistry::getPassRegistry()); + } + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.addRequired<TargetTransformInfoWrapperPass>(); + AU.addRequired<AAResultsWrapperPass>(); + AU.addRequired<OptimizationRemarkEmitterWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + } + + bool runOnFunction(Function &F) override { + if (skipFunction(F)) + return false; + + return eliminateTailRecursion( + F, &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F), + &getAnalysis<AAResultsWrapperPass>().getAAResults(), + &getAnalysis<OptimizationRemarkEmitterWrapperPass>().getORE()); + } +}; +} + +char TailCallElim::ID = 0; +INITIALIZE_PASS_BEGIN(TailCallElim, "tailcallelim", "Tail Call Elimination", + false, false) +INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) +INITIALIZE_PASS_DEPENDENCY(OptimizationRemarkEmitterWrapperPass) +INITIALIZE_PASS_END(TailCallElim, "tailcallelim", "Tail Call Elimination", + false, false) + +// Public interface to the TailCallElimination pass +FunctionPass *llvm::createTailCallEliminationPass() { + return new TailCallElim(); +} + +PreservedAnalyses TailCallElimPass::run(Function &F, + FunctionAnalysisManager &AM) { + + TargetTransformInfo &TTI = AM.getResult<TargetIRAnalysis>(F); + AliasAnalysis &AA = AM.getResult<AAManager>(F); + auto &ORE = AM.getResult<OptimizationRemarkEmitterAnalysis>(F); + + bool Changed = eliminateTailRecursion(F, &TTI, &AA, &ORE); + + if (!Changed) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + return PA; +} |