diff options
Diffstat (limited to 'lib/Transforms/Scalar/MergeICmps.cpp')
-rw-r--r-- | lib/Transforms/Scalar/MergeICmps.cpp | 728 |
1 files changed, 423 insertions, 305 deletions
diff --git a/lib/Transforms/Scalar/MergeICmps.cpp b/lib/Transforms/Scalar/MergeICmps.cpp index 69fd8b163a07..3d047a193267 100644 --- a/lib/Transforms/Scalar/MergeICmps.cpp +++ b/lib/Transforms/Scalar/MergeICmps.cpp @@ -1,9 +1,8 @@ //===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===// // -// The LLVM Compiler Infrastructure -// -// This file is distributed under the University of Illinois Open Source -// License. See LICENSE.TXT for details. +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // @@ -11,29 +10,54 @@ // later typically inlined as a chain of efficient hardware comparisons). This // typically benefits c++ member or nonmember operator==(). // -// The basic idea is to replace a larger chain of integer comparisons loaded -// from contiguous memory locations into a smaller chain of such integer +// The basic idea is to replace a longer chain of integer comparisons loaded +// from contiguous memory locations into a shorter chain of larger integer // comparisons. Benefits are double: // - There are less jumps, and therefore less opportunities for mispredictions // and I-cache misses. // - Code size is smaller, both because jumps are removed and because the // encoding of a 2*n byte compare is smaller than that of two n-byte // compares. - +// +// Example: +// +// struct S { +// int a; +// char b; +// char c; +// uint16_t d; +// bool operator==(const S& o) const { +// return a == o.a && b == o.b && c == o.c && d == o.d; +// } +// }; +// +// Is optimized as : +// +// bool S::operator==(const S& o) const { +// return memcmp(this, &o, 8) == 0; +// } +// +// Which will later be expanded (ExpandMemCmp) as a single 8-bytes icmp. +// //===----------------------------------------------------------------------===// -#include <algorithm> -#include <numeric> -#include <utility> -#include <vector> +#include "llvm/Transforms/Scalar/MergeICmps.h" +#include "llvm/Analysis/DomTreeUpdater.h" +#include "llvm/Analysis/GlobalsModRef.h" #include "llvm/Analysis/Loads.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" +#include "llvm/IR/Dominators.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Pass.h" #include "llvm/Transforms/Scalar.h" +#include "llvm/Transforms/Utils/BasicBlockUtils.h" #include "llvm/Transforms/Utils/BuildLibCalls.h" +#include <algorithm> +#include <numeric> +#include <utility> +#include <vector> using namespace llvm; @@ -50,76 +74,109 @@ static bool isSimpleLoadOrStore(const Instruction *I) { return false; } -// A BCE atom. +// A BCE atom "Binary Compare Expression Atom" represents an integer load +// that is a constant offset from a base value, e.g. `a` or `o.c` in the example +// at the top. struct BCEAtom { - BCEAtom() : GEP(nullptr), LoadI(nullptr), Offset() {} - - const Value *Base() const { return GEP ? GEP->getPointerOperand() : nullptr; } - + BCEAtom() = default; + BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset) + : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {} + + BCEAtom(const BCEAtom &) = delete; + BCEAtom &operator=(const BCEAtom &) = delete; + + BCEAtom(BCEAtom &&that) = default; + BCEAtom &operator=(BCEAtom &&that) { + if (this == &that) + return *this; + GEP = that.GEP; + LoadI = that.LoadI; + BaseId = that.BaseId; + Offset = std::move(that.Offset); + return *this; + } + + // We want to order BCEAtoms by (Base, Offset). However we cannot use + // the pointer values for Base because these are non-deterministic. + // To make sure that the sort order is stable, we first assign to each atom + // base value an index based on its order of appearance in the chain of + // comparisons. We call this index `BaseOrdering`. For example, for: + // b[3] == c[2] && a[1] == d[1] && b[4] == c[3] + // | block 1 | | block 2 | | block 3 | + // b gets assigned index 0 and a index 1, because b appears as LHS in block 1, + // which is before block 2. + // We then sort by (BaseOrdering[LHS.Base()], LHS.Offset), which is stable. bool operator<(const BCEAtom &O) const { - assert(Base() && "invalid atom"); - assert(O.Base() && "invalid atom"); - // Just ordering by (Base(), Offset) is sufficient. However because this - // means that the ordering will depend on the addresses of the base - // values, which are not reproducible from run to run. To guarantee - // stability, we use the names of the values if they exist; we sort by: - // (Base.getName(), Base(), Offset). - const int NameCmp = Base()->getName().compare(O.Base()->getName()); - if (NameCmp == 0) { - if (Base() == O.Base()) { - return Offset.slt(O.Offset); - } - return Base() < O.Base(); - } - return NameCmp < 0; + return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset); } - GetElementPtrInst *GEP; - LoadInst *LoadI; + GetElementPtrInst *GEP = nullptr; + LoadInst *LoadI = nullptr; + unsigned BaseId = 0; APInt Offset; }; +// A class that assigns increasing ids to values in the order in which they are +// seen. See comment in `BCEAtom::operator<()``. +class BaseIdentifier { +public: + // Returns the id for value `Base`, after assigning one if `Base` has not been + // seen before. + int getBaseId(const Value *Base) { + assert(Base && "invalid base"); + const auto Insertion = BaseToIndex.try_emplace(Base, Order); + if (Insertion.second) + ++Order; + return Insertion.first->second; + } + +private: + unsigned Order = 1; + DenseMap<const Value*, int> BaseToIndex; +}; + // If this value is a load from a constant offset w.r.t. a base address, and // there are no other users of the load or address, returns the base address and // the offset. -BCEAtom visitICmpLoadOperand(Value *const Val) { - BCEAtom Result; - if (auto *const LoadI = dyn_cast<LoadInst>(Val)) { - LLVM_DEBUG(dbgs() << "load\n"); - if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { - LLVM_DEBUG(dbgs() << "used outside of block\n"); - return {}; - } - // Do not optimize atomic loads to non-atomic memcmp - if (!LoadI->isSimple()) { - LLVM_DEBUG(dbgs() << "volatile or atomic\n"); - return {}; - } - Value *const Addr = LoadI->getOperand(0); - if (auto *const GEP = dyn_cast<GetElementPtrInst>(Addr)) { - LLVM_DEBUG(dbgs() << "GEP\n"); - if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { - LLVM_DEBUG(dbgs() << "used outside of block\n"); - return {}; - } - const auto &DL = GEP->getModule()->getDataLayout(); - if (!isDereferenceablePointer(GEP, DL)) { - LLVM_DEBUG(dbgs() << "not dereferenceable\n"); - // We need to make sure that we can do comparison in any order, so we - // require memory to be unconditionnally dereferencable. - return {}; - } - Result.Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); - if (GEP->accumulateConstantOffset(DL, Result.Offset)) { - Result.GEP = GEP; - Result.LoadI = LoadI; - } - } +BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) { + auto *const LoadI = dyn_cast<LoadInst>(Val); + if (!LoadI) + return {}; + LLVM_DEBUG(dbgs() << "load\n"); + if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) { + LLVM_DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + // Do not optimize atomic loads to non-atomic memcmp + if (!LoadI->isSimple()) { + LLVM_DEBUG(dbgs() << "volatile or atomic\n"); + return {}; } - return Result; + Value *const Addr = LoadI->getOperand(0); + auto *const GEP = dyn_cast<GetElementPtrInst>(Addr); + if (!GEP) + return {}; + LLVM_DEBUG(dbgs() << "GEP\n"); + if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) { + LLVM_DEBUG(dbgs() << "used outside of block\n"); + return {}; + } + const auto &DL = GEP->getModule()->getDataLayout(); + if (!isDereferenceablePointer(GEP, LoadI->getType(), DL)) { + LLVM_DEBUG(dbgs() << "not dereferenceable\n"); + // We need to make sure that we can do comparison in any order, so we + // require memory to be unconditionnally dereferencable. + return {}; + } + APInt Offset = APInt(DL.getPointerTypeSizeInBits(GEP->getType()), 0); + if (!GEP->accumulateConstantOffset(DL, Offset)) + return {}; + return BCEAtom(GEP, LoadI, BaseId.getBaseId(GEP->getPointerOperand()), + Offset); } -// A basic block with a comparison between two BCE atoms. +// A basic block with a comparison between two BCE atoms, e.g. `a == o.a` in the +// example at the top. // The block might do extra work besides the atom comparison, in which case // doesOtherWork() returns true. Under some conditions, the block can be // split into the atom comparison part and the "other work" part @@ -133,13 +190,11 @@ class BCECmpBlock { BCECmpBlock() {} BCECmpBlock(BCEAtom L, BCEAtom R, int SizeBits) - : Lhs_(L), Rhs_(R), SizeBits_(SizeBits) { + : Lhs_(std::move(L)), Rhs_(std::move(R)), SizeBits_(SizeBits) { if (Rhs_ < Lhs_) std::swap(Rhs_, Lhs_); } - bool IsValid() const { - return Lhs_.Base() != nullptr && Rhs_.Base() != nullptr; - } + bool IsValid() const { return Lhs_.BaseId != 0 && Rhs_.BaseId != 0; } // Assert the block is consistent: If valid, it should also have // non-null members besides Lhs_ and Rhs_. @@ -160,19 +215,19 @@ class BCECmpBlock { // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp // instructions in the block. - bool canSplit(AliasAnalysis *AA) const; + bool canSplit(AliasAnalysis &AA) const; // Return true if this all the relevant instructions in the BCE-cmp-block can // be sunk below this instruction. By doing this, we know we can separate the // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the // block. bool canSinkBCECmpInst(const Instruction *, DenseSet<Instruction *> &, - AliasAnalysis *AA) const; + AliasAnalysis &AA) const; // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block // instructions. Split the old block and move all non-BCE-cmp-insts into the // new parent block. - void split(BasicBlock *NewParent, AliasAnalysis *AA) const; + void split(BasicBlock *NewParent, AliasAnalysis &AA) const; // The basic block where this comparison happens. BasicBlock *BB = nullptr; @@ -191,7 +246,7 @@ private: bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, DenseSet<Instruction *> &BlockInsts, - AliasAnalysis *AA) const { + AliasAnalysis &AA) const { // If this instruction has side effects and its in middle of the BCE cmp block // instructions, then bail for now. if (Inst->mayHaveSideEffects()) { @@ -201,9 +256,9 @@ bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, // Disallow stores that might alias the BCE operands MemoryLocation LLoc = MemoryLocation::get(Lhs_.LoadI); MemoryLocation RLoc = MemoryLocation::get(Rhs_.LoadI); - if (isModSet(AA->getModRefInfo(Inst, LLoc)) || - isModSet(AA->getModRefInfo(Inst, RLoc))) - return false; + if (isModSet(AA.getModRefInfo(Inst, LLoc)) || + isModSet(AA.getModRefInfo(Inst, RLoc))) + return false; } // Make sure this instruction does not use any of the BCE cmp block // instructions as operand. @@ -214,7 +269,7 @@ bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst, return true; } -void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis *AA) const { +void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const { DenseSet<Instruction *> BlockInsts( {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); llvm::SmallVector<Instruction *, 4> OtherInsts; @@ -234,7 +289,7 @@ void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis *AA) const { } } -bool BCECmpBlock::canSplit(AliasAnalysis *AA) const { +bool BCECmpBlock::canSplit(AliasAnalysis &AA) const { DenseSet<Instruction *> BlockInsts( {Lhs_.GEP, Rhs_.GEP, Lhs_.LoadI, Rhs_.LoadI, CmpI, BranchI}); for (Instruction &Inst : *BB) { @@ -265,7 +320,8 @@ bool BCECmpBlock::doesOtherWork() const { // Visit the given comparison. If this is a comparison between two valid // BCE atoms, returns the comparison. BCECmpBlock visitICmp(const ICmpInst *const CmpI, - const ICmpInst::Predicate ExpectedPredicate) { + const ICmpInst::Predicate ExpectedPredicate, + BaseIdentifier &BaseId) { // The comparison can only be used once: // - For intermediate blocks, as a branch condition. // - For the final block, as an incoming value for the Phi. @@ -275,25 +331,27 @@ BCECmpBlock visitICmp(const ICmpInst *const CmpI, LLVM_DEBUG(dbgs() << "cmp has several uses\n"); return {}; } - if (CmpI->getPredicate() == ExpectedPredicate) { - LLVM_DEBUG(dbgs() << "cmp " - << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") - << "\n"); - auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0)); - if (!Lhs.Base()) return {}; - auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1)); - if (!Rhs.Base()) return {}; - const auto &DL = CmpI->getModule()->getDataLayout(); - return BCECmpBlock(std::move(Lhs), std::move(Rhs), - DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); - } - return {}; + if (CmpI->getPredicate() != ExpectedPredicate) + return {}; + LLVM_DEBUG(dbgs() << "cmp " + << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne") + << "\n"); + auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId); + if (!Lhs.BaseId) + return {}; + auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId); + if (!Rhs.BaseId) + return {}; + const auto &DL = CmpI->getModule()->getDataLayout(); + return BCECmpBlock(std::move(Lhs), std::move(Rhs), + DL.getTypeSizeInBits(CmpI->getOperand(0)->getType())); } // Visit the given comparison block. If this is a comparison between two valid // BCE atoms, returns the comparison. BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, - const BasicBlock *const PhiBlock) { + const BasicBlock *const PhiBlock, + BaseIdentifier &BaseId) { if (Block->empty()) return {}; auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator()); if (!BranchI) return {}; @@ -306,7 +364,7 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, auto *const CmpI = dyn_cast<ICmpInst>(Val); if (!CmpI) return {}; LLVM_DEBUG(dbgs() << "icmp\n"); - auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ); + auto Result = visitICmp(CmpI, ICmpInst::ICMP_EQ, BaseId); Result.CmpI = CmpI; Result.BranchI = BranchI; return Result; @@ -323,7 +381,8 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch"); BasicBlock *const FalseBlock = BranchI->getSuccessor(1); auto Result = visitICmp( - CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE); + CmpI, FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE, + BaseId); Result.CmpI = CmpI; Result.BranchI = BranchI; return Result; @@ -332,47 +391,41 @@ BCECmpBlock visitCmpBlock(Value *const Val, BasicBlock *const Block, } static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons, - BCECmpBlock &Comparison) { + BCECmpBlock &&Comparison) { LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName() << "': Found cmp of " << Comparison.SizeBits() - << " bits between " << Comparison.Lhs().Base() << " + " + << " bits between " << Comparison.Lhs().BaseId << " + " << Comparison.Lhs().Offset << " and " - << Comparison.Rhs().Base() << " + " + << Comparison.Rhs().BaseId << " + " << Comparison.Rhs().Offset << "\n"); LLVM_DEBUG(dbgs() << "\n"); - Comparisons.push_back(Comparison); + Comparisons.push_back(std::move(Comparison)); } // A chain of comparisons. class BCECmpChain { public: - BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, - AliasAnalysis *AA); + BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, + AliasAnalysis &AA); - int size() const { return Comparisons_.size(); } + int size() const { return Comparisons_.size(); } #ifdef MERGEICMPS_DOT_ON void dump() const; #endif // MERGEICMPS_DOT_ON - bool simplify(const TargetLibraryInfo *const TLI, AliasAnalysis *AA); + bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, + DomTreeUpdater &DTU); - private: +private: static bool IsContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) { - return First.Lhs().Base() == Second.Lhs().Base() && - First.Rhs().Base() == Second.Rhs().Base() && + return First.Lhs().BaseId == Second.Lhs().BaseId && + First.Rhs().BaseId == Second.Rhs().BaseId && First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset && First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset; } - // Merges the given comparison blocks into one memcmp block and update - // branches. Comparisons are assumed to be continguous. If NextBBInChain is - // null, the merged block will link to the phi block. - void mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, - BasicBlock *const NextBBInChain, PHINode &Phi, - const TargetLibraryInfo *const TLI, AliasAnalysis *AA); - PHINode &Phi_; std::vector<BCECmpBlock> Comparisons_; // The original entry block (before sorting); @@ -380,16 +433,17 @@ class BCECmpChain { }; BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, - AliasAnalysis *AA) + AliasAnalysis &AA) : Phi_(Phi) { assert(!Blocks.empty() && "a chain should have at least one block"); // Now look inside blocks to check for BCE comparisons. std::vector<BCECmpBlock> Comparisons; + BaseIdentifier BaseId; for (size_t BlockIdx = 0; BlockIdx < Blocks.size(); ++BlockIdx) { BasicBlock *const Block = Blocks[BlockIdx]; assert(Block && "invalid block"); BCECmpBlock Comparison = visitCmpBlock(Phi.getIncomingValueForBlock(Block), - Block, Phi.getParent()); + Block, Phi.getParent(), BaseId); Comparison.BB = Block; if (!Comparison.IsValid()) { LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n"); @@ -411,13 +465,13 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, // chain before sorting. Unless we can abort the chain at this point // and start anew. // - // NOTE: we only handle block with single predecessor for now. + // NOTE: we only handle blocks a with single predecessor for now. if (Comparison.canSplit(AA)) { LLVM_DEBUG(dbgs() << "Split initial block '" << Comparison.BB->getName() << "' that does extra work besides compare\n"); Comparison.RequireSplit = true; - enqueueBlock(Comparisons, Comparison); + enqueueBlock(Comparisons, std::move(Comparison)); } else { LLVM_DEBUG(dbgs() << "ignoring initial block '" << Comparison.BB->getName() @@ -450,7 +504,7 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, // We could still merge bb1 and bb2 though. return; } - enqueueBlock(Comparisons, Comparison); + enqueueBlock(Comparisons, std::move(Comparison)); } // It is possible we have no suitable comparison to merge. @@ -466,9 +520,11 @@ BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi, #endif // MERGEICMPS_DOT_ON // Reorder blocks by LHS. We can do that without changing the // semantics because we are only accessing dereferencable memory. - llvm::sort(Comparisons_, [](const BCECmpBlock &a, const BCECmpBlock &b) { - return a.Lhs() < b.Lhs(); - }); + llvm::sort(Comparisons_, + [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) { + return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) < + std::tie(RhsBlock.Lhs(), RhsBlock.Rhs()); + }); #ifdef MERGEICMPS_DOT_ON errs() << "AFTER REORDERING:\n\n"; dump(); @@ -498,162 +554,205 @@ void BCECmpChain::dump() const { } #endif // MERGEICMPS_DOT_ON -bool BCECmpChain::simplify(const TargetLibraryInfo *const TLI, - AliasAnalysis *AA) { - // First pass to check if there is at least one merge. If not, we don't do - // anything and we keep analysis passes intact. - { - bool AtLeastOneMerged = false; - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { - AtLeastOneMerged = true; - break; +namespace { + +// A class to compute the name of a set of merged basic blocks. +// This is optimized for the common case of no block names. +class MergedBlockName { + // Storage for the uncommon case of several named blocks. + SmallString<16> Scratch; + +public: + explicit MergedBlockName(ArrayRef<BCECmpBlock> Comparisons) + : Name(makeName(Comparisons)) {} + const StringRef Name; + +private: + StringRef makeName(ArrayRef<BCECmpBlock> Comparisons) { + assert(!Comparisons.empty() && "no basic block"); + // Fast path: only one block, or no names at all. + if (Comparisons.size() == 1) + return Comparisons[0].BB->getName(); + const int size = std::accumulate(Comparisons.begin(), Comparisons.end(), 0, + [](int i, const BCECmpBlock &Cmp) { + return i + Cmp.BB->getName().size(); + }); + if (size == 0) + return StringRef("", 0); + + // Slow path: at least two blocks, at least one block with a name. + Scratch.clear(); + // We'll have `size` bytes for name and `Comparisons.size() - 1` bytes for + // separators. + Scratch.reserve(size + Comparisons.size() - 1); + const auto append = [this](StringRef str) { + Scratch.append(str.begin(), str.end()); + }; + append(Comparisons[0].BB->getName()); + for (int I = 1, E = Comparisons.size(); I < E; ++I) { + const BasicBlock *const BB = Comparisons[I].BB; + if (!BB->getName().empty()) { + append("+"); + append(BB->getName()); } } - if (!AtLeastOneMerged) return false; + return StringRef(Scratch); } +}; +} // namespace + +// Merges the given contiguous comparison blocks into one memcmp block. +static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, + BasicBlock *const InsertBefore, + BasicBlock *const NextCmpBlock, + PHINode &Phi, const TargetLibraryInfo &TLI, + AliasAnalysis &AA, DomTreeUpdater &DTU) { + assert(!Comparisons.empty() && "merging zero comparisons"); + LLVMContext &Context = NextCmpBlock->getContext(); + const BCECmpBlock &FirstCmp = Comparisons[0]; + + // Create a new cmp block before next cmp block. + BasicBlock *const BB = + BasicBlock::Create(Context, MergedBlockName(Comparisons).Name, + NextCmpBlock->getParent(), InsertBefore); + IRBuilder<> Builder(BB); + // Add the GEPs from the first BCECmpBlock. + Value *const Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone()); + Value *const Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone()); + + Value *IsEqual = nullptr; + LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> " + << BB->getName() << "\n"); + if (Comparisons.size() == 1) { + LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); + Value *const LhsLoad = + Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs); + Value *const RhsLoad = + Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs); + // There are no blocks to merge, just do the comparison. + IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad); + } else { + // If there is one block that requires splitting, we do it now, i.e. + // just before we know we will collapse the chain. The instructions + // can be executed before any of the instructions in the chain. + const auto ToSplit = + std::find_if(Comparisons.begin(), Comparisons.end(), + [](const BCECmpBlock &B) { return B.RequireSplit; }); + if (ToSplit != Comparisons.end()) { + LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n"); + ToSplit->split(BB, AA); + } - // Remove phi references to comparison blocks, they will be rebuilt as we - // merge the blocks. - for (const auto &Comparison : Comparisons_) { - Phi_.removeIncomingValue(Comparison.BB, false); - } + const unsigned TotalSizeBits = std::accumulate( + Comparisons.begin(), Comparisons.end(), 0u, + [](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); }); - // If entry block is part of the chain, we need to make the first block - // of the chain the new entry block of the function. - BasicBlock *Entry = &Comparisons_[0].BB->getParent()->getEntryBlock(); - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (Entry == Comparisons_[I].BB) { - BasicBlock *NEntryBB = BasicBlock::Create(Entry->getContext(), "", - Entry->getParent(), Entry); - BranchInst::Create(Entry, NEntryBB); - break; - } + // Create memcmp() == 0. + const auto &DL = Phi.getModule()->getDataLayout(); + Value *const MemCmpCall = emitMemCmp( + Lhs, Rhs, + ConstantInt::get(DL.getIntPtrType(Context), TotalSizeBits / 8), Builder, + DL, &TLI); + IsEqual = Builder.CreateICmpEQ( + MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); } - // Point the predecessors of the chain to the first comparison block (which is - // the new entry point) and update the entry block of the chain. - if (EntryBlock_ != Comparisons_[0].BB) { - EntryBlock_->replaceAllUsesWith(Comparisons_[0].BB); - EntryBlock_ = Comparisons_[0].BB; + BasicBlock *const PhiBB = Phi.getParent(); + // Add a branch to the next basic block in the chain. + if (NextCmpBlock == PhiBB) { + // Continue to phi, passing it the comparison result. + Builder.CreateBr(PhiBB); + Phi.addIncoming(IsEqual, BB); + DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}}); + } else { + // Continue to next block if equal, exit to phi else. + Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB); + Phi.addIncoming(ConstantInt::getFalse(Context), BB); + DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock}, + {DominatorTree::Insert, BB, PhiBB}}); } + return BB; +} + +bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA, + DomTreeUpdater &DTU) { + assert(Comparisons_.size() >= 2 && "simplifying trivial BCECmpChain"); + // First pass to check if there is at least one merge. If not, we don't do + // anything and we keep analysis passes intact. + const auto AtLeastOneMerged = [this]() { + for (size_t I = 1; I < Comparisons_.size(); ++I) { + if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) + return true; + } + return false; + }; + if (!AtLeastOneMerged()) + return false; - // Effectively merge blocks. + LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block " + << EntryBlock_->getName() << "\n"); + + // Effectively merge blocks. We go in the reverse direction from the phi block + // so that the next block is always available to branch to. + const auto mergeRange = [this, &TLI, &AA, &DTU](int I, int Num, + BasicBlock *InsertBefore, + BasicBlock *Next) { + return mergeComparisons(makeArrayRef(Comparisons_).slice(I, Num), + InsertBefore, Next, Phi_, TLI, AA, DTU); + }; int NumMerged = 1; - for (size_t I = 1; I < Comparisons_.size(); ++I) { - if (IsContiguous(Comparisons_[I - 1], Comparisons_[I])) { + BasicBlock *NextCmpBlock = Phi_.getParent(); + for (int I = static_cast<int>(Comparisons_.size()) - 2; I >= 0; --I) { + if (IsContiguous(Comparisons_[I], Comparisons_[I + 1])) { + LLVM_DEBUG(dbgs() << "Merging block " << Comparisons_[I].BB->getName() + << " into " << Comparisons_[I + 1].BB->getName() + << "\n"); ++NumMerged; } else { - // Merge all previous comparisons and start a new merge block. - mergeComparisons( - makeArrayRef(Comparisons_).slice(I - NumMerged, NumMerged), - Comparisons_[I].BB, Phi_, TLI, AA); + NextCmpBlock = mergeRange(I + 1, NumMerged, NextCmpBlock, NextCmpBlock); NumMerged = 1; } } - mergeComparisons(makeArrayRef(Comparisons_) - .slice(Comparisons_.size() - NumMerged, NumMerged), - nullptr, Phi_, TLI, AA); - - return true; -} - -void BCECmpChain::mergeComparisons(ArrayRef<BCECmpBlock> Comparisons, - BasicBlock *const NextBBInChain, - PHINode &Phi, - const TargetLibraryInfo *const TLI, - AliasAnalysis *AA) { - assert(!Comparisons.empty()); - const auto &FirstComparison = *Comparisons.begin(); - BasicBlock *const BB = FirstComparison.BB; - LLVMContext &Context = BB->getContext(); - - if (Comparisons.size() >= 2) { - // If there is one block that requires splitting, we do it now, i.e. - // just before we know we will collapse the chain. The instructions - // can be executed before any of the instructions in the chain. - auto C = std::find_if(Comparisons.begin(), Comparisons.end(), - [](const BCECmpBlock &B) { return B.RequireSplit; }); - if (C != Comparisons.end()) - C->split(EntryBlock_, AA); - - LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons\n"); - const auto TotalSize = - std::accumulate(Comparisons.begin(), Comparisons.end(), 0, - [](int Size, const BCECmpBlock &C) { - return Size + C.SizeBits(); - }) / - 8; - - // Incoming edges do not need to be updated, and both GEPs are already - // computing the right address, we just need to: - // - replace the two loads and the icmp with the memcmp - // - update the branch - // - update the incoming values in the phi. - FirstComparison.BranchI->eraseFromParent(); - FirstComparison.CmpI->eraseFromParent(); - FirstComparison.Lhs().LoadI->eraseFromParent(); - FirstComparison.Rhs().LoadI->eraseFromParent(); - - IRBuilder<> Builder(BB); - const auto &DL = Phi.getModule()->getDataLayout(); - Value *const MemCmpCall = emitMemCmp( - FirstComparison.Lhs().GEP, FirstComparison.Rhs().GEP, - ConstantInt::get(DL.getIntPtrType(Context), TotalSize), - Builder, DL, TLI); - Value *const MemCmpIsZero = Builder.CreateICmpEQ( - MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0)); + // Insert the entry block for the new chain before the old entry block. + // If the old entry block was the function entry, this ensures that the new + // entry can become the function entry. + NextCmpBlock = mergeRange(0, NumMerged, EntryBlock_, NextCmpBlock); + + // Replace the original cmp chain with the new cmp chain by pointing all + // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp + // blocks in the old chain unreachable. + while (!pred_empty(EntryBlock_)) { + BasicBlock* const Pred = *pred_begin(EntryBlock_); + LLVM_DEBUG(dbgs() << "Updating jump into old chain from " << Pred->getName() + << "\n"); + Pred->getTerminator()->replaceUsesOfWith(EntryBlock_, NextCmpBlock); + DTU.applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_}, + {DominatorTree::Insert, Pred, NextCmpBlock}}); + } - // Add a branch to the next basic block in the chain. - if (NextBBInChain) { - Builder.CreateCondBr(MemCmpIsZero, NextBBInChain, Phi.getParent()); - Phi.addIncoming(ConstantInt::getFalse(Context), BB); - } else { - Builder.CreateBr(Phi.getParent()); - Phi.addIncoming(MemCmpIsZero, BB); - } + // If the old cmp chain was the function entry, we need to update the function + // entry. + const bool ChainEntryIsFnEntry = + (EntryBlock_ == &EntryBlock_->getParent()->getEntryBlock()); + if (ChainEntryIsFnEntry && DTU.hasDomTree()) { + LLVM_DEBUG(dbgs() << "Changing function entry from " + << EntryBlock_->getName() << " to " + << NextCmpBlock->getName() << "\n"); + DTU.getDomTree().setNewRoot(NextCmpBlock); + DTU.applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}}); + } + EntryBlock_ = nullptr; - // Delete merged blocks. - for (size_t I = 1; I < Comparisons.size(); ++I) { - BasicBlock *CBB = Comparisons[I].BB; - CBB->replaceAllUsesWith(BB); - CBB->eraseFromParent(); - } - } else { - assert(Comparisons.size() == 1); - // There are no blocks to merge, but we still need to update the branches. - LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n"); - if (NextBBInChain) { - if (FirstComparison.BranchI->isConditional()) { - LLVM_DEBUG(dbgs() << "conditional -> conditional\n"); - // Just update the "true" target, the "false" target should already be - // the phi block. - assert(FirstComparison.BranchI->getSuccessor(1) == Phi.getParent()); - FirstComparison.BranchI->setSuccessor(0, NextBBInChain); - Phi.addIncoming(ConstantInt::getFalse(Context), BB); - } else { - LLVM_DEBUG(dbgs() << "unconditional -> conditional\n"); - // Replace the unconditional branch by a conditional one. - FirstComparison.BranchI->eraseFromParent(); - IRBuilder<> Builder(BB); - Builder.CreateCondBr(FirstComparison.CmpI, NextBBInChain, - Phi.getParent()); - Phi.addIncoming(FirstComparison.CmpI, BB); - } - } else { - if (FirstComparison.BranchI->isConditional()) { - LLVM_DEBUG(dbgs() << "conditional -> unconditional\n"); - // Replace the conditional branch by an unconditional one. - FirstComparison.BranchI->eraseFromParent(); - IRBuilder<> Builder(BB); - Builder.CreateBr(Phi.getParent()); - Phi.addIncoming(FirstComparison.CmpI, BB); - } else { - LLVM_DEBUG(dbgs() << "unconditional -> unconditional\n"); - Phi.addIncoming(FirstComparison.CmpI, BB); - } - } + // Delete merged blocks. This also removes incoming values in phi. + SmallVector<BasicBlock *, 16> DeadBlocks; + for (auto &Cmp : Comparisons_) { + LLVM_DEBUG(dbgs() << "Deleting merged block " << Cmp.BB->getName() << "\n"); + DeadBlocks.push_back(Cmp.BB); } + DeleteDeadBlocks(DeadBlocks, &DTU); + + Comparisons_.clear(); + return true; } std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, @@ -691,8 +790,8 @@ std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi, return Blocks; } -bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI, - AliasAnalysis *AA) { +bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA, + DomTreeUpdater &DTU) { LLVM_DEBUG(dbgs() << "processPhi()\n"); if (Phi.getNumIncomingValues() <= 1) { LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n"); @@ -757,24 +856,54 @@ bool processPhi(PHINode &Phi, const TargetLibraryInfo *const TLI, return false; } - return CmpChain.simplify(TLI, AA); + return CmpChain.simplify(TLI, AA, DTU); } -class MergeICmps : public FunctionPass { - public: +static bool runImpl(Function &F, const TargetLibraryInfo &TLI, + const TargetTransformInfo &TTI, AliasAnalysis &AA, + DominatorTree *DT) { + LLVM_DEBUG(dbgs() << "MergeICmpsLegacyPass: " << F.getName() << "\n"); + + // We only try merging comparisons if the target wants to expand memcmp later. + // The rationale is to avoid turning small chains into memcmp calls. + if (!TTI.enableMemCmpExpansion(F.hasOptSize(), true)) + return false; + + // If we don't have memcmp avaiable we can't emit calls to it. + if (!TLI.has(LibFunc_memcmp)) + return false; + + DomTreeUpdater DTU(DT, /*PostDominatorTree*/ nullptr, + DomTreeUpdater::UpdateStrategy::Eager); + + bool MadeChange = false; + + for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { + // A Phi operation is always first in a basic block. + if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) + MadeChange |= processPhi(*Phi, TLI, AA, DTU); + } + + return MadeChange; +} + +class MergeICmpsLegacyPass : public FunctionPass { +public: static char ID; - MergeICmps() : FunctionPass(ID) { - initializeMergeICmpsPass(*PassRegistry::getPassRegistry()); + MergeICmpsLegacyPass() : FunctionPass(ID) { + initializeMergeICmpsLegacyPassPass(*PassRegistry::getPassRegistry()); } bool runOnFunction(Function &F) override { if (skipFunction(F)) return false; const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); - AliasAnalysis *AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); - auto PA = runImpl(F, &TLI, &TTI, AA); - return !PA.areAllPreserved(); + // MergeICmps does not need the DominatorTree, but we update it if it's + // already available. + auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>(); + auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults(); + return runImpl(F, TLI, TTI, AA, DTWP ? &DTWP->getDomTree() : nullptr); } private: @@ -782,46 +911,35 @@ class MergeICmps : public FunctionPass { AU.addRequired<TargetLibraryInfoWrapperPass>(); AU.addRequired<TargetTransformInfoWrapperPass>(); AU.addRequired<AAResultsWrapperPass>(); + AU.addPreserved<GlobalsAAWrapperPass>(); + AU.addPreserved<DominatorTreeWrapperPass>(); } - - PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, AliasAnalysis *AA); }; -PreservedAnalyses MergeICmps::runImpl(Function &F, const TargetLibraryInfo *TLI, - const TargetTransformInfo *TTI, - AliasAnalysis *AA) { - LLVM_DEBUG(dbgs() << "MergeICmpsPass: " << F.getName() << "\n"); - - // We only try merging comparisons if the target wants to expand memcmp later. - // The rationale is to avoid turning small chains into memcmp calls. - if (!TTI->enableMemCmpExpansion(true)) return PreservedAnalyses::all(); - - // If we don't have memcmp avaiable we can't emit calls to it. - if (!TLI->has(LibFunc_memcmp)) - return PreservedAnalyses::all(); - - bool MadeChange = false; - - for (auto BBIt = ++F.begin(); BBIt != F.end(); ++BBIt) { - // A Phi operation is always first in a basic block. - if (auto *const Phi = dyn_cast<PHINode>(&*BBIt->begin())) - MadeChange |= processPhi(*Phi, TLI, AA); - } - - if (MadeChange) return PreservedAnalyses::none(); - return PreservedAnalyses::all(); -} +} // namespace -} // namespace - -char MergeICmps::ID = 0; -INITIALIZE_PASS_BEGIN(MergeICmps, "mergeicmps", +char MergeICmpsLegacyPass::ID = 0; +INITIALIZE_PASS_BEGIN(MergeICmpsLegacyPass, "mergeicmps", "Merge contiguous icmps into a memcmp", false, false) INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass) -INITIALIZE_PASS_END(MergeICmps, "mergeicmps", +INITIALIZE_PASS_END(MergeICmpsLegacyPass, "mergeicmps", "Merge contiguous icmps into a memcmp", false, false) -Pass *llvm::createMergeICmpsPass() { return new MergeICmps(); } +Pass *llvm::createMergeICmpsLegacyPass() { return new MergeICmpsLegacyPass(); } + +PreservedAnalyses MergeICmpsPass::run(Function &F, + FunctionAnalysisManager &AM) { + auto &TLI = AM.getResult<TargetLibraryAnalysis>(F); + auto &TTI = AM.getResult<TargetIRAnalysis>(F); + auto &AA = AM.getResult<AAManager>(F); + auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); + const bool MadeChanges = runImpl(F, TLI, TTI, AA, DT); + if (!MadeChanges) + return PreservedAnalyses::all(); + PreservedAnalyses PA; + PA.preserve<GlobalsAA>(); + PA.preserve<DominatorTreeAnalysis>(); + return PA; +} |